├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── conf ├── config.yaml └── task2vec │ ├── montecarlo.yaml │ └── variational.yaml ├── dataset ├── __init__.py ├── cifar.py ├── cub.py ├── dataset.py ├── expansion.py ├── imat.py ├── inat.py └── mnist.py ├── datasets.py ├── main.py ├── models.py ├── plot_distance_cub_inat.py ├── requirements.txt ├── scripts ├── download_cub.sh └── download_inat2018.sh ├── small_datasets_example.ipynb ├── static ├── distance_matrix.png └── taxonomical distance.png ├── support_files ├── cub │ └── final_tasks_map.json └── inat2018 │ ├── classes_tasks.json │ ├── final_tasks_map.json │ └── tasks.json ├── task2vec.py ├── task_similarity.py ├── utils.py └── variational.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Task2Vec 2 | 3 | This is an implementation of the Task2Vec method described in the paper [Task2Vec: Task Embedding for Meta-Learning](https://arxiv.org/abs/1902.03545). 4 | 5 | 6 | Task2Vec provides vectorial representations of learning tasks (datasets) which can be used to reason about the nature of 7 | those tasks and their relations. 8 | In particular, it provides a fixed-dimensional embedding of the task that is independent of details such as the number of 9 | classes and does not require any understanding of the class label semantics. The distance between embeddings 10 | matches our intuition about semantic and taxonomic relations between different visual tasks 11 | (e.g., tasks based on classifying different types of plants are similar). The resulting vector can be used to 12 | represent a dataset in meta-learning applicatins, and allows for example to select the best feature extractor for a task 13 | without an expensive brute force search. 14 | 15 | ## Quick start 16 | 17 | To compute and embedding using task2vec, you just need to provide a dataset and a probe network, for example: 18 | ```python 19 | from task2vec import Task2Vec 20 | from models import get_model 21 | from datasets import get_dataset 22 | 23 | dataset = get_dataset('cifar10') 24 | probe_network = get_model('resnet34', pretrained=True, num_classes=10) 25 | embedding = Task2Vec(probe_network).embed(dataset) 26 | ``` 27 | Task2Vec uses the diagonal of the Fisher Information Matrix to compute an embedding of the task. In this implementation 28 | we provide two methods, `montecarlo` and `variational`. The first is the fastest and is the default, but `variational` 29 | may be more robust in some situations (in particular it is the one used in the paper). You can try it using: 30 | ```python 31 | task2vec.embed(dataset, probe_network, method='variational') 32 | ``` 33 | Now, let's try computing several embedding and plot the distance matrix between the tasks: 34 | ```python 35 | from task2vec import Task2Vec 36 | from models import get_model 37 | import datasets 38 | import task_similarity 39 | 40 | dataset_names = ('mnist', 'cifar10', 'cifar100', 'letters', 'kmnist') 41 | dataset_list = [datasets.__dict__[name]('./data')[0] for name in dataset_names] 42 | 43 | embeddings = [] 44 | for name, dataset in zip(dataset_names, dataset_list): 45 | print(f"Embedding {name}") 46 | probe_network = get_model('resnet34', pretrained=True, num_classes=int(max(dataset.targets)+1)).cuda() 47 | embeddings.append( Task2Vec(probe_network, max_samples=1000, skip_layers=6).embed(dataset) ) 48 | task_similarity.plot_distance_matrix(embeddings, dataset_names) 49 | ``` 50 | You can also look at the notebook `small_datasets_example.ipynb` for a runnable implementation of this code snippet. 51 | 52 | ## Experiments on iNaturalist and CUB 53 | 54 | ### Downloading the data 55 | First, decide where you will store all the data. For example: 56 | ``` 57 | export DATA_ROOT=./data 58 | ``` 59 | To download [CUB-200](http://www.vision.caltech.edu/visipedia/CUB-200.html), 60 | from the repository root run: 61 | ```sh 62 | ./scripts/download_cub.sh $DATA_ROOT 63 | ``` 64 | 65 | To download [iNaturalist 2018](https://github.com/visipedia/inat_comp/tree/master/2018), 66 | from the repository root run: 67 | ```sh 68 | ./scripts/download_inat2018.sh $DATA_ROOT 69 | ``` 70 | **WARNING:** iNaturalist needs ~319Gb for download and extraction. 71 | Consider downloading and extracting it manually following the instructions 72 | [here](https://github.com/visipedia/inat_comp/tree/master/2018). 73 | 74 | ### Computing the embedding of all tasks 75 | To compute the embedding on a single task of CUB + iNat2018, run: 76 | ```sh 77 | python main.py task2vec.method=montecarlo dataset.root=$DATA_ROOT dataset.name=cub_inat2018 dataset.task_id=$TASK_ID -m 78 | ``` 79 | This will use the `montecarlo` Fisher approximation to compute the embedding of the task number `$TASK_ID` in the CUB + iNAT2018 meta-task. 80 | The result is stored in a pickle file inside `outputs`. 81 | 82 | To compute all embeddings at once, we can use Hydra's multi-run mode as follow: 83 | ```sh 84 | python main.py task2vec.method=montecarlo dataset.root=$DATA_ROOT dataset.name=cub_inat2018 dataset.task_id=`seq -s , 0 50` -m 85 | ``` 86 | This will compute the embeddings of the first 50 tasks in the CUB + iNat2018 meta-task. 87 | To plot the 50x50 distance matrix between these tasks, first download all the `iconic_taxa` 88 | [image files](https://github.com/inaturalist/inaturalist/tree/master/app/assets/images/iconic_taxa) 89 | to `./static/iconic_taxa`, and then run: 90 | ```sh 91 | python plot_distance_cub_inat.py --data-root $DATA_ROOT ./multirun/montecarlo 92 | ``` 93 | The result should look like the following. Note that task regarding classification of similar life forms 94 | (e.g, different types of birds, plants, mammals) cluster together. 95 | 96 | ![task2vec distance matrix](static/distance_matrix.png?raw=1) 97 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | device: "cuda:0" 2 | 3 | task2vec: 4 | # Maximum number of samples in the dataset used to estimate the Fisher 5 | max_samples: 10000 6 | skip_layers: 0 7 | 8 | # Whether to put batch normalization in eval mode (true) or train mode (false) when computing the Fisher 9 | # fix_batch_norm: true 10 | 11 | classifier_opts: 12 | optimizer: adam 13 | epochs: 10 14 | learning_rate: 0.0004 15 | weight_decay: 0.0001 16 | 17 | defaults: 18 | - task2vec: montecarlo 19 | 20 | dataset: 21 | name: inat2018 22 | task_id: 0 23 | root: ~/data 24 | 25 | # Probe network to use 26 | model: 27 | arch: resnet34 28 | pretrained: true 29 | 30 | loader: 31 | batch_size: 100 32 | num_workers: 6 33 | balanced_sampling: true 34 | num_samples: 10000 35 | 36 | hydra: 37 | sweep: 38 | dir: ./multirun/${task2vec.method} 39 | subdir: ${hydra.job.num}_${hydra.job.override_dirname} 40 | # subdir: ${hydra.job.num}_${hydra.job.num}_${hydra.job.override_dirname} 41 | 42 | -------------------------------------------------------------------------------- /conf/task2vec/montecarlo.yaml: -------------------------------------------------------------------------------- 1 | task2vec: 2 | method: montecarlo 3 | method_opts: 4 | epochs: 1 -------------------------------------------------------------------------------- /conf/task2vec/variational.yaml: -------------------------------------------------------------------------------- 1 | task2vec: 2 | method: variational 3 | method_opts: 4 | beta: 1.0e-7 5 | epochs: 2 -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | -------------------------------------------------------------------------------- /dataset/cifar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import numpy as np 15 | from torchvision.datasets import CIFAR10, CIFAR100 16 | 17 | from .dataset import ClassificationTaskDataset 18 | from .expansion import ClassificationTaskExpander 19 | 20 | 21 | class SplitCIFARTask: 22 | """SplitCIFARTask generates Split CIFAR task 23 | 24 | Parameters 25 | ---------- 26 | cifar10_dataset : CIFAR10Dataset 27 | cifar100_dataset : CIFAR100Dataset 28 | """ 29 | 30 | def __init__(self, cifar10_dataset, cifar100_dataset): 31 | self.cifar10_dataset = cifar10_dataset 32 | self.cifar100_dataset = cifar100_dataset 33 | 34 | def generate(self, task_id=0, transform=None, target_transform=None): 35 | """Generate tasks given the classes 36 | 37 | Parameters 38 | ---------- 39 | task_id : int 0-10 (default 0) 40 | 0 = CIFAR10, 1 = first 10 of CIFAR100, 2 = second 10 of CIFAR100, ... 41 | transform : callable (default None) 42 | Optional transform to be applied on a sample. 43 | target_transform : callable (default None) 44 | Optional transform to be applied on a label. 45 | 46 | Returns 47 | ------- 48 | Task 49 | """ 50 | assert isinstance(task_id, int) 51 | assert 0 <= task_id <= 10, task_id 52 | 53 | task_expander = ClassificationTaskExpander() 54 | if task_id == 0: 55 | classes = tuple(range(10)) 56 | return task_expander(self.cifar10_dataset, 57 | {c: new_c for new_c, c in enumerate(classes)}, 58 | label_names={c: name for c, name in self.cifar10_dataset.label_names_map.items()}, 59 | task_id=task_id, 60 | task_name='Split CIFAR: CIFAR-10 {}'.format(classes), 61 | transform=transform, 62 | target_transform=target_transform) 63 | else: 64 | classes = tuple([int(c) for c in np.arange(10) + 10 * (task_id - 1)]) 65 | return task_expander(self.cifar100_dataset, 66 | {c: new_c for new_c, c in enumerate(classes)}, 67 | label_names={classes.index(old_c): name for old_c, name in 68 | self.cifar100_dataset.label_names_map.items() if old_c in classes}, 69 | task_id=task_id, 70 | task_name='Split CIFAR: CIFAR-100 {}'.format(classes), 71 | transform=transform, 72 | target_transform=target_transform) 73 | 74 | 75 | class CIFAR10Dataset(ClassificationTaskDataset): 76 | """CIFAR10 Dataset 77 | 78 | Parameters 79 | ---------- 80 | path : str (default None) 81 | path to dataset (should contain images folder in same directory) 82 | if None, search using DATA environment variable 83 | train : bool (default True) 84 | if True, load train split otherwise load test split 85 | download: bool (default False) 86 | if True, downloads the dataset from the internet and 87 | puts it in path directory; otherwise if dataset is already downloaded, 88 | it is not downloaded again 89 | metadata : dict (default empty) 90 | extra arbitrary metadata 91 | transform : callable (default None) 92 | Optional transform to be applied on a sample. 93 | target_transform : callable (default None) 94 | Optional transform to be applied on a label. 95 | """ 96 | 97 | def __init__(self, path, train=True, download=False, 98 | metadata={}, transform=None, target_transform=None): 99 | num_classes, task_name = self._get_settings() 100 | assert isinstance(path, str) 101 | assert isinstance(train, bool) 102 | 103 | self.cifar = self._get_cifar(path, train, transform, target_transform, download) 104 | 105 | super(CIFAR10Dataset, self).__init__(list(self.cifar.data), 106 | [int(x) for x in self.cifar.targets], 107 | label_names={l: str(l) for l in range(num_classes)}, 108 | root=path, 109 | task_id=None, 110 | task_name=task_name, 111 | metadata=metadata, 112 | transform=transform, 113 | target_transform=target_transform) 114 | 115 | def _get_settings(self): 116 | return 10, 'CIFAR10' 117 | 118 | def _get_cifar(self, path, train, transform, target_transform, download=True): 119 | return CIFAR10(path, train=train, transform=transform, 120 | target_transform=target_transform, download=download) 121 | 122 | 123 | class CIFAR100Dataset(CIFAR10Dataset): 124 | """CIFAR100 Dataset 125 | 126 | Parameters 127 | ---------- 128 | path : str (default None) 129 | path to dataset (should contain images folder in same directory) 130 | if None, search using DATA environment variable 131 | train : bool (default True) 132 | if True, load train split otherwise load test split 133 | download: bool (default False) 134 | if True, downloads the dataset from the internet and 135 | puts it in path directory; otherwise if dataset is already downloaded, 136 | it is not downloaded again 137 | metadata : dict (default empty) 138 | extra arbitrary metadata 139 | transform : callable (default None) 140 | Optional transform to be applied on a sample. 141 | target_transform : callable (default None) 142 | Optional transform to be applied on a label. 143 | """ 144 | 145 | def __init__(self, path=None, train=True, download=False, 146 | metadata={}, transform=None, target_transform=None): 147 | super(CIFAR100Dataset, self).__init__(path=path, 148 | train=train, 149 | metadata=metadata, 150 | transform=transform, 151 | target_transform=target_transform) 152 | 153 | def _get_settings(self): 154 | return 100, 'CIFAR100' 155 | 156 | def _get_cifar(self, path, train, transform, target_transform, download=True): 157 | return CIFAR100(path, train=train, transform=transform, 158 | target_transform=target_transform, download=download) 159 | -------------------------------------------------------------------------------- /dataset/cub.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import os 15 | import csv 16 | 17 | from .dataset import ClassificationTaskDataset 18 | from .expansion import MetaclassClassificationTaskExpander 19 | 20 | 21 | class CUBTasks: 22 | """CUBTasks generates tasks from the CUB dataset 23 | 24 | Parameters 25 | ---------- 26 | cub_dataset : CUBDataset 27 | """ 28 | TAXONOMY_FILE = 'taxonomy.txt' 29 | 30 | ORDER_TASK = 'order' 31 | FAMILY_TASK = 'family' 32 | GENUS_TASK = 'genus' 33 | SPECIES_TASK = 'species' 34 | POSSIBLE_TASKS = [ORDER_TASK, FAMILY_TASK, GENUS_TASK, SPECIES_TASK] 35 | TAXONOMY_COLUMN_MAP = {ORDER_TASK: 2, FAMILY_TASK: 3, GENUS_TASK: 4, SPECIES_TASK: 5} 36 | 37 | def __init__(self, cub_dataset): 38 | self.cub_dataset = cub_dataset 39 | 40 | def generate(self, task='order', task_id=None, taxonomy_file=TAXONOMY_FILE, 41 | use_species_names=False, transform=None, target_transform=None): 42 | """Generate tasks given the specified task (order|family|genus|species) 43 | 44 | Parameters 45 | ---------- 46 | task: str (default 'order') 47 | tasks to generate 48 | task_id : int or None (default 0) 49 | if None, generate all tasks otherwise only generate for task_id 50 | taxonomy_file : str 51 | taxonomy file name if provided 52 | use_species_names : bool (default False) 53 | if True, use species name in taxonomy file instead of default label names 54 | transform : callable (default None) 55 | Optional transform to be applied on a sample. 56 | target_transform : callable (default None) 57 | Optional transform to be applied on a label. 58 | 59 | Returns 60 | ------- 61 | dict : task id -> Task if task_id is None 62 | or 63 | Task if task_id is set 64 | """ 65 | assert isinstance(task, str) 66 | assert task in CUBTasks.POSSIBLE_TASKS, task 67 | 68 | taxonomy_path = os.path.join(self.cub_dataset.root, taxonomy_file) 69 | assert os.path.exists(taxonomy_path), taxonomy_path 70 | 71 | task_map = [] 72 | task_col_index = CUBTasks.TAXONOMY_COLUMN_MAP[task] 73 | task_names_to_ids = {} 74 | with open(taxonomy_path, 'r') as f: 75 | reader = csv.reader(f, delimiter=' ') 76 | for row in reader: 77 | label = int(row[0]) 78 | label_name = row[5].replace('_', ' ') if use_species_names else row[1] 79 | task_name = row[task_col_index] 80 | 81 | if task_name not in task_names_to_ids: 82 | new_task_id = len(task_map) 83 | task_names_to_ids[task_name] = new_task_id 84 | task = { 85 | "task_id": new_task_id, 86 | "task_name": task_name, 87 | "class_names": [label_name], 88 | "class_ids": [0], 89 | "label_map": {label: 0}, 90 | } 91 | task_map.append(task) 92 | else: 93 | new_task_id = task_names_to_ids[task_name] 94 | new_label = len(task_map[new_task_id]['label_map']) 95 | task_map[new_task_id]['class_names'].append(label_name) 96 | task_map[new_task_id]['class_ids'].append(new_label) 97 | task_map[new_task_id]['label_map'].update({label: new_label}) 98 | 99 | task_expander = MetaclassClassificationTaskExpander() 100 | if task_id is not None: 101 | assert isinstance(task_id, int) 102 | assert 0 <= task_id < len(task_map) 103 | task_map = [task_map[task_id]] 104 | results = task_expander(self.cub_dataset, task_map, 105 | transform=transform, 106 | target_transform=target_transform) 107 | assert len(results) == 1, len(results) 108 | return list(results.values())[0] 109 | else: 110 | return task_expander(self.cub_dataset, task_map, 111 | transform=transform, 112 | target_transform=target_transform) 113 | 114 | 115 | class CUBDataset(ClassificationTaskDataset): 116 | """CUB Dataset 117 | 118 | Parameters 119 | ---------- 120 | path : str (default None) 121 | path to dataset (should contain images folder in same directory) 122 | if None, search using DATA environment variable 123 | split : str (train|test) or None (default 'train') 124 | only load split if provided, otherwise if None, load train+test 125 | classes_file : str 126 | path to class names file (relative to path argument) 127 | metadata : dict (default empty) 128 | extra arbitrary metadata 129 | transform : callable (default None) 130 | Optional transform to be applied on a sample. 131 | target_transform : callable (default None) 132 | Optional transform to be applied on a label. 133 | """ 134 | IMAGES_FOLDER = 'images' 135 | IMAGES_FILE = 'images.txt' 136 | TRAIN_TEST_SPLIT_FILE = 'train_test_split.txt' 137 | IMAGE_CLASS_LABELS_FILE = 'image_class_labels.txt' 138 | CLASSES_FILE = 'classes.txt' 139 | 140 | TRAIN_SPLIT = 'train' 141 | TEST_SPLIT = 'test' 142 | POSSIBLE_SPLITS = [TRAIN_SPLIT, TEST_SPLIT] 143 | 144 | def __init__(self, root, split='train', classes_file=CLASSES_FILE, 145 | metadata={}, transform=None, target_transform=None): 146 | 147 | path = os.path.join(root, 'cub/CUB_200_2011') 148 | images_folder = os.path.join(path, CUBDataset.IMAGES_FOLDER) 149 | images_file = os.path.join(path, CUBDataset.IMAGES_FILE) 150 | train_test_split_file = os.path.join(path, CUBDataset.TRAIN_TEST_SPLIT_FILE) 151 | image_class_labels_file = os.path.join(path, CUBDataset.IMAGE_CLASS_LABELS_FILE) 152 | classes_file = os.path.join(path, classes_file) 153 | 154 | assert os.path.exists(images_folder), images_folder 155 | assert os.path.exists(images_file), images_file 156 | assert os.path.exists(image_class_labels_file), image_class_labels_file 157 | 158 | # read in splits 159 | ignore_indices = set() 160 | if split is not None: 161 | assert split in CUBDataset.POSSIBLE_SPLITS, split 162 | assert os.path.exists(train_test_split_file), train_test_split_file 163 | 164 | with open(train_test_split_file, 'r') as f: 165 | for l in f: 166 | index, is_train = l.strip().split(' ') 167 | if int(is_train) and split == CUBDataset.TEST_SPLIT: 168 | ignore_indices.add(int(index)) 169 | elif not int(is_train) and split == CUBDataset.TRAIN_SPLIT: 170 | ignore_indices.add(int(index)) 171 | 172 | # read in images 173 | images_list = [] 174 | with open(images_file, 'r') as f: 175 | for l in f: 176 | index, img = l.strip().split(' ') 177 | if int(index) not in ignore_indices: 178 | images_list.append(os.path.join(CUBDataset.IMAGES_FOLDER, img)) 179 | 180 | # read in labels 181 | labels_list = [] 182 | with open(image_class_labels_file, 'r') as f: 183 | for l in f: 184 | index, label = l.strip().split(' ') 185 | if int(index) not in ignore_indices: 186 | labels_list.append(int(label)) 187 | 188 | # read in label names 189 | label_names = {} 190 | if os.path.exists(classes_file): 191 | with open(classes_file, 'r') as f: 192 | for l in f: 193 | label, label_name = l.strip().split(' ', 1) 194 | label_names[int(label)] = label_name 195 | 196 | self.split = split 197 | super(CUBDataset, self).__init__(images_list, 198 | labels_list, 199 | label_names=label_names, 200 | root=path, 201 | task_id=None, 202 | task_name='CUB', 203 | metadata=metadata, 204 | transform=transform, 205 | target_transform=target_transform) 206 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | 15 | import os 16 | import sys 17 | import collections 18 | from copy import deepcopy 19 | 20 | from PIL import Image 21 | import numpy as np 22 | import torch 23 | import torch.utils.data as data 24 | from torchvision import transforms 25 | from torchvision.datasets.folder import default_loader 26 | from sklearn.preprocessing import MultiLabelBinarizer 27 | 28 | 29 | def is_multi_label(labels_list): 30 | """Whether labels list provided is a multi-label dataset. 31 | 32 | Parameters 33 | ---------- 34 | labels_list : list 35 | list of labels (integers) or multiple labels (lists of integers) or mixed 36 | 37 | Returns 38 | ------- 39 | True if multi-label 40 | """ 41 | return any([isinstance(l, list) for l in labels_list]) 42 | 43 | 44 | class MultilabelTransform: 45 | """MultilabelTransform transforms a list of labels into a multilabel format 46 | 47 | E.g. if possible_labels=[1,2,3,4,5], then 48 | sample=[3,5] => [0,0,1,0,1] 49 | 50 | Parameters 51 | ---------- 52 | possible_labels : list 53 | An ordering for the class labels 54 | """ 55 | 56 | def __init__(self, possible_labels): 57 | assert isinstance(possible_labels, list) 58 | assert len(possible_labels) > 0 59 | assert len(possible_labels) == len(set(possible_labels)) 60 | self._transformer = MultiLabelBinarizer(classes=possible_labels) 61 | 62 | def __call__(self, sample): 63 | """ 64 | Parameters 65 | ---------- 66 | sample : target label 67 | 68 | Returns 69 | ------- 70 | target label in multilabel format 71 | """ 72 | if isinstance(sample, int): 73 | return self._transformer.fit_transform([[sample]])[0] 74 | else: # list 75 | return self._transformer.fit_transform([sample])[0] 76 | 77 | 78 | class TaskDataset(data.Dataset): 79 | """TaskDataset allows task expansion operations and is initialized with a list of images and labels. 80 | 81 | Parameters 82 | ---------- 83 | images_list : list 84 | list of image paths (str) 85 | labels_list : list 86 | list of labels 87 | root : str 88 | root path to append to all images 89 | task_id : int or tuple, default None 90 | task id (simply used as metadata) 91 | task_name : str, default None 92 | task name (simply used as metadata) 93 | metadata : dict (default empty) 94 | extra arbitrary metadata 95 | transform : callable (default None) 96 | Optional transform to be applied on a sample. 97 | target_transform : callable (default None) 98 | Optional transform to be applied on a label. 99 | """ 100 | 101 | def __init__(self, images_list, labels_list, root='', 102 | task_id=None, task_name=None, metadata={}, 103 | transform=None, target_transform=None): 104 | assert isinstance(images_list, list) 105 | assert isinstance(labels_list, list) 106 | assert len(images_list) == len(labels_list) 107 | assert isinstance(root, str) 108 | if task_id is not None: 109 | assert isinstance(task_id, int) or isinstance(task_id, tuple) 110 | if task_name is not None: 111 | assert isinstance(task_name, str) 112 | assert isinstance(metadata, dict) 113 | 114 | self._images = [img for img in images_list] 115 | self._labels = list(labels_list) 116 | self._root = root 117 | 118 | self._task_id = task_id 119 | self._task_name = task_name 120 | self._metadata = deepcopy(metadata) 121 | 122 | self._transform = transform 123 | self._target_transform = target_transform 124 | 125 | self._loader = default_loader 126 | 127 | def __len__(self): 128 | return len(self.images) 129 | 130 | def __getitem__(self, index): 131 | """ 132 | Parameters 133 | ---------- 134 | index : int 135 | 136 | Returns 137 | ------- 138 | img : image 139 | sample 140 | target : int 141 | class_index of target class 142 | """ 143 | if isinstance(self.images[index], str): 144 | img_path = self.images[index] 145 | try: 146 | img = self._loader(os.path.join(self._root, img_path)) 147 | except OSError as ex: 148 | # If the file cannot be read, print error and return grey image instead 149 | print(ex, file=sys.stderr) 150 | img = Image.fromarray(np.ones([224, 224, 3], dtype=np.uint8) * 128) 151 | 152 | elif isinstance(self.images[index], torch.Tensor): 153 | img = Image.fromarray(self.images[index].numpy()) 154 | elif isinstance(self.images[index], np.ndarray): 155 | img = Image.fromarray(self.images[index]) 156 | else: 157 | raise NotImplementedError() 158 | target = self.labels[index] 159 | if self._transform is not None: 160 | img = self._transform(img) 161 | if self._target_transform is not None: 162 | target = self._target_transform(target) 163 | return img, target 164 | 165 | @property 166 | def images(self): 167 | """List of images. 168 | """ 169 | return self._images 170 | 171 | @property 172 | def labels(self): 173 | """List of labels, corresponding to the list of images. 174 | """ 175 | return self._labels 176 | 177 | @property 178 | def possible_labels(self): 179 | """List of possible labels. 180 | 181 | Returns 182 | ------- 183 | list of labels (int) 184 | """ 185 | labels = set() 186 | for label in self.labels: 187 | if isinstance(label, collections.Iterable): 188 | labels.update([l for l in label]) 189 | else: 190 | labels.add(label) 191 | return sorted(list(labels)) 192 | 193 | @property 194 | def num_classes(self): 195 | return len(self.possible_labels) 196 | 197 | @property 198 | def task_id(self): 199 | return self._task_id 200 | 201 | @property 202 | def task_name(self): 203 | return self._task_name 204 | 205 | @property 206 | def metadata(self): 207 | return self._metadata 208 | 209 | @metadata.setter 210 | def metadata(self, metadata): 211 | self._metadata = metadata 212 | 213 | @property 214 | def root(self): 215 | return self._root 216 | 217 | @property 218 | def transform(self): 219 | return self._transform 220 | 221 | @property 222 | def target_transform(self): 223 | return self._target_transform 224 | 225 | 226 | class ClassificationTaskDataset(TaskDataset): 227 | """ClassificationTaskDataset allows task expansion operations for classification tasks. 228 | 229 | Parameters 230 | ---------- 231 | images_list : list 232 | list of image paths (str) 233 | labels_list : list 234 | list of labels (integers) or multiple labels (lists of integers) 235 | (old task) 236 | label_names : dict (default None) 237 | map of label (int) -> name (str) 238 | if task_mapper is not None, should be using new labels 239 | binarize_labels : bool (default False) 240 | if True, binarize labels in dataset iterator (should be used for multilabel) 241 | force_multi_label : bool (default False) 242 | if True, force multi-label dataset 243 | root : str 244 | root path to append to all images 245 | task_id : int, default None 246 | task id (simply used as metadata) 247 | task_name : str, default None 248 | task name (simply used as metadata) 249 | metadata : dict (default empty) 250 | arbitrary metadata 251 | transform : callable (default None) 252 | Optional transform to be applied on a sample. 253 | target_transform : callable (default None) 254 | Optional transform to be applied on a label. 255 | """ 256 | 257 | def __init__(self, images_list, labels_list, label_names=None, root='', 258 | binarize_labels=False, force_multi_label=False, 259 | task_id=None, task_name=None, metadata={}, 260 | transform=None, target_transform=None): 261 | assert isinstance(labels_list, list) 262 | assert all(isinstance(x, int) or isinstance(x, list) for x in labels_list) 263 | assert all(all(isinstance(i, int) for i in x) if isinstance(x, list) else True for x in labels_list) 264 | 265 | # force multi-label format when at least one label instance is multi-label 266 | if force_multi_label or is_multi_label(labels_list): 267 | labels_list = [l if isinstance(l, list) else [l] for l in labels_list] 268 | 269 | super(ClassificationTaskDataset, self).__init__(images_list, 270 | labels_list, root=root, 271 | task_id=task_id, task_name=task_name, metadata=metadata, 272 | transform=transform, target_transform=target_transform) 273 | 274 | # sanity checks for label names 275 | if label_names is not None: 276 | self._verify_label_names(label_names) 277 | self._label_names_map = label_names 278 | 279 | if binarize_labels or self.is_multi_label: 280 | target_transforms = [ 281 | MultilabelTransform(self.possible_labels), 282 | ] 283 | if target_transform is not None: 284 | target_transforms.append(target_transform) 285 | self._target_transform = transforms.Compose(target_transforms) 286 | 287 | def _verify_label_names(self, label_names): 288 | """Verify label names given labels. 289 | 290 | Raises an exception if label names are not proper with respect to the given labels. 291 | 292 | Parameters 293 | ---------- 294 | label_names : dict 295 | map of label -> name 296 | """ 297 | assert isinstance(label_names, dict) 298 | 299 | found_labels = set() 300 | for label in self._labels: 301 | if isinstance(label, collections.Iterable): 302 | found_labels.update([l for l in label]) 303 | else: 304 | found_labels.add(label) 305 | assert found_labels.issubset( 306 | label_names.keys()), 'dataset contains labels not specified in label_names: {} vs. {}'.format(found_labels, 307 | label_names.keys()) 308 | if len(found_labels) < len(label_names.keys()): 309 | print("Warning: label_names contains more labels than discovered labels in the dataset") 310 | 311 | def get_labels(self, flatten=False): 312 | """List of labels, corresponding to the list of images. 313 | 314 | Parameters 315 | ---------- 316 | flatten : bool (default False) 317 | flattens list into all labels (destroys correspondence with images) 318 | comes in handy for calculating label frequency 319 | """ 320 | if flatten: 321 | flattened_labels_list = [] 322 | for l in self._labels: 323 | if isinstance(l, list): 324 | flattened_labels_list.extend(l) 325 | else: 326 | flattened_labels_list.append(l) 327 | return flattened_labels_list 328 | else: 329 | return self._labels 330 | 331 | @property 332 | def is_multi_label(self): 333 | """Whether dataset is a multi-label dataset. 334 | 335 | Returns 336 | ------- 337 | True if multi-label dataset 338 | """ 339 | return is_multi_label(self._labels) 340 | 341 | @property 342 | def label_names(self): 343 | """List of label names for each sample. 344 | 345 | Returns 346 | ------- 347 | list of label names per sample (str) 348 | """ 349 | if not self._label_names_map: 350 | return None 351 | else: 352 | return [[self._label_names_map[l] for l in x] if type(x) == list else self._label_names_map[x] for x in 353 | self.labels] 354 | 355 | @property 356 | def possible_label_names(self): 357 | """List of possible label names. 358 | 359 | Returns 360 | ------- 361 | list of label names (str) 362 | """ 363 | if not self._label_names_map: 364 | return None 365 | else: 366 | return [self._label_names_map[l] for l in self.possible_labels] 367 | 368 | @property 369 | def label_names_map(self): 370 | """Map of label to label name. 371 | """ 372 | return self._label_names_map 373 | -------------------------------------------------------------------------------- /dataset/expansion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | 15 | import os 16 | import collections 17 | from copy import deepcopy 18 | 19 | from .dataset import TaskDataset, ClassificationTaskDataset 20 | 21 | 22 | class TaskExpander: 23 | """TaskExpander is an abstract class for task expansion functions. 24 | 25 | In general, task expansion functions take old task(s) as input and produces new task(s). 26 | 27 | Implement call functions to accept old task(s) and produce new task(s). 28 | """ 29 | def __call__(self, task): 30 | raise NotImplementedError() 31 | 32 | 33 | class ClassificationTaskExpander(TaskExpander): 34 | """ClassificationTaskExpander remaps old classes to new classes. 35 | """ 36 | def __call__(self, task, task_map, label_names=None, flatten_binary=False, 37 | force_multi_label=False, force_remove_multi_label=False, 38 | task_id=None, task_name=None, metadata=None, 39 | transform=None, target_transform=None): 40 | """Call function. 41 | 42 | Parameters 43 | ---------- 44 | task : ClassificationTaskDataset 45 | old task 46 | task_map : dict 47 | map of old label (int) -> new labels (list) 48 | label_names : dict (default None) 49 | map of new label (int) -> name (str) 50 | flatten_binary : bool (default False) 51 | if True and new label space is binary, set new label to 1 if attribute present 52 | force_multi_label : bool (default False) 53 | if True, force multi-label dataset 54 | force_remove_multi_label : bool (default False) 55 | if True, force removing samples with multiple labels 56 | task_id : int, default None 57 | task id (simply used as metadata) 58 | task_name : str, default None 59 | task name (simply used as metadata) 60 | metadata : dict (default empty) 61 | extra arbitrary metadata 62 | transform : callable (default None) 63 | Optional transform to be applied on a sample. 64 | target_transform : callable (default None) 65 | Optional transform to be applied on a label. 66 | 67 | Returns 68 | ------- 69 | ClassificationTaskDataset 70 | """ 71 | assert isinstance(task_map, dict) 72 | assert all(isinstance(x, int) for x in task_map.keys()) 73 | assert all(isinstance(x, int) or isinstance(x, list) for x in task_map.values()) 74 | is_single_label_task = all(isinstance(x, int) for x in task_map.values()) 75 | if flatten_binary: 76 | assert self._is_binary_new_task(task_map) 77 | 78 | if task_map is {}: 79 | return task 80 | 81 | # create new labels and filter images 82 | new_images = [] 83 | new_labels = [] 84 | for idx, (img, label) in enumerate(zip(task.images, task.labels)): 85 | if isinstance(label, list): 86 | # map old labels to new labels 87 | new_label = set() 88 | for l in label: 89 | if l in task_map: 90 | if type(task_map[l]) == int: 91 | new_label.add(task_map[l]) 92 | else: 93 | new_label.update(task_map[l]) 94 | new_label = list(new_label) 95 | if len(new_label) > 0: 96 | # enforce consistent data type 97 | new_label = new_label[0] if len(new_label) == 1 and is_single_label_task else new_label 98 | new_label = 1 if flatten_binary and new_label == [0,1] else new_label 99 | new_images.append(img) 100 | new_labels.append(new_label) 101 | # otherwise exclude sample 102 | 103 | elif isinstance(label, int): 104 | # map old label to new label 105 | if label in task_map: 106 | new_label = task_map[label] 107 | new_images.append(img) 108 | new_labels.append(new_label) 109 | # otherwise if not specified in task mapper, exclude sample 110 | 111 | else: 112 | raise ValueError("label is not a list or int: {}".format(label)) 113 | 114 | if force_remove_multi_label: 115 | multi_label_samples_mask = [isinstance(l, list) and len(l) != 1 for l in new_labels] 116 | new_images = [img for idx, img in enumerate(new_images) if not multi_label_samples_mask[idx]] 117 | new_labels = [l if isinstance(l, int) else l[0] for idx, l in enumerate(new_labels) if not multi_label_samples_mask[idx]] 118 | 119 | assert len(new_images) == len(new_labels) 120 | assert len(new_images) > 0 121 | 122 | return ClassificationTaskDataset(new_images, 123 | new_labels, 124 | label_names=label_names, 125 | force_multi_label=force_multi_label, 126 | root=task.root, 127 | task_id=task.task_id if task_id is None else task_id, 128 | task_name=task.task_name if task_name is None else task_name, 129 | metadata=task.metadata if metadata is None else metadata, 130 | transform=transform, 131 | target_transform=target_transform) 132 | 133 | def _is_binary_new_task(self, task_map): 134 | """Helper to determine if new task is a binary classification problem. 135 | 136 | Parameters 137 | ---------- 138 | task_map : dict 139 | map of old label (int) -> new labels (list) 140 | 141 | Returns 142 | ------- 143 | True if new task is binary classification problem 144 | """ 145 | possible_labels = set() 146 | for label in task_map.values(): 147 | if isinstance(label, int): 148 | possible_labels.add(label) 149 | elif isinstance(label, list): 150 | possible_labels.update(label) 151 | else: 152 | raise ValueError('Expected int or list but got {}'.format(type(label))) 153 | return possible_labels == set([0,1]) 154 | 155 | 156 | class MetaclassClassificationTaskExpander(TaskExpander): 157 | """MetaclassClassificationTaskExpander produces new tasks for each meta-class grouping of old classes. 158 | """ 159 | def __call__(self, task, task_map, flatten_binary=False, force_remove_multi_label=False, 160 | transform=None, target_transform=None): 161 | """Call function. 162 | 163 | Parameters 164 | ---------- 165 | task : ClassificationTaskDataset 166 | old task 167 | task_map : list 168 | list of new tasks, where each task is provided as the following dict: 169 | { 170 | "task_id": int, # new task id 171 | "task_name": str, # new task name 172 | "class_names": [str], # NEW class names 173 | "class_ids": [int], # NEW class_ids 174 | "label_map": {int: int}, # old class_id -> new class_id 175 | "multi_label": int(0|1), # whether to force multilabel (optional) 176 | } 177 | flatten_binary : bool (default False) 178 | if True and label space is binary, flatten [0,1] new label to 1 179 | force_remove_multi_label : bool (default False) 180 | if True, force removing samples with multiple labels 181 | transform : callable (default None) 182 | Optional transform to be applied on a sample. 183 | target_transform : callable (default None) 184 | Optional transform to be applied on a label. 185 | 186 | Returns 187 | ------- 188 | dict : task_id (int) -> ClassificationTaskDataset 189 | """ 190 | assert isinstance(task_map, list) 191 | assert all(isinstance(task, dict) for task in task_map) 192 | assert all(all(x in task for x in ['task_id', 'task_name', 'label_map']) for task in task_map), 'task_map missing some required keys' 193 | 194 | new_tasks = {} 195 | for task_mapper in task_map: 196 | new_task_id = task_mapper['task_id'] 197 | new_task_name = task_mapper['task_name'] 198 | label_map = task_mapper['label_map'] 199 | force_multi_label = task_mapper['multi_label'] if 'multi_label' in task_mapper else 0 200 | assert force_multi_label in [0,1] 201 | 202 | # get new label names if available 203 | label_names = None 204 | if 'class_names' in task_mapper and 'class_ids' in task_mapper: 205 | new_class_ids = task_mapper['class_ids'] 206 | new_class_names = task_mapper['class_names'] 207 | assert len(new_class_names) == len(new_class_ids) 208 | label_names = {i:n for i, n in zip(new_class_ids, new_class_names)} 209 | 210 | # create new task 211 | remapper = ClassificationTaskExpander() 212 | new_task = remapper(task, 213 | label_map, 214 | task_id=new_task_id, 215 | task_name=new_task_name, 216 | metadata=task.metadata, 217 | label_names=label_names, 218 | flatten_binary=flatten_binary, 219 | force_multi_label=bool(force_multi_label), 220 | force_remove_multi_label=force_remove_multi_label, 221 | transform=transform, 222 | target_transform=target_transform) 223 | new_tasks[new_task_id] = new_task 224 | 225 | return new_tasks 226 | 227 | 228 | class BinaryClassificationTaskExpander(MetaclassClassificationTaskExpander): 229 | """BinaryClassificationTaskExpander produces new binary tasks for each attribute. 230 | """ 231 | def __call__(self, task, labels=None, transform=None, target_transform=None): 232 | """Call function. 233 | 234 | Parameters 235 | ---------- 236 | task : ClassificationTaskDataset 237 | old task 238 | labels : list 239 | list of labels to consider 240 | transform : callable (default None) 241 | Optional transform to be applied on a sample. 242 | target_transform : callable (default None) 243 | Optional transform to be applied on a label. 244 | 245 | Returns 246 | ------- 247 | dict : label (int) -> ClassificationTaskDataset 248 | """ 249 | if labels is None: 250 | labels = task.possible_labels 251 | elif isinstance(labels, int): 252 | labels = [labels] 253 | assert isinstance(labels, list) 254 | 255 | task_map = [] 256 | for label in labels: 257 | label_name = task.label_names_map[label] if task.possible_label_names is not None else str(label) 258 | 259 | # create new task 260 | binary_task = {} 261 | binary_task['task_id'] = label 262 | binary_task['task_name'] = label_name 263 | binary_task['label_map'] = { l: 1 if l == label else 0 for l in task.possible_labels } 264 | binary_task['class_ids'] = [1, 0] 265 | binary_task['class_names'] = [label_name, 'not {}'.format(label_name)] 266 | task_map.append(binary_task) 267 | 268 | return super(BinaryClassificationTaskExpander, self).__call__(task, 269 | task_map, 270 | flatten_binary=True, 271 | transform=transform, 272 | target_transform=target_transform) 273 | 274 | 275 | class UnionClassificationTaskExpander(TaskExpander): 276 | """UnionClassificationTaskExpander combines multiple tasks into one task. 277 | 278 | Supports different label merging strategies: 279 | DISJOINT : do not merge; remap all old classes to new classes 280 | LABELS : merge common classes 281 | LABEL_NAMES : merge classes with common names 282 | 283 | Parameters 284 | ---------- 285 | merge_mode : str 286 | one of the following merge modes: 287 | DISJOINT : do not merge; remap all old classes to new classes 288 | LABELS : merge common classes 289 | LABEL_NAMES : merge classes with common names 290 | merge_duplicate_images : bool (default True) 291 | if True, merge duplicate images otherwise do not 292 | """ 293 | DISJOINT_MERGE = 'DISJOINT' 294 | MERGE_LABELS = 'LABELS' 295 | MERGE_LABEL_NAMES = 'LABEL_NAMES' 296 | POSSIBLE_MERGE_MODES = [DISJOINT_MERGE, MERGE_LABELS, MERGE_LABEL_NAMES] 297 | 298 | def __init__(self, merge_mode=DISJOINT_MERGE, merge_duplicate_images=True): 299 | assert merge_mode in UnionClassificationTaskExpander.POSSIBLE_MERGE_MODES 300 | self._merge_mode = merge_mode 301 | self._merge_duplicate_images = merge_duplicate_images 302 | 303 | def __call__(self, tasks, task_id=None, task_name=None, metadata=None, 304 | transform=None, target_transform=None): 305 | """Call function. 306 | 307 | Parameters 308 | ---------- 309 | tasks : list 310 | list of ClassificationTaskDataset 311 | task_id : int, default None 312 | task id (simply used as metadata) 313 | task_name : str, default None 314 | task name (simply used as metadata) 315 | metadata : dict (default empty) 316 | extra arbitrary metadata 317 | transform : callable (default None) 318 | Optional transform to be applied on a sample. 319 | target_transform : callable (default None) 320 | Optional transform to be applied on a label. 321 | 322 | Returns 323 | ------- 324 | ClassificationTaskDataset 325 | """ 326 | assert isinstance(tasks, collections.Iterable) 327 | 328 | # build label map depending on mode 329 | if self._merge_mode == UnionClassificationTaskExpander.DISJOINT_MERGE: 330 | label_map, label_names_map = self._disjoint_merge(tasks) 331 | 332 | elif self._merge_mode == UnionClassificationTaskExpander.MERGE_LABELS: 333 | label_map, label_names_map = self._merge_labels(tasks) 334 | 335 | elif self._merge_mode == UnionClassificationTaskExpander.MERGE_LABEL_NAMES: 336 | label_map, label_names_map = self._merge_label_names(tasks) 337 | 338 | # remap old labels to new labels 339 | new_images, new_labels = self._remap_tasks(tasks, label_map) 340 | 341 | # combine labels for the same image across datasets 342 | if self._merge_duplicate_images: 343 | new_images, new_labels = self._merge_duplicate_images(new_images, new_labels) 344 | 345 | # create new metadata 346 | new_metadata = {} 347 | for idx, t in enumerate(tasks): 348 | new_metadata[idx] = deepcopy(t.metadata) 349 | 350 | return ClassificationTaskDataset(new_images, 351 | new_labels, 352 | label_names=label_names_map, 353 | task_id=task_id, 354 | task_name=task_name, 355 | metadata=new_metadata if metadata is None else metadata, 356 | transform=transform, 357 | target_transform=target_transform) 358 | 359 | def _disjoint_merge(self, tasks): 360 | """Helper for disjoint merge. 361 | 362 | Parameters 363 | ---------- 364 | tasks : list 365 | list of ClassificationTaskDataset 366 | 367 | Returns 368 | ------- 369 | label_map : dict 370 | map of task (index) -> old label -> new label 371 | label_names_map : dict 372 | map of new label -> new label name 373 | """ 374 | label_map = {} 375 | label_names_map = {} 376 | 377 | label_counter = 0 378 | for idx, t in enumerate(tasks): 379 | label_map[idx] = {l : label_counter+i for i, l in enumerate(t.possible_labels)} 380 | if t.possible_label_names is not None: 381 | label_names_map.update({label_counter+i : n for i, n in enumerate(t.possible_label_names)}) 382 | else: 383 | task_name = t.task_name if t.task_name is not None else str(idx) 384 | label_names_map.update({label_counter+i : '{}_{}'.format(task_name, l) for i, l in enumerate(t.possible_labels)}) 385 | label_counter += len(t.possible_labels) 386 | 387 | return label_map, label_names_map 388 | 389 | def _merge_labels(self, tasks, intersection=False): 390 | """Helper for labels merge. 391 | 392 | Parameters 393 | ---------- 394 | tasks : list 395 | list of ClassificationTaskDataset 396 | intersection : bool (default False) 397 | if True, only return intersection 398 | 399 | Returns 400 | ------- 401 | label_map : dict 402 | map of task (index) -> old label -> new label 403 | label_names_map : dict 404 | map of new label -> new label name 405 | """ 406 | if intersection: 407 | common_labels = set.intersection(*[set(t.possible_labels) for t in tasks]) 408 | if len(common_labels) == 0: 409 | return {}, {} 410 | 411 | label_map = {} 412 | label_names_map = {} 413 | 414 | label_counter = 0 415 | old_labels = {} 416 | for idx, t in enumerate(tasks): 417 | label_map[idx] = {} 418 | for i, l in enumerate(t.possible_labels): 419 | if intersection and l not in common_labels: 420 | continue 421 | 422 | if l in old_labels: 423 | new_label = old_labels[l] 424 | else: 425 | new_label = label_counter 426 | old_labels[l] = new_label 427 | label_counter += 1 428 | label_map[idx][l] = new_label 429 | 430 | # merge label names by concatenating names 431 | if t.possible_label_names is not None: 432 | if new_label not in label_names_map: 433 | label_names_map[new_label] = t.possible_label_names[i] 434 | else: 435 | label_names_map[new_label] += ',' + t.possible_label_names[i] 436 | else: 437 | task_name = t.task_name if t.task_name is not None else str(idx) 438 | if new_label not in label_names_map: 439 | label_names_map[new_label] = '{}_{}'.format(task_name, l) 440 | else: 441 | label_names_map[new_label] += ',{}_{}'.format(task_name, l) 442 | 443 | return label_map, label_names_map 444 | 445 | def _merge_label_names(self, tasks, intersection=False): 446 | """Helper for label names merge. 447 | 448 | Parameters 449 | ---------- 450 | tasks : list 451 | list of ClassificationTaskDataset 452 | intersection : bool (default False) 453 | if True, only return intersection 454 | 455 | Returns 456 | ------- 457 | label_map : dict 458 | map of task (index) -> old label -> new label 459 | label_names_map : dict 460 | map of new label -> new label name 461 | """ 462 | if intersection: 463 | common_label_names = set.intersection(*[set(t.possible_label_names) for t in tasks]) 464 | if len(common_label_names) == 0: 465 | return {}, {} 466 | 467 | label_map = {} 468 | label_names_map = {} 469 | 470 | label_counter = 0 471 | old_labels = {} 472 | for idx, t in enumerate(tasks): 473 | assert t.possible_label_names is not None, idx 474 | label_map[idx] = {} 475 | for i, l in enumerate(t.possible_labels): 476 | label_name = t.possible_label_names[i] 477 | if intersection and label_name not in common_label_names: 478 | continue 479 | 480 | if label_name in old_labels: 481 | new_label = old_labels[label_name] 482 | else: 483 | new_label = label_counter 484 | old_labels[label_name] = new_label 485 | label_counter += 1 486 | label_map[idx][l] = new_label 487 | label_names_map[new_label] = label_name 488 | 489 | return label_map, label_names_map 490 | 491 | def _remap_tasks(self, tasks, label_map): 492 | """Helper to remap task labels. 493 | 494 | Parameters 495 | ---------- 496 | tasks : list 497 | list of ClassificationTaskDataset 498 | label_map : dict 499 | map of task (index) -> old label -> new label 500 | 501 | Returns 502 | ------- 503 | new_images : list 504 | new_labels : list 505 | """ 506 | remapper = ClassificationTaskExpander() 507 | new_images = [] 508 | new_labels = [] 509 | for idx, t in enumerate(tasks): 510 | new_task = remapper(t, label_map[idx], label_names=None) 511 | new_images.extend([os.path.join(t.root, img) if isinstance(img, str) else img for img in new_task.images]) 512 | new_labels.extend(new_task.labels) 513 | return new_images, new_labels 514 | 515 | def _merge_duplicate_images(self, images, labels): 516 | """Merge duplicate images. 517 | 518 | Parameters 519 | ---------- 520 | images : list 521 | labels : list 522 | 523 | Returns 524 | ------- 525 | new_images : list 526 | new_labels : list 527 | """ 528 | new_data = {} 529 | for img, label in zip(images, labels): 530 | assert isinstance(img, str) 531 | if img not in new_data: 532 | new_data[img] = label 533 | elif isinstance(label, int): 534 | if isinstance(new_data[img], int): 535 | new_data[img] = [new_data[img], label] 536 | else: 537 | new_data[img].append(label) 538 | elif isinstance(label, list): 539 | if isinstance(new_data[img], int): 540 | new_data[img] = [new_data[img], *label] 541 | else: 542 | new_data[img].extend(label) 543 | else: 544 | ValueError() 545 | 546 | new_images = [] 547 | new_labels = [] 548 | for k,v in new_data.items(): 549 | new_images.append(k) 550 | new_labels.append(v) 551 | 552 | return new_images, new_labels 553 | 554 | 555 | class IntersectionClassificationTaskExpander(UnionClassificationTaskExpander): 556 | """IntersectionClassificationTaskExpander combines multiple tasks with common labels into one task. 557 | 558 | Supports different label merging strategies: 559 | LABELS : merge common classes 560 | LABEL_NAMES : merge classes with common names 561 | 562 | Parameters 563 | ---------- 564 | merge_mode : str 565 | one of the following merge modes: 566 | LABELS : merge common classes 567 | LABEL_NAMES : merge classes with common names 568 | """ 569 | MERGE_LABELS = 'LABELS' 570 | MERGE_LABEL_NAMES = 'LABEL_NAMES' 571 | POSSIBLE_MERGE_MODES = [MERGE_LABELS, MERGE_LABEL_NAMES] 572 | 573 | def __init__(self, merge_mode=MERGE_LABEL_NAMES): 574 | assert merge_mode in IntersectionClassificationTaskExpander.POSSIBLE_MERGE_MODES 575 | self._merge_mode = merge_mode 576 | 577 | def __call__(self, tasks, task_id=None, task_name=None, metadata=None, 578 | transform=None, target_transform=None): 579 | """Call function. 580 | 581 | Parameters 582 | ---------- 583 | tasks : list 584 | list of ClassificationTaskDataset 585 | task_id : int, default None 586 | task id (simply used as metadata) 587 | task_name : str, default None 588 | task name (simply used as metadata) 589 | metadata : dict (default empty) 590 | extra arbitrary metadata 591 | transform : callable (default None) 592 | Optional transform to be applied on a sample. 593 | target_transform : callable (default None) 594 | Optional transform to be applied on a label. 595 | 596 | Returns 597 | ------- 598 | ClassificationTaskDataset 599 | or 600 | None if intersection is empty 601 | """ 602 | assert isinstance(tasks, collections.Iterable) 603 | 604 | # build label map depending on mode 605 | if self._merge_mode == IntersectionClassificationTaskExpander.MERGE_LABELS: 606 | label_map, label_names_map = self._merge_labels(tasks, intersection=True) 607 | 608 | elif self._merge_mode == IntersectionClassificationTaskExpander.MERGE_LABEL_NAMES: 609 | label_map, label_names_map = self._merge_label_names(tasks, intersection=True) 610 | 611 | # if no overlap 612 | if len(label_map) == 0 and len(label_names_map) == 0: 613 | return None 614 | 615 | # remap old labels to new labels 616 | new_images, new_labels = self._remap_tasks(tasks, label_map) 617 | 618 | # combine labels for the same image across datasets 619 | new_images, new_labels = self._merge_duplicate_images(new_images, new_labels) 620 | 621 | # create new metadata 622 | new_metadata = {} 623 | for idx, t in enumerate(tasks): 624 | new_metadata[idx] = deepcopy(t.metadata) 625 | 626 | return ClassificationTaskDataset(new_images, 627 | new_labels, 628 | label_names=label_names_map, 629 | task_id=task_id, 630 | task_name=task_name, 631 | metadata=new_metadata if metadata is None else metadata, 632 | transform=transform, 633 | target_transform=target_transform) 634 | 635 | 636 | class TrainTestExpander(UnionClassificationTaskExpander): 637 | """TrainTestExpander ensures targets from both train and test partitions are the same. 638 | 639 | This is basically IntersectionClassificationTaskExpander but does not merge into one dataset. 640 | 641 | Supports different label merging strategies: 642 | LABELS : merge common classes 643 | LABEL_NAMES : merge classes with common names 644 | 645 | Parameters 646 | ---------- 647 | merge_mode : str 648 | one of the following merge modes: 649 | LABELS : merge common classes 650 | LABEL_NAMES : merge classes with common names 651 | """ 652 | MERGE_LABELS = 'LABELS' 653 | MERGE_LABEL_NAMES = 'LABEL_NAMES' 654 | POSSIBLE_MERGE_MODES = [MERGE_LABELS, MERGE_LABEL_NAMES] 655 | 656 | def __init__(self, merge_mode=MERGE_LABEL_NAMES): 657 | assert merge_mode in TrainTestExpander.POSSIBLE_MERGE_MODES 658 | self._merge_mode = merge_mode 659 | 660 | def __call__(self, *tasks): 661 | """Call function. 662 | 663 | Parameters 664 | ---------- 665 | *tasks : list (positional arguments) 666 | list of ClassificationTaskDataset 667 | 668 | Returns 669 | ------- 670 | list of ClassificationTaskDataset 671 | """ 672 | assert isinstance(tasks, collections.Iterable) 673 | 674 | # build label map depending on mode 675 | if self._merge_mode == TrainTestExpander.MERGE_LABELS: 676 | label_map, label_names_map = self._merge_labels(tasks, intersection=True) 677 | 678 | elif self._merge_mode == TrainTestExpander.MERGE_LABEL_NAMES: 679 | label_map, label_names_map = self._merge_label_names(tasks, intersection=True) 680 | 681 | # if no overlap 682 | if len(label_map) == 0 and len(label_names_map) == 0: 683 | raise ValueError('tasks provided do not share any common labels!') 684 | 685 | # build new tasks with consistent label map 686 | new_tasks = [] 687 | remapper = ClassificationTaskExpander() 688 | for idx, t in enumerate(tasks): 689 | new_task = remapper(t, label_map[idx], label_names=label_names_map, 690 | task_id=t.task_id, task_name=t.task_name, metadata=t.metadata, 691 | transform=t.transform, target_transform=t.target_transform) 692 | new_tasks.append(new_task) 693 | 694 | return new_tasks 695 | -------------------------------------------------------------------------------- /dataset/imat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import os 15 | import json 16 | import csv 17 | 18 | from .dataset import ClassificationTaskDataset 19 | from .expansion import BinaryClassificationTaskExpander, MetaclassClassificationTaskExpander 20 | 21 | 22 | class iMat2018FashionTasks: 23 | """iMat2018FashionTasks generates binary attributes tasks from the iMaterialist 2018 fashion dataset 24 | 25 | Parameters 26 | ---------- 27 | imat_dataset : iMat2018FashionDataset 28 | """ 29 | LABEL_MAP_FILE = 'labels.csv' 30 | CUSTOM_CATEGORY_TYPES = ['pants', 'shoes', 'other'] 31 | CATEGORY_TASK_ID = 1 32 | ATTRIBUTE_TASK_IDS = [2, 3, 4, 5, 6, 8, 9] 33 | 34 | def __init__(self, imat_dataset): 35 | self.dataset = imat_dataset 36 | self._setup_label_map() 37 | 38 | def _setup_label_map(self): 39 | """Helper to load annotation data. 40 | """ 41 | label_map_path = os.path.join(self.dataset.root, iMat2018FashionTasks.LABEL_MAP_FILE) 42 | self.label_map = None 43 | if not os.path.exists(label_map_path): 44 | return 45 | 46 | self.label_map = {} 47 | with open(label_map_path, 'r') as f: 48 | reader = csv.reader(f) 49 | for idx, row in enumerate(reader): 50 | if idx < 1: 51 | continue 52 | label_id, task_id, label_name, task_name, is_pants, is_shoe = row 53 | self.label_map[label_id] = row 54 | 55 | def generate(self, task_id=0, transform=None, target_transform=None): 56 | return self.generate_from_attributes(task_id=task_id, 57 | transform=transform, 58 | target_transform=target_transform) 59 | 60 | def generate_from_attributes(self, task_id=0, transform=None, target_transform=None): 61 | """Generate binary attributes tasks. 62 | 63 | Parameters 64 | ---------- 65 | task_id : int or None (default 0) 66 | if None, generate all tasks otherwise only generate for task_id 67 | task_id is iMat label-1 (i.e. zero-indexed) 68 | transform : callable (default None) 69 | Optional transform to be applied on a sample. 70 | target_transform : callable (default None) 71 | Optional transform to be applied on a label. 72 | 73 | Returns 74 | ------- 75 | dict : label (int) -> ClassificationTaskDataset 76 | """ 77 | task_expander = BinaryClassificationTaskExpander() 78 | if task_id is not None: 79 | assert isinstance(task_id, int) 80 | label = task_id + 1 81 | assert label in self.dataset.possible_labels, label 82 | results = task_expander(self.dataset, labels=[label], 83 | transform=transform, 84 | target_transform=target_transform) 85 | assert len(results) == 1, len(results) 86 | return list(results.values())[0] 87 | else: 88 | return task_expander(self.dataset, 89 | transform=transform, 90 | target_transform=target_transform) 91 | 92 | def generate_from_custom_category_types(self, task_id=0, 93 | transform=None, target_transform=None): 94 | """Generate from 3 custom category tasks: pants, shoes, other. 95 | 96 | Parameters 97 | ---------- 98 | task_id : int or None (default 0) 99 | if None, generate all tasks otherwise only generate for task_id 100 | task_id is 0-2: pants, shoes, other 101 | transform : callable (default None) 102 | Optional transform to be applied on a sample. 103 | target_transform : callable (default None) 104 | Optional transform to be applied on a label. 105 | 106 | Returns 107 | ------- 108 | dict : label (int) -> ClassificationTaskDataset 109 | """ 110 | # initialize task map 111 | task_map = [] 112 | for idx, task_name in enumerate(iMat2018FashionTasks.CUSTOM_CATEGORY_TYPES): 113 | task_map.append({ 114 | "task_id": idx, 115 | "task_name": task_name, 116 | "class_names": [], 117 | "class_ids": [], 118 | "label_map": {}, 119 | }) 120 | 121 | # build task map 122 | for idx, (label_id, data) in enumerate(self.label_map.items()): 123 | label_id, data_task_id, label_name, task_name, is_pants, is_shoe = data 124 | label_id = int(label_id) 125 | data_task_id = int(data_task_id) 126 | is_pants = is_pants == 'yes' 127 | is_shoe = is_shoe == 'yes' 128 | if data_task_id == iMat2018FashionTasks.CATEGORY_TASK_ID: 129 | if is_pants: 130 | task_map[0]['class_names'].append(label_name) 131 | new_label_index = len(task_map[0]['class_ids']) 132 | task_map[0]['class_ids'].append(new_label_index) 133 | task_map[0]['label_map'].update({label_id: new_label_index}) 134 | elif is_shoe: 135 | task_map[1]['class_names'].append(label_name) 136 | new_label_index = len(task_map[1]['class_ids']) 137 | task_map[1]['class_ids'].append(new_label_index) 138 | task_map[1]['label_map'].update({label_id: new_label_index}) 139 | else: 140 | task_map[2]['class_names'].append(label_name) 141 | new_label_index = len(task_map[2]['class_ids']) 142 | task_map[2]['class_ids'].append(new_label_index) 143 | task_map[2]['label_map'].update({label_id: new_label_index}) 144 | 145 | # create tasks 146 | # note: "other" category has at least one bad sample with multiple labels; remove it 147 | task_expander = MetaclassClassificationTaskExpander() 148 | if task_id is not None: 149 | assert isinstance(task_id, int) 150 | assert task_id in range(len(task_map)), task_id 151 | results = task_expander(self.dataset, [task_map[task_id]], 152 | force_remove_multi_label=(task_id == 2), 153 | transform=transform, 154 | target_transform=target_transform) 155 | assert len(results) == 1, len(results) 156 | return list(results.values())[0] 157 | else: 158 | results = task_expander(self.dataset, task_map[0:2], 159 | transform=transform, 160 | target_transform=target_transform) 161 | other_results = task_expander(self.dataset, [task_map[2]], 162 | force_remove_multi_label=True, 163 | transform=transform, 164 | target_transform=target_transform) 165 | results.update(other_results) 166 | return results 167 | 168 | def generate_from_attribute_types(self, task_id=0, 169 | transform=None, target_transform=None): 170 | """Generate from attribute types. 171 | 172 | Parameters 173 | ---------- 174 | task_id : int or None (default 0) 175 | if None, generate all tasks otherwise only generate for task_id 176 | task_id is 0-6: color, gender, material, neckline, pattern, sleeve, style 177 | transform : callable (default None) 178 | Optional transform to be applied on a sample. 179 | target_transform : callable (default None) 180 | Optional transform to be applied on a label. 181 | 182 | Returns 183 | ------- 184 | dict : label (int) -> ClassificationTaskDataset 185 | """ 186 | # initialize task map 187 | task_map = [] 188 | task_ids_map = {data_task_id: idx for idx, data_task_id in enumerate(iMat2018FashionTasks.ATTRIBUTE_TASK_IDS)} 189 | for idx, data_task_id in enumerate(iMat2018FashionTasks.ATTRIBUTE_TASK_IDS): 190 | task_map.append({ 191 | "task_id": data_task_id, 192 | "task_name": None, 193 | "class_names": [], 194 | "class_ids": [], 195 | "label_map": {}, 196 | "multi_label": 1, 197 | }) 198 | 199 | # build task map 200 | for idx, (label_id, data) in enumerate(self.label_map.items()): 201 | label_id, data_task_id, label_name, task_name, is_pants, is_shoe = data 202 | label_id = int(label_id) 203 | data_task_id = int(data_task_id) 204 | is_pants = is_pants == 'yes' 205 | is_shoe = is_shoe == 'yes' 206 | if data_task_id in iMat2018FashionTasks.ATTRIBUTE_TASK_IDS: 207 | task_idx = task_ids_map[data_task_id] 208 | task_map[task_idx]['task_id'] = data_task_id 209 | task_map[task_idx]['task_name'] = task_name 210 | task_map[task_idx]['class_names'].append(label_name) 211 | new_label_index = len(task_map[task_idx]['class_ids']) 212 | task_map[task_idx]['class_ids'].append(new_label_index) 213 | task_map[task_idx]['label_map'].update({label_id: new_label_index}) 214 | 215 | # create tasks 216 | task_expander = MetaclassClassificationTaskExpander() 217 | if task_id is not None: 218 | assert isinstance(task_id, int) 219 | assert task_id in range(len(task_map)), task_id 220 | results = task_expander(self.dataset, [task_map[task_id]], 221 | transform=transform, 222 | target_transform=target_transform) 223 | assert len(results) == 1, len(results) 224 | return list(results.values())[0] 225 | else: 226 | return task_expander(self.dataset, task_map, 227 | transform=transform, 228 | target_transform=target_transform) 229 | 230 | 231 | class iMat2018FashionDataset(ClassificationTaskDataset): 232 | """iMaterialist 2018 Fashion Dataset 233 | 234 | Parameters 235 | ---------- 236 | path : str (default None) 237 | path to dataset 238 | if None, search using DATA environment variable 239 | split : str (train|validation) 240 | load provided split 241 | binarize_labels : bool (default False) 242 | if True, binarize labels in dataset iterator 243 | metadata : dict (default empty) 244 | extra arbitrary metadata 245 | transform : callable (default None) 246 | Optional transform to be applied on a sample. 247 | target_transform : callable (default None) 248 | Optional transform to be applied on a label. 249 | """ 250 | TRAIN_SPLIT = 'train' 251 | VAL_SPLIT = 'validation' 252 | POSSIBLE_SPLITS = [TRAIN_SPLIT, VAL_SPLIT] 253 | 254 | LABEL_MAP_FILE = 'labels.csv' 255 | 256 | def __init__(self, path=None, split=TRAIN_SPLIT, binarize_labels=False, 257 | metadata={}, transform=None, target_transform=None): 258 | if path is not None: 259 | assert isinstance(path, str) 260 | path = os.path.join(os.environ['DATA'], 'imat2018', 'fashion') if path is None else path 261 | assert os.path.exists(path), path 262 | assert split in iMat2018FashionDataset.POSSIBLE_SPLITS, split 263 | 264 | self.split = split 265 | 266 | images_path = os.path.join(path, split) 267 | annotations_path = os.path.join(path, '{}.json'.format(split)) 268 | assert os.path.isdir(images_path), images_path 269 | assert os.path.exists(annotations_path), annotations_path 270 | 271 | with open(annotations_path, 'r') as f: 272 | j = json.load(f) 273 | images_list = [os.path.join(split, '{}.jpg').format(img['imageId']) for img in j['images']] 274 | images_to_labels = {os.path.join(split, '{}.jpg').format(labels['imageId']): [int(l) for l in labels['labelId']] 275 | for labels in j['annotations']} 276 | labels_list = [images_to_labels[img] for img in images_list] 277 | 278 | label_names = self._setup_label_map(path) 279 | 280 | super(iMat2018FashionDataset, self).__init__(images_list, 281 | labels_list, 282 | label_names=label_names, 283 | root=path, 284 | binarize_labels=binarize_labels, 285 | task_id=None, 286 | task_name='iMat2018', 287 | metadata=metadata, 288 | transform=transform, 289 | target_transform=target_transform) 290 | 291 | def _setup_label_map(self, root): 292 | label_map_path = os.path.join(root, iMat2018FashionDataset.LABEL_MAP_FILE) 293 | if not os.path.exists(label_map_path): 294 | return None 295 | 296 | label_names = {} 297 | with open(label_map_path, 'r') as f: 298 | reader = csv.reader(f) 299 | for idx, row in enumerate(reader): 300 | if idx < 1: 301 | continue 302 | label_id, task_id, label_name, task_name, is_pants, is_shoe = row 303 | label_names[int(label_id)] = label_name 304 | 305 | return label_names 306 | -------------------------------------------------------------------------------- /dataset/inat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import os 15 | from collections import defaultdict 16 | import json 17 | 18 | from .dataset import ClassificationTaskDataset 19 | 20 | 21 | class iNat2018Dataset(ClassificationTaskDataset): 22 | """iNaturalist 2018 Dataset 23 | 24 | Parameters 25 | ---------- 26 | path : str (default None) 27 | path to dataset 28 | if None, search using DATA environment variable 29 | split : str (train|test) 30 | load provided split 31 | task_id : int (default 0) 32 | id of task 33 | level : str (default 'order') 34 | what level of taxonomy to create tasks 35 | metadata : dict (default empty) 36 | extra arbitrary metadata 37 | transform : callable (default None) 38 | Optional transform to be applied on a sample. 39 | target_transform : callable (default None) 40 | Optional transform to be applied on a label. 41 | """ 42 | CATEGORIES_FILE = 'categories.json' 43 | TASKS_FILE = 'tasks.json' 44 | CLASS_TASKS_FILE = 'classes_tasks.json' 45 | TRAIN_2018_FILE = 'train2018.json' 46 | 47 | ORDER = 'order' 48 | CLASS = 'class' 49 | POSSIBLE_LEVELS = [ORDER, CLASS] 50 | 51 | TRAIN_SPLIT = 'train' 52 | VAL_SPLIT = 'val' 53 | POSSIBLE_SPLITS = [TRAIN_SPLIT, VAL_SPLIT] 54 | 55 | def __init__(self, root, split=TRAIN_SPLIT, task_id=0, level=ORDER, 56 | metadata={}, transform=None, target_transform=None): 57 | assert isinstance(root, str) 58 | path = os.path.join(root, 'inat2018') 59 | assert os.path.exists(path), path 60 | assert split in iNat2018Dataset.POSSIBLE_SPLITS, split 61 | assert isinstance(task_id, int) 62 | assert level in iNat2018Dataset.POSSIBLE_LEVELS, level 63 | 64 | self.split = split 65 | 66 | # load categories 67 | with open(os.path.join(path, iNat2018Dataset.CATEGORIES_FILE)) as f: 68 | self.categories = json.load(f) 69 | 70 | # load or create set of tasks (corresponding to classification inside orders) 71 | tasks_path = os.path.join(path, iNat2018Dataset.TASKS_FILE) if level == 'order' else os.path.join(path, 72 | iNat2018Dataset.CLASS_TASKS_FILE) 73 | if not os.path.exists(tasks_path): 74 | self._create_tasks(tasks_path, path, level=level) 75 | with open(tasks_path) as f: 76 | self.tasks = json.load(f) 77 | 78 | annotation_file = os.path.join(path, "{}2018.json".format(split)) 79 | assert os.path.exists(annotation_file), annotation_file 80 | with open(annotation_file) as f: 81 | j = json.load(f) 82 | annotations = j['annotations'] 83 | images_list = [img['file_name'] for img in j['images']] 84 | 85 | # get labels 86 | task = self.tasks[task_id] 87 | species_ids = task['species_ids'] 88 | species_ids = {k: i for i, k in enumerate(species_ids)} 89 | labels_list = [species_ids.get(a['category_id'], -1) for a in annotations] 90 | 91 | # throw away images we are not training on (identified by label==-1) 92 | images_list = [img for i, img in enumerate(images_list) if labels_list[i] != -1] 93 | labels_list = [l for l in labels_list if l != -1] 94 | 95 | # get label names 96 | label_names = {task['label_map'][str(l)]: n for l, n in zip(task['species_ids'], task['species_names'])} 97 | 98 | task_name = self.tasks[task_id]['name'] 99 | super(iNat2018Dataset, self).__init__(images_list, 100 | labels_list, 101 | label_names=label_names, 102 | root=path, 103 | task_id=task_id, 104 | task_name=task_name, 105 | metadata=metadata, 106 | transform=transform, 107 | target_transform=target_transform) 108 | 109 | def _create_tasks(self, tasks_path, root, level=ORDER): 110 | """Create tasks file. 111 | 112 | Post-conditions: 113 | Creates a file 114 | 115 | Parameters 116 | ---------- 117 | tasks_path : str 118 | path to tasks file to write 119 | root : str 120 | root folder to train file 121 | level : str (default 'order') 122 | what level of taxonomy to create tasks 123 | """ 124 | assert level in iNat2018Dataset.POSSIBLE_LEVELS, level 125 | 126 | with open(os.path.join(root, iNat2018Dataset.TRAIN_2018_FILE)) as f: 127 | annotations = json.load(f) 128 | annotations = annotations['annotations'] 129 | 130 | level_samples = defaultdict(list) 131 | for r in annotations: 132 | level_samples[self.categories[r['category_id']][level]].append(r['image_id']) 133 | 134 | tasks = [] 135 | for i, (level_name, _) in enumerate(sorted(level_samples.items(), key=lambda x: len(x[1]), reverse=True)): 136 | species_in_level = sorted([(c['name'], c['id']) for c in self.categories if c[level] == level_name], 137 | key=lambda x: x[0]) 138 | species_ids = {k: i for i, (_, k) in enumerate(species_in_level)} 139 | tasks.append({ 140 | 'id': i, 141 | 'name': level_name, 142 | 'species_names': [n for (n, i) in species_in_level], 143 | 'species_ids': [i for (n, i) in species_in_level], 144 | 'label_map': species_ids 145 | }) 146 | with open(tasks_path, 'w') as f: 147 | json.dump(tasks, f) 148 | -------------------------------------------------------------------------------- /dataset/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import os 15 | 16 | from torchvision.datasets import MNIST 17 | import torch.nn.functional as F 18 | 19 | from .dataset import ClassificationTaskDataset 20 | from .expansion import ClassificationTaskExpander 21 | 22 | class SplitMNISTTask: 23 | """SplitMNISTTask generates Split MNIST tasks given two classes 24 | 25 | Parameters 26 | ---------- 27 | mnist_dataset : MNISTDataset 28 | """ 29 | def __init__(self, mnist_dataset): 30 | self.mnist_dataset = mnist_dataset 31 | 32 | def generate(self, classes=(0, 1), transform=None, target_transform=None): 33 | """Generate tasks given the classes 34 | 35 | Parameters 36 | ---------- 37 | classes : tuple of ints (0-9) 38 | two classes to generate split MNIST 39 | (possible to accept more than two classes) 40 | transform : callable (default None) 41 | Optional transform to be applied on a sample. 42 | target_transform : callable (default None) 43 | Optional transform to be applied on a label. 44 | 45 | Returns 46 | ------- 47 | Task 48 | """ 49 | assert isinstance(classes, tuple) 50 | assert all(isinstance(c, int) and c >= 0 and c <= 9 for c in classes) 51 | assert len(set(classes)) == len(classes) 52 | 53 | task_expander = ClassificationTaskExpander() 54 | return task_expander(self.mnist_dataset, 55 | {c : new_c for new_c, c in enumerate(classes)}, 56 | label_names={classes.index(old_c) : name for old_c, name in self.mnist_dataset.label_names_map.items() if old_c in classes}, 57 | task_id=classes, 58 | task_name='Split MNIST {}'.format(classes), 59 | transform=transform, 60 | target_transform=target_transform) 61 | 62 | class MNISTDataset(ClassificationTaskDataset): 63 | """MNIST Dataset 64 | 65 | Parameters 66 | ---------- 67 | path : str (default None) 68 | path to dataset (should contain images folder in same directory) 69 | if None, search using DATA environment variable 70 | train : bool (default True) 71 | if True, load train split otherwise load test split 72 | download: bool (default False) 73 | if True, downloads the dataset from the internet and 74 | puts it in path directory; otherwise if dataset is already downloaded, 75 | it is not downloaded again 76 | metadata : dict (default empty) 77 | extra arbitrary metadata 78 | transform : callable (default None) 79 | Optional transform to be applied on a sample. 80 | target_transform : callable (default None) 81 | Optional transform to be applied on a label. 82 | expand: pad images to 32x32 and expand them to have 3 images to be 83 | """ 84 | def __init__(self, path=None, train=True, download=False, 85 | metadata={}, transform=None, target_transform=None, expand=False): 86 | if path is not None: 87 | assert isinstance(path, str) 88 | path = os.path.join(os.environ['DATA'], 'mnist') if path is None else path 89 | assert isinstance(train, bool) 90 | 91 | self.mnist = MNIST(path, train=train, transform=transform, 92 | target_transform=target_transform, download=download) 93 | data = self.mnist.train_data if train else self.mnist.test_data 94 | labels = self.mnist.train_labels if train else self.mnist.test_labels 95 | 96 | if expand: 97 | data = F.pad(data, [2,2,2,2]) 98 | data = data.view(-1, 32, 32, 1).expand([data.size(0), 32, 32, 3]) 99 | 100 | super(MNISTDataset, self).__init__([x for x in data], 101 | [int(x) for x in labels], 102 | label_names={l: str(l) for l in range(10)}, 103 | root=path, 104 | task_id=None, 105 | task_name='MNIST', 106 | metadata=metadata, 107 | transform=transform, 108 | target_transform=target_transform) 109 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | 15 | import collections 16 | import torchvision.transforms as transforms 17 | import os 18 | import json 19 | 20 | 21 | try: 22 | from IPython import embed 23 | except: 24 | pass 25 | 26 | _DATASETS = {} 27 | 28 | Dataset = collections.namedtuple( 29 | 'Dataset', ['trainset', 'testset']) 30 | 31 | 32 | def _add_dataset(dataset_fn): 33 | _DATASETS[dataset_fn.__name__] = dataset_fn 34 | return dataset_fn 35 | 36 | 37 | def _get_transforms(augment=True, normalize=None): 38 | if normalize is None: 39 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) 41 | 42 | basic_transform = [transforms.ToTensor(), normalize] 43 | 44 | transform_train = [] 45 | if augment: 46 | transform_train += [ 47 | transforms.RandomResizedCrop(224), 48 | transforms.RandomHorizontalFlip(), 49 | ] 50 | else: 51 | transform_train += [ 52 | transforms.Resize(256), 53 | transforms.CenterCrop(224), 54 | ] 55 | transform_train += basic_transform 56 | transform_train = transforms.Compose(transform_train) 57 | 58 | transform_test = [ 59 | transforms.Resize(256), 60 | transforms.CenterCrop(224), 61 | ] 62 | transform_test += basic_transform 63 | transform_test = transforms.Compose(transform_test) 64 | 65 | return transform_train, transform_test 66 | 67 | 68 | def _get_mnist_transforms(augment=True, invert=False, transpose=False): 69 | transform = [ 70 | transforms.ToTensor(), 71 | ] 72 | if invert: 73 | transform += [transforms.Lambda(lambda x: 1. - x)] 74 | if transpose: 75 | transform += [transforms.Lambda(lambda x: x.transpose(2, 1))] 76 | transform += [ 77 | transforms.Normalize((.5,), (.5,)), 78 | transforms.Lambda(lambda x: x.expand(3, 32, 32)) 79 | ] 80 | 81 | transform_train = [] 82 | transform_train += [transforms.Pad(padding=2)] 83 | if augment: 84 | transform_train += [transforms.RandomCrop(32, padding=4)] 85 | transform_train += transform 86 | transform_train = transforms.Compose(transform_train) 87 | 88 | transform_test = [] 89 | transform_test += [transforms.Pad(padding=2)] 90 | transform_test += transform 91 | transform_test = transforms.Compose(transform_test) 92 | 93 | return transform_train, transform_test 94 | 95 | 96 | def _get_cifar_transforms(augment=True): 97 | transform = [ 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 100 | ] 101 | transform_train = [] 102 | if augment: 103 | transform_train += [ 104 | transforms.Pad(padding=4, fill=(125, 123, 113)), 105 | transforms.RandomCrop(32, padding=0), 106 | transforms.RandomHorizontalFlip()] 107 | transform_train += transform 108 | transform_train = transforms.Compose(transform_train) 109 | transform_test = [] 110 | transform_test += transform 111 | transform_test = transforms.Compose(transform_test) 112 | return transform_train, transform_test 113 | 114 | 115 | def set_metadata(trainset, testset, config, dataset_name): 116 | trainset.metadata = { 117 | 'dataset': dataset_name, 118 | 'task_id': config.task_id, 119 | 'task_name': trainset.task_name, 120 | } 121 | testset.metadata = { 122 | 'dataset': dataset_name, 123 | 'task_id': config.task_id, 124 | 'task_name': testset.task_name, 125 | } 126 | return trainset, testset 127 | 128 | 129 | @_add_dataset 130 | def inat2018(root, config): 131 | from dataset.inat import iNat2018Dataset 132 | transform_train, transform_test = _get_transforms() 133 | trainset = iNat2018Dataset(root, split='train', transform=transform_train, task_id=config.task_id) 134 | testset = iNat2018Dataset(root, split='val', transform=transform_test, task_id=config.task_id) 135 | trainset, testset = set_metadata(trainset, testset, config, 'inat2018') 136 | return trainset, testset 137 | 138 | 139 | def load_tasks_map(tasks_map_file): 140 | assert os.path.exists(tasks_map_file), tasks_map_file 141 | with open(tasks_map_file, 'r') as f: 142 | tasks_map = json.load(f) 143 | tasks_map = {int(k): int(v) for k, v in tasks_map.items()} 144 | return tasks_map 145 | 146 | 147 | @_add_dataset 148 | def cub_inat2018(root, config): 149 | """This meta-task is the concatenation of CUB-200 (first 25 tasks) and iNat (last 207 tasks). 150 | 151 | - The first 10 tasks are classification of the animal species inside one of 10 orders of birds in CUB-200 152 | (considering all orders except passeriformes). 153 | - The next 15 tasks are classification of species inside the 15 families of the order of passerifomes 154 | - The remaining 207 tasks are classification of the species inside each of 207 families in iNat 155 | 156 | As noted above, for CUB-200 10 taks are classification of species inside an order, rather than inside of a family 157 | as done in the iNat (recall order > family > species). This is done because CUB-200 has very few images 158 | in each family of bird (expect for the families of passeriformes). Hence, we go up step in the taxonomy and 159 | consider classification inside a orders and not families. 160 | """ 161 | NUM_CUB = 25 162 | NUM_CUB_ORDERS = 10 163 | NUM_INAT = 207 164 | assert 0 <= config.task_id < NUM_CUB + NUM_INAT 165 | transform_train, transform_test = _get_transforms() 166 | if 0 <= config.task_id < NUM_CUB: 167 | # CUB 168 | from dataset.cub import CUBTasks, CUBDataset 169 | tasks_map_file = os.path.join(root, 'cub/CUB_200_2011', 'final_tasks_map.json') 170 | tasks_map = load_tasks_map(tasks_map_file) 171 | task_id = tasks_map[config.task_id] 172 | 173 | if config.task_id < NUM_CUB_ORDERS: 174 | # CUB orders 175 | train_tasks = CUBTasks(CUBDataset(root, split='train')) 176 | trainset = train_tasks.generate(task_id=task_id, 177 | use_species_names=True, 178 | transform=transform_train) 179 | test_tasks = CUBTasks(CUBDataset(root, split='test')) 180 | testset = test_tasks.generate(task_id=task_id, 181 | use_species_names=True, 182 | transform=transform_test) 183 | else: 184 | # CUB passeriformes families 185 | train_tasks = CUBTasks(CUBDataset(root, split='train')) 186 | trainset = train_tasks.generate(task_id=task_id, 187 | task='family', 188 | taxonomy_file='passeriformes.txt', 189 | use_species_names=True, 190 | transform=transform_train) 191 | test_tasks = CUBTasks(CUBDataset(root, split='test')) 192 | testset = test_tasks.generate(task_id=task_id, 193 | task='family', 194 | taxonomy_file='passeriformes.txt', 195 | use_species_names=True, 196 | transform=transform_test) 197 | else: 198 | # iNat2018 199 | from dataset.inat import iNat2018Dataset 200 | tasks_map_file = os.path.join(root, 'inat2018', 'final_tasks_map.json') 201 | tasks_map = load_tasks_map(tasks_map_file) 202 | task_id = tasks_map[config.task_id - NUM_CUB] 203 | 204 | trainset = iNat2018Dataset(root, split='train', transform=transform_train, task_id=task_id) 205 | testset = iNat2018Dataset(root, split='val', transform=transform_test, task_id=task_id) 206 | trainset, testset = set_metadata(trainset, testset, config, 'cub_inat2018') 207 | return trainset, testset 208 | 209 | 210 | @_add_dataset 211 | def imat2018fashion(root, config): 212 | NUM_IMAT = 228 213 | assert 0 <= config.task_id < NUM_IMAT 214 | from dataset.imat import iMat2018FashionDataset, iMat2018FashionTasks 215 | transform_train, transform_test = _get_transforms() 216 | train_tasks = iMat2018FashionTasks(iMat2018FashionDataset(root, split='train')) 217 | trainset = train_tasks.generate(task_id=config.task_id, 218 | transform=transform_train) 219 | test_tasks = iMat2018FashionTasks(iMat2018FashionDataset(root, split='validation')) 220 | testset = test_tasks.generate(task_id=config.task_id, 221 | transform=transform_test) 222 | trainset, testset = set_metadata(trainset, testset, config, 'imat2018fashion') 223 | return trainset, testset 224 | 225 | 226 | @_add_dataset 227 | def split_mnist(root, config): 228 | assert isinstance(config.task_id, tuple) 229 | from dataset.mnist import MNISTDataset, SplitMNISTTask 230 | transform_train, transform_test = _get_mnist_transforms() 231 | train_tasks = SplitMNISTTask(MNISTDataset(root, train=True)) 232 | trainset = train_tasks.generate(classes=config.task_id, transform=transform_train) 233 | test_tasks = SplitMNISTTask(MNISTDataset(root, train=False)) 234 | testset = test_tasks.generate(classes=config.task_id, transform=transform_test) 235 | trainset, testset = set_metadata(trainset, testset, config, 'split_mnist') 236 | return trainset, testset 237 | 238 | 239 | @_add_dataset 240 | def split_cifar(root, config): 241 | assert 0 <= config.task_id < 11 242 | from dataset.cifar import CIFAR10Dataset, CIFAR100Dataset, SplitCIFARTask 243 | transform_train, transform_test = _get_cifar_transforms() 244 | train_tasks = SplitCIFARTask(CIFAR10Dataset(root, train=True), CIFAR100Dataset(root, train=True)) 245 | trainset = train_tasks.generate(task_id=config.task_id, transform=transform_train) 246 | test_tasks = SplitCIFARTask(CIFAR10Dataset(root, train=False), CIFAR100Dataset(root, train=False)) 247 | testset = test_tasks.generate(task_id=config.task_id, transform=transform_test) 248 | trainset, testset = set_metadata(trainset, testset, config, 'split_cifar') 249 | return trainset, testset 250 | 251 | 252 | @_add_dataset 253 | def cifar10_mnist(root, config): 254 | from dataset.cifar import CIFAR10Dataset 255 | from dataset.mnist import MNISTDataset 256 | from dataset.expansion import UnionClassificationTaskExpander 257 | transform_train, transform_test = _get_cifar_transforms() 258 | trainset = UnionClassificationTaskExpander(merge_duplicate_images=False)( 259 | [CIFAR10Dataset(root, train=True), MNISTDataset(root, train=True, expand=True)], transform=transform_train) 260 | testset = UnionClassificationTaskExpander(merge_duplicate_images=False)( 261 | [CIFAR10Dataset(root, train=False), MNISTDataset(root, train=False, expand=True)], transform=transform_test) 262 | return trainset, testset 263 | 264 | 265 | @_add_dataset 266 | def cifar10(root): 267 | from torchvision.datasets import CIFAR10 268 | transform = transforms.Compose([ 269 | transforms.Resize(224), 270 | transforms.ToTensor(), 271 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 272 | ]) 273 | trainset = CIFAR10(root, train=True, transform=transform, download=True) 274 | testset = CIFAR10(root, train=False, transform=transform) 275 | return trainset, testset 276 | 277 | 278 | @_add_dataset 279 | def cifar100(root): 280 | from torchvision.datasets import CIFAR100 281 | transform = transforms.Compose([ 282 | transforms.Resize(224), 283 | transforms.ToTensor(), 284 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 285 | ]) 286 | trainset = CIFAR100(root, train=True, transform=transform, download=True) 287 | testset = CIFAR100(root, train=False, transform=transform) 288 | return trainset, testset 289 | 290 | 291 | @_add_dataset 292 | def mnist(root): 293 | from torchvision.datasets import MNIST 294 | transform = transforms.Compose([ 295 | lambda x: x.convert("RGB"), 296 | transforms.Resize(224), 297 | transforms.ToTensor(), 298 | # transforms.Normalize((0.5, 0.5, 0.5), (1., 1., 1.)), 299 | ]) 300 | trainset = MNIST(root, train=True, transform=transform, download=True) 301 | testset = MNIST(root, train=False, transform=transform) 302 | return trainset, testset 303 | 304 | 305 | @_add_dataset 306 | def letters(root): 307 | from torchvision.datasets import EMNIST 308 | transform = transforms.Compose([ 309 | lambda x: x.convert("RGB"), 310 | transforms.Resize(224), 311 | transforms.ToTensor(), 312 | # transforms.Normalize((0.5, 0.5, 0.5), (1., 1., 1.)), 313 | ]) 314 | trainset = EMNIST(root, train=True, split='letters', transform=transform, download=True) 315 | testset = EMNIST(root, train=False, split='letters', transform=transform) 316 | return trainset, testset 317 | 318 | 319 | @_add_dataset 320 | def kmnist(root): 321 | from torchvision.datasets import KMNIST 322 | transform = transforms.Compose([ 323 | lambda x: x.convert("RGB"), 324 | transforms.Resize(224), 325 | transforms.ToTensor(), 326 | ]) 327 | trainset = KMNIST(root, train=True, transform=transform, download=True) 328 | testset = KMNIST(root, train=False, transform=transform) 329 | return trainset, testset 330 | 331 | 332 | @_add_dataset 333 | def stl10(root): 334 | from torchvision.datasets import STL10 335 | transform = transforms.Compose([ 336 | transforms.Resize(224), 337 | transforms.ToTensor(), 338 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 339 | ]) 340 | trainset = STL10(root, split='train', transform=transform, download=True) 341 | testset = STL10(root, split='test', transform=transform) 342 | trainset.targets = trainset.labels 343 | testset.targets = testset.labels 344 | return trainset, testset 345 | 346 | 347 | def get_dataset(root, config=None): 348 | return _DATASETS[config.name](os.path.expanduser(root), config) 349 | 350 | 351 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | 15 | #!/usr/bin/env python3.6 16 | import pickle 17 | 18 | import hydra 19 | import logging 20 | 21 | from datasets import get_dataset 22 | from models import get_model 23 | 24 | from task2vec import Task2Vec 25 | from omegaconf import DictConfig, OmegaConf 26 | 27 | 28 | @hydra.main(config_path="conf/config.yaml") 29 | def main(cfg: DictConfig): 30 | logging.info(cfg.pretty()) 31 | train_dataset, test_dataset = get_dataset(cfg.dataset.root, cfg.dataset) 32 | if hasattr(train_dataset, 'task_name'): 33 | print(f"======= Embedding for task: {train_dataset.task_name} =======") 34 | probe_network = get_model(cfg.model.arch, pretrained=cfg.model.pretrained, 35 | num_classes=train_dataset.num_classes) 36 | probe_network = probe_network.to(cfg.device) 37 | embedding = Task2Vec(probe_network, **cfg.task2vec).embed(train_dataset) 38 | embedding.meta = OmegaConf.to_container(cfg, resolve=True) 39 | embedding.meta['task_name'] = getattr(train_dataset, 'task_name', None) 40 | with open('embedding.p', 'wb') as f: 41 | pickle.dump(embedding, f) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | 15 | import torch.utils.model_zoo as model_zoo 16 | 17 | import torchvision.models.resnet as resnet 18 | import torch 19 | 20 | from task2vec import ProbeNetwork 21 | 22 | _MODELS = {} 23 | 24 | 25 | def _add_model(model_fn): 26 | _MODELS[model_fn.__name__] = model_fn 27 | return model_fn 28 | 29 | 30 | class ResNet(resnet.ResNet, ProbeNetwork): 31 | 32 | def __init__(self, block, layers, num_classes=1000): 33 | super(ResNet, self).__init__(block, layers, num_classes) 34 | # Saves the ordered list of layers. We need this to forward from an arbitrary intermediate layer. 35 | self.layers = [ 36 | self.conv1, self.bn1, self.relu, 37 | self.maxpool, self.layer1, self.layer2, 38 | self.layer3, self.layer4, self.avgpool, 39 | lambda z: torch.flatten(z, 1), self.fc 40 | ] 41 | 42 | @property 43 | def classifier(self): 44 | return self.fc 45 | 46 | # @ProbeNetwork.classifier.setter 47 | # def classifier(self, val): 48 | # self.fc = val 49 | 50 | # Modified forward method that allows to start feeding the cached activations from an intermediate 51 | # layer of the network 52 | def forward(self, x, start_from=0): 53 | """Replaces the default forward so that we can forward features starting from any intermediate layer.""" 54 | for layer in self.layers[start_from:]: 55 | x = layer(x) 56 | return x 57 | 58 | 59 | @_add_model 60 | def resnet18(pretrained=False, num_classes=1000): 61 | """Constructs a ResNet-18 model. 62 | Args: 63 | pretrained (bool): If True, returns a model pre-trained on ImageNet 64 | """ 65 | model: ProbeNetwork = ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 66 | if pretrained: 67 | state_dict = model_zoo.load_url(resnet.model_urls['resnet18']) 68 | state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k} 69 | model.load_state_dict(state_dict, strict=False) 70 | return model 71 | 72 | 73 | @_add_model 74 | def resnet34(pretrained=False, num_classes=1000): 75 | """Constructs a ResNet-18 model. 76 | Args: 77 | pretrained (bool): If True, returns a model pre-trained on ImageNet 78 | """ 79 | model = ResNet(resnet.BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 80 | if pretrained: 81 | state_dict = model_zoo.load_url(resnet.model_urls['resnet34']) 82 | state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k} 83 | model.load_state_dict(state_dict, strict=False) 84 | return model 85 | 86 | 87 | def get_model(model_name, pretrained=False, num_classes=1000): 88 | try: 89 | return _MODELS[model_name](pretrained=pretrained, num_classes=num_classes) 90 | except KeyError: 91 | raise ValueError(f"Architecture {model_name} not implemented.") 92 | -------------------------------------------------------------------------------- /plot_distance_cub_inat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"). You 6 | # may not use this file except in compliance with the License. A copy of 7 | # the License is located at 8 | # 9 | # http://aws.amazon.com/apache2.0/ 10 | # 11 | # or in the "license" file accompanying this file. This file is 12 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 13 | # ANY KIND, either express or implied. See the License for the specific 14 | # language governing permissions and limitations under the License. 15 | 16 | import os 17 | import json 18 | import sys 19 | from io import BytesIO 20 | import argparse 21 | import numpy as np 22 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox 23 | from scipy.cluster import hierarchy 24 | from scipy.spatial.distance import squareform, pdist 25 | import task_similarity 26 | import glob 27 | import matplotlib as mpl 28 | mpl.use('Agg') 29 | import matplotlib.pyplot as plt 30 | import seaborn as sns 31 | 32 | 33 | CATEGORIES_JSON_FILE = 'inat2018/categories.json' 34 | ICONS_PATH = './static/iconic_taxa/' 35 | 36 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 37 | parser.add_argument('root', default='multirun/variational', type=str) 38 | parser.add_argument('--data-root', default='~/data', type=str) 39 | parser.add_argument('--distance', default='cosine', type=str, 40 | help='distance to use') 41 | parser.add_argument('--max-tasks', default=120, type=int, 42 | help='number of tasks to consider') 43 | args = parser.parse_args() 44 | 45 | # Assumes task IDs are mapped as follows: 46 | # CUB: 0-24 (0-9 are orders, 10-24 are Passeriformes families) 47 | # iNat2018: 25-231 48 | # As in cub_inat2018 in datasets.py 49 | CUB = 'CUB' 50 | INAT = 'iNat' 51 | CUB_NUM_TASKS = 25 52 | 53 | ADDITIONAL_TAXONOMY_DATA = [ 54 | { 55 | 'kingdom': 'Animalia ', 56 | 'supercategory': 'Animalia ', 57 | 'phylum': 'Chordata', 58 | 'class': 'Aves', 59 | 'order': 'Apodiformes', 60 | } 61 | ] 62 | 63 | 64 | def invert_icon(img): 65 | img = (1. - img) 66 | return img 67 | 68 | 69 | def get_image(e): 70 | base = os.path.join(ICONS_PATH, "{}-200px.png") 71 | possible_names = [base.format(x) for x in 72 | [e.meta.get('class'), e.meta.get('phylum'), e.meta.get('kingdom'), 'unknown']] 73 | for filename in possible_names: 74 | if os.path.exists(filename): 75 | img = plt.imread(filename, format='png') 76 | return invert_icon(img) if e.dataset == CUB else img 77 | raise FileNotFoundError() 78 | 79 | 80 | def average_top_k_tax_distance(distance_matrix, from_embeddings, to_embeddings=None, k=2): 81 | assert k > 0, k 82 | 83 | if to_embeddings is None: 84 | to_embeddings = from_embeddings 85 | 86 | assert distance_matrix.shape[0] == len(from_embeddings) 87 | assert distance_matrix.shape[1] == len(to_embeddings) 88 | 89 | tax_distance = [] 90 | for i in range(len(from_embeddings)): 91 | top_matches = distance_matrix[i].argsort()[:k] 92 | tax_distance.append( 93 | np.mean([taxonomy_distance(from_embeddings[i], to_embeddings[j]) for j in top_matches]) 94 | ) 95 | return np.mean(tax_distance) 96 | 97 | 98 | def plot_changing_k(ax, distance_matrix, from_embeddings, to_embeddings, **kwargs): 99 | x = [1, 3, 5, 10, 20, 30, 50, 100, 200, len(to_embeddings)] 100 | x = [v for v in x if v <= len(to_embeddings)] 101 | y = [] 102 | for k in x: 103 | y.append(average_top_k_tax_distance(distance_matrix, from_embeddings, to_embeddings, k=k)) 104 | ax.plot(x, y, **kwargs) 105 | ax.set_xlabel('Size k of neighborhood') 106 | ax.set_ylabel('Avg. top-k tax. distance') 107 | 108 | 109 | def sort_distance_matrix(distance_matrix, embeddings, names, method='complete'): 110 | assert method in ['ward', 'single', 'average', 'complete'] 111 | np.fill_diagonal(distance_matrix, 0.) 112 | cond_distance_matrix = squareform(distance_matrix, checks=False) 113 | linkage_matrix = hierarchy.linkage(cond_distance_matrix, method='complete', optimal_ordering=True) 114 | res_order = hierarchy.leaves_list(linkage_matrix) 115 | distance_matrix = distance_matrix[res_order][:, res_order] 116 | embeddings = [embeddings[i] for i in res_order] 117 | names = [names[i] for i in res_order] 118 | np.fill_diagonal(distance_matrix, np.nan) 119 | return distance_matrix, embeddings, names, res_order 120 | 121 | 122 | def draw_figure_to_plt(distance_matrix, embeddings, names, label_size=14): 123 | fig = plt.figure(figsize=(15 / 25. * len(embeddings), 15 / 25. * len(embeddings))) 124 | ax = plt.gca() 125 | 126 | plt.imshow(distance_matrix, cmap='viridis_r') 127 | ax.set_xticks(np.arange(len(embeddings))) 128 | ax.set_yticks(np.arange(len(embeddings))) 129 | 130 | ax.set_xticklabels(names) 131 | ax.set_yticklabels(names) 132 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 133 | rotation_mode="anchor") 134 | 135 | try: 136 | for i, e in enumerate(embeddings): 137 | arr_img = get_image(e) 138 | imagebox = OffsetImage(arr_img, zoom=0.18) 139 | imagebox.image.axes = ax 140 | xy = (i, i) 141 | ab = AnnotationBbox(imagebox, xy, frameon=False) 142 | ax.add_artist(ab) 143 | except FileNotFoundError: 144 | print("Could not find an icon for a taxonomy entry. Have you downloaded the iconic_taxa directory in ./static?") 145 | 146 | plt.tick_params(axis='both', which='major', labelsize=label_size) 147 | plt.tight_layout() 148 | 149 | 150 | def taxonomy_distance(e0, e1): 151 | for i, k in enumerate(['order', 'class', 'phylum', 'kingdom']): 152 | if e0.meta[k] == e1.meta[k]: 153 | return i 154 | return i + 1 155 | 156 | 157 | def add_class_information(embeddings): 158 | # load taxonomy 159 | with open(os.path.join(args.data_root, CATEGORIES_JSON_FILE), 'r') as f: 160 | categories = json.load(f) 161 | categories.extend(ADDITIONAL_TAXONOMY_DATA) 162 | 163 | category_map = {c['order']: c for c in categories} 164 | category_map.update({c['family']: c for c in categories if 'family' in c}) 165 | 166 | for e in embeddings: 167 | c = category_map[e.task_name] 168 | e.meta['order'] = c['order'].lower() 169 | e.meta['class'] = c['class'].lower() 170 | e.meta['phylum'] = c['phylum'].lower() 171 | e.meta['kingdom'] = c['kingdom'].lower() 172 | e.meta['supercategory'] = c['supercategory'].lower() 173 | 174 | 175 | def main(): 176 | os.makedirs('./plots', exist_ok=True) 177 | 178 | files = glob.glob(os.path.join(args.root, '*', 'embedding.p')) 179 | 180 | # get embeddings 181 | embeddings = [task_similarity.load_embedding(file) for file in files] 182 | embeddings.sort(key=lambda x: x.meta['dataset']['task_id']) 183 | for e in embeddings: 184 | e.task_id = e.meta['dataset']['task_id'] 185 | e.task_name = e.meta['task_name'] 186 | e.dataset = CUB if e.task_id < CUB_NUM_TASKS else INAT 187 | 188 | # get distance matrix 189 | distance_matrix = task_similarity.pdist(embeddings, distance=args.distance) 190 | add_class_information(embeddings) 191 | 192 | # construct names to display on plot 193 | for e in embeddings: 194 | assert hasattr(e, 'task_name') 195 | assert 'class' in e.meta, e.task_name 196 | assert hasattr(e, 'dataset'), e.task_name 197 | 198 | names = [f"[{e.dataset}] {e.task_name} ({e.meta['class']})" if 'order' in e.meta 199 | else f"[{e.dataset}] {e.task_name} ({e.meta['class']})" for e in embeddings] 200 | embeddings = np.array(embeddings) 201 | 202 | # === Plot comparison between taxonomical distance and task2vec distance === 203 | # Compute all taxonomical distances 204 | tax_distance_matrix = np.zeros_like(distance_matrix) 205 | for i, e0 in enumerate(embeddings): 206 | for j, e1 in enumerate(embeddings): 207 | tax_distance_matrix[i, j] = taxonomy_distance(e0, e1) 208 | 209 | sns.set_style('whitegrid') 210 | plt.close('all') 211 | sns.set_style('whitegrid') 212 | fig = plt.figure(figsize=(3.3, 3.3)) 213 | ax = fig.gca() 214 | # Plot how the average task2vec and taxonomical distance change as the neighborhood changes 215 | np.fill_diagonal(distance_matrix, np.inf) 216 | np.fill_diagonal(tax_distance_matrix, np.inf) 217 | plot_changing_k(ax, distance_matrix, embeddings, embeddings, label='Task2Vec distance') 218 | plot_changing_k(ax, tax_distance_matrix, embeddings, embeddings, label='Taxonomy distance') 219 | np.fill_diagonal(distance_matrix, 0) 220 | np.fill_diagonal(tax_distance_matrix, 0) 221 | ax.legend() 222 | # save figure 223 | fig.savefig('plots/embedding_distace_vs_taxonomy.pdf', bbox_inches='tight') 224 | 225 | # === Plot the clustered distance matrix === 226 | np.fill_diagonal(distance_matrix, 0.) 227 | 228 | embeddings = embeddings[:50] 229 | names = names[:50] 230 | distance_matrix = distance_matrix[:50, :50] 231 | # sort distance matrix 232 | distance_matrix, embeddings, names, _ = sort_distance_matrix(distance_matrix, embeddings, names, method='complete') 233 | # draw figure 234 | plt.close('all') 235 | sns.set_style('white') 236 | draw_figure_to_plt(distance_matrix, embeddings, names, label_size=22) 237 | # save figure 238 | plt.savefig('plots/task2vec_distance_matrix.pdf', format=args.format, dpi=None, bbox_inches='tight') 239 | 240 | 241 | if __name__ == '__main__': 242 | main() 243 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | seaborn 2 | scipy 3 | matplotlib 4 | omegaconf 5 | fastcluster 6 | torch 7 | torchvision 8 | numpy 9 | pandas 10 | hydra-core 11 | sklearn 12 | tqdm 13 | -------------------------------------------------------------------------------- /scripts/download_cub.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"). You 6 | # may not use this file except in compliance with the License. A copy of 7 | # the License is located at 8 | # 9 | # http://aws.amazon.com/apache2.0/ 10 | # 11 | # or in the "license" file accompanying this file. This file is 12 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 13 | # ANY KIND, either express or implied. See the License for the specific 14 | # language governing permissions and limitations under the License. 15 | 16 | if [ $# -nq 1 ]; then 17 | echo "Usage: ./download_cub.sh DATA_ROOT" 18 | exit 1 19 | fi 20 | 21 | DATA_DIR=$1 22 | TASK2VEC_REPO="./" 23 | 24 | 25 | mkdir -p "$DATA_DIR"/cub 26 | cd "$DATA_DIR"/cub || exit 1 27 | 28 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz 29 | tar -xzf images.tgz 30 | wget http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz 31 | tar -xzf lists.tgz 32 | mv lists/* CUB_200_2011/ 33 | cp $TASK2VEC_REPO/support_files/cub/*.json "$DATA_DIR"/cub/CUB_200_2011 -------------------------------------------------------------------------------- /scripts/download_inat2018.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"). You 6 | # may not use this file except in compliance with the License. A copy of 7 | # the License is located at 8 | # 9 | # http://aws.amazon.com/apache2.0/ 10 | # 11 | # or in the "license" file accompanying this file. This file is 12 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 13 | # ANY KIND, either express or implied. See the License for the specific 14 | # language governing permissions and limitations under the License. 15 | 16 | if [ $# -nq 1 ]; then 17 | echo "Usage: ./download_inat2018.sh DATA_ROOT" 18 | exit 1 19 | fi 20 | 21 | DATA_DIR=$1 22 | TASK2VEC_REPO="./" 23 | 24 | mkdir -p "$DATA_DIR"/inat2018 25 | cd "$DATA_DIR"/inat2018|| exit 1 26 | 27 | # For alternative ways of downloading the files check https://github.com/visipedia/inat_comp/tree/master/2018 28 | FILES="train_val2018.tar.gz train2018.json.tar.gz val2018.json.tar.gz test2018.tar.gz test2018.json.tar.gz" 29 | for FILE in $FILES ; do 30 | echo "Downloading $FILE..." 31 | wget https://ml-inat-competition-datasets.s3.amazonaws.com/2018/$FILE 32 | echo "Extracting $FILE..." 33 | tar -xzf $FILE 34 | done 35 | echo "Downloading unobfuscated category names" 36 | wget http://www.vision.caltech.edu/~gvanhorn/datasets/inaturalist/fgvc5_competition/categories.json.tar.gz 37 | tar -xzf categories.json.tar.gz 38 | 39 | # Auxiliary files needed to generate the task list in the paper: 40 | cp $TASK2VEC_REPO/support_files/inat2018/*.json ./ 41 | -------------------------------------------------------------------------------- /small_datasets_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from task2vec import Task2Vec\n", 10 | "from models import get_model\n", 11 | "import datasets\n", 12 | "import task_similarity" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "dataset_names = ('stl10', 'mnist', 'cifar10', 'cifar100', 'letters', 'kmnist')\n", 22 | "# Change `root` with the directory you want to use to download the datasets\n", 23 | "dataset_list = [datasets.__dict__[name](root='./data')[0] for name in dataset_names] " 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "embeddings = []\n", 33 | "for name, dataset in zip(dataset_names, dataset_list):\n", 34 | " print(f\"Embedding {name}\")\n", 35 | " probe_network = get_model('resnet34', pretrained=True, num_classes=int(max(dataset.targets)+1)).cuda()\n", 36 | " embeddings.append( Task2Vec(probe_network, max_samples=1000, skip_layers=6).embed(dataset) )" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsgAAALICAYAAABiqwZ2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3df5jlZ13f/9c7u8JKCJDAF9QkmqCB+AtBR8DfgAQRvUgVgRgRVHSpVoGvP2pAhS3I1yBCQZsiKwICV0uFUhpLQkwRQsuvZkAkhUqNwZBEEJAAAQyQ5P3945yVO5NN5sxmd87MZx6P65pr5nw+5/7sPWc3k+fcc58z1d0BAABmjlr2BAAAYCsRyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAw2L2RO5921CO31K/du/CGV9ey5wAAwLRsKJBTFpwBAJi2DQVy7dp1pOYBAABbwqQDuaoemuQFSXYleXF3n32Q+zwqyb4kneSvuvvMTZ0kcEhWVlbOSrJn2fMAtqVrV1dXb9IEcMDGtljs2j5bLKpqV5JzkpyW5MokF1fVud39/uE+pyR5SpLv6u6rq+quy5ktcAj2rK6u7lv2JIDtZ2VlZd+y58DWNuUV5PsmubS7L0uSqnpVktOTvH+4z88lOae7r06S7v7ops8SAIAtZYMryFsrkKtqb5K9w6H93b1//vHxSa4Yzl2Z5H5rLnGP+XXemtk2jH3d/YYjNF0AALaBbb2CPI/h/eve8ebtTnJKkgckOSHJW6rqm7v7k4dhegAAbEOT3YOc5KokJw63T5gfG12Z5J3d/cUkH6yq/5tZMF+8OVMEAGCr2dZbLNZxcZJTqurkzML4jCRrX6HidUl+PMlLq+oumW25uGxTZwkAwJaysUA+avusIHf3dVX1i0kuyGx/8Uu6+31V9Ywkq9197vzcQ6rq/UmuT/Jr3f2Py5s1AADLNuUtFunu85Kct+bY04aPO8kvz98AAGCjK8jbaosFAABs2KRXkAEAYKMmuwcZAAAOhRVkAAAYbCiQe1cdqXkAAMCWYIsFAAAMNraCLJABAJi4De5BtsUCAIBps4IMAAADK8gAADDY4AqyQAYAYNq8zBsAAAysIAMAwGCDr4MskAEAmDZbLAAAYGCLBQAADAQyAAAMNhjIR2oaAACwNdiDDAAAAyvIAAAwsAcZAAAGVpABAGBgBRkAAAYbfJLekZoGAABsDbZYAADAwBYLAAAYWEEGAICBQAYAgIFABgCAwYYCObYgAwAwcZ6kBwAAA1ssAABgsKHk7aO21tt6quqhVfWBqrq0qs66hfs9oqq6qlY28ngAADA9k11BrqpdSc5JclqSK5NcXFXndvf719zvmCRPSvLOzZ8lAABbzZRXkO+b5NLuvqy7v5DkVUlOP8j9npnk2Umu3chjAQDANG1sTfiorfVWVXuranV42zvM9vgkVwy3r5wf+2dV9a1JTuzu12/ocQAAYLK29RaL7t6fZP+hjK2qo5I8L8lPHc45AQCwvW0skLfXq7xdleTE4fYJ82MHHJPkm5K8uaqS5CuSnFtVD+/u1U2bJQAAW8q2XkFex8VJTqmqkzML4zOSnHngZHd/KsldDtyuqjcn+VVxDACws002kLv7uqr6xSQXJNmV5CXd/b6qekaS1e4+d7kzBABgK9rYr5reRoGcJN19XpLz1hx72s3c9wGbMScAALa2ya4gAwDAoRDIAAAwEMgAADCY9B5kAADYKCvIAAAwmPIvCgEAgA2zxQIAAAa2WAAAwEAgAwDAYIOB3EdqHgAAsCXYgwwAAANbLAAAYLDBFWRbLAAAmDavgwwAAANbLAAAYGCLBQAADLyKBQAADKwgAwDAwC8KAQCAwYYCuWyxAABg4myxAACAwcYCuQQyAADTtsEtFgIZALablZWVs5LsWfY8tpCTVlZW9i17ElvItaurq2cvexJbiUAGgOnbs7q6um/Zk2Br8s3CTXmSHgAADDYYyDccqXkAAMCWsKFAPsoWCwAAJs4eZAAAGGxsBdnLvAEAMHEbW0EWyAAATNyGAnmXLRYAAEzcBp+k51UsAACYtkmvIFfVQ5O8IMmuJC/u7rPXnP/lJD+b5LokH0vyM919+aZPFACALWNjgVzbZwW5qnYlOSfJaUmuTHJxVZ3b3e8f7vaXSVa6+3NV9fNJfjfJozd/tgAAbBVTXkG+b5JLu/uyJKmqVyU5Pck/B3J3v2m4/zuSPGZTZwgAwJazwUDeWivIVbU3yd7h0P7u3j//+PgkVwznrkxyv1u43OOTnH94ZwgAwHazrQN5HsP7173jOqrqMUlWknzfrZ4UAADb2mT3ICe5KsmJw+0T5sdupKoenOQ3knxfd39+k+YGAMAWtaFA3r3FVpDXcXGSU6rq5MzC+IwkZ453qKr7JHlRkod290c3f4oAAGw1kw3k7r6uqn4xyQWZvczbS7r7fVX1jCSr3X1ukuckuX2SV1dVknyoux++tEkDALB0Gwvk7bXFIt19XpLz1hx72vDxgzd9UgAAbGmTXUEGAIBDMekVZAAA2KgNriBff6TmAQAAW4IVZAAAGFhBBgCAwYYC+cusIMNSrKysnJVkz7LnscWctLKysm/Zk9hCrl1dXT172ZMAmAIryLA97FldXd237EmwdflmAeDw2eAKskAGAGDaNhbIXgcZAICJ2+CrWFhBBgBg2myxAACAgUAGAIDBBvcgX3ek5gEAAFuCFWQAABgIZHacbfpLN7brL8XwyysA2HYEMjuRX7qxSbZp1AOww20okG9T9iADADBtVpABAGAgkAEAYLDBQLbFAgCAadvgHmQryAAATJsVZADYAC8VuWm8TCRLs7FAjhVkAHY8LxW5CbZh0DMhtlgAAMDAFgsAABhsMJBvOFLzAACALWFjWyzsQQYAYOKsIAMAwGBDgbwrfaTmAQAAW8IGV5AFMgAA07bBQD5S0wAAgK1hg78oRCEDADBtG1xBPupIzeOIqKqHJnlBkl1JXtzdZ685f9skL0/ybUn+Mcmju/vvNnueAABsHRtcQd4+gVxVu5Kck+S0JFcmubiqzu3u9w93e3ySq7v766rqjCTPTvLozZ8tAABbxQZXkHcdqXkcCfdNcml3X5YkVfWqJKcnGQP59CT75h+/Jsm/q6rqbs9GBADYoaYcyMcnuWK4fWWS+93cfbr7uqr6VJI7J/n4pswQAIAtZ0OBvPsrLt1Sz9Krqr1J9g6H9nf3/mXNBwCA7W9DgbzVzGP45oL4qiQnDrdPmB872H2urKrdSe6Y2ZP1AADYobbPs+427uIkp1TVyVV1myRnJDl3zX3OTfK4+cc/luQv7D8GANjZaso9WFUPS/L8zF7m7SXd/ayqekaS1e4+t6r2JHlFkvsk+USSMw48qY/F3PCRe0z3H9AW8tsfP3XZU9gxPvDZuy17CjvGX7/4G5Y9hR1h97W+TG+W4y68dNlT2FHO/8i/P2Jbf7f1Fov1dPd5Sc5bc+xpw8fXJnnkZs8LAICta8pbLAAAYMMEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAJNTVa9Y5NjB7D7802GrWFlZOSvJniN1/dXV1X1H6toAALekqvYm2Tsc2t/d+4fb37jm/ruSfNsi1xbI07ZHxAIAUzSP4f1rj1fVU5I8NcmXV9WnDxxO8oWD3f9gbLEAAGAyuvt3uvuYJM/p7jvM347p7jt391MWuYZABgBgiv5bVR2dJFX1mKp6XlV9zSIDBTIAAFP0wiSfq6pvSfIrSf42ycsXGSiQAQCYouu6u5OcnuTfdfc5SY5ZZKAn6QEAMEXXzJ+w95gk31tVRyX5skUGWkEGAGCKHp3k80ke390fSXJCkucsMtAKMgAAkzOP4ucNtz+UBfcgC2QAACajqv5nd393VV2TpMdTSbq777DeNQQyAACT0d3fPX+/0BPyDkYgAwAwSfNfL323DM0732pxiwQyAACTU1W/lOTpSf4hyQ3zw53kXuuNFcgAAEzRk5Lcs7v/caMDvcwbAABTdEWSTx3KQCvIAABM0WVJ3lxVr8/s9ZCTJN39vJsfMiOQAQCYog/N374sC/4GvQMEMgAAU3RekqcmOSlfat5O8oz1BgpkAACm6JVJfjXJ/86XXsViIQIZAIAp+lh3/9mhDBTIAABM0dOr6sVJ3pgbP0nvtesNFMgAAEzRTyc5NbMn6I2/KEQgAwCwI317d9/zUAb6RSEAAEzR26rqGw5loBVkAACm6P5J3lNVH8xsD3Il6e6+13oDBTIAAFP00EMdKJABAJic7r78UMfagwwAAAOBDAAAA4EMAAADgQwAAAOBDAAAA4EMAAADgQwAAAOBDAAAA4EMAAADgQwAAAOBDAAAA4EMAAADgQwAAAOBDAAAA4EMAAADgQwAAAOBDAAAA4EMAACD3cueANvbb3/81GVPYUf4zbv89bKnsGM8/p/uvOwp7Bi7vtjLnsKOcJtrrl/2FHaM6z76sWVPgcNEIG9TKysrZyXZs87dTtqEqQAATIpA3r72rK6u7rulO6ysrNzieQAAbsoeZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAJqmqHllVx8w//s2qem1Vfet643Yf+amxRNeurKzsO1IXX11dPWLXBgC4JVW1N8ne4dD+7t6/5m6/1d2vrqrvTvLgJM9J8sIk97ulawvkCVtdXT172XMAADgS5jG8NojXun7+/ocyC+jXV9Vvr3dtWywAAJiqq6rqRUkeneS8qrptFuhfgQwAwFQ9KskFSX6guz+Z5Lgkv7beIFssAACYnKraleTd3X3qgWPd/eEkH15vrBVkAAAmp7uvT/KBqvrqjY61ggwAwFQdm+R9VfW/knz2wMHufvgtDRLIAABM1W8dyiCBDADAJHX3RVX1NUlO6e7/XlW3S7JrvXH2IAMAMElV9XNJXpPkRfNDxyd53XrjBDIAAFP1r5J8V5JPJ0l3/02Su643SCADADBVn+/uLxy4UVW7k/R6gwQyAABTdVFVPTXJl1fVaUleneTP1hskkAEAmKqzknwsySVJnpDkvO7+jfUGeRULAACm6pe6+wVJ/ujAgap60vzYzbKCDADAVD3uIMd+ar1BVpABAJiUqvrxJGcmObmqzh1OHZPkE+uNF8gAAEzN25J8OMldkjx3OH5NkveuN1ggAwAwKd19eZLLq+ot3X3ReK6qnp3k129pvD3IAABM1WkHOfaD6w2yggwAwKRU1c8n+YUkX1tV45aKY5K8db3xAhkAgKn5D0nOT/I7mb0W8gHXdPe6T9KzxQIAgEnp7k919991948nOTHJg+b7ko+qqpPXG3/IK8grKytnJdlzqOO51U5a9gQAALayqnp6kpUk90zy0iS3SfLKJN91S+NuzRaLPaurq/tuxXhuhZWVlX3LngMAwBb3I0nuk+TdSdLdf19Vx6w3yBYLAACm6gvd3Uk6Sarq6EUGCWQAAKbqT6vqRUnuVFU/l+S/J/mj9QZ5FQsAACapu3+vqk5L8unM9iE/rbsvXG+cQAYAYLLmQbxuFI8EMgAAk1JV12S+73jtqSTd3Xe4pfECGQCASenudV+p4pZ4kh4AAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADHYvewJsbx/47N2WPYUd4fH/dOdlT2HH+OMT37rsKewYK3vutewp7AhH3X7XsqewYxx9zDHLngKHiRVkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAACanqp60yLGDEcgAAEzR4w5y7KcWGbj78M6DTXTtysrKvmVOYHV1dal/PgCwc1XV3iR7h0P7u3t/Vf14kjOTnFxV5w7n75DkE4tcWyBvU6urq2cvew4AAMvS3fuT7D/Iqbcl+XCSuyR57nD8miTvXeTaAhkAgMno7suTXF5VD07yT919Q1XdI8mpSS5Z5Br2IAMAMEVvSbKnqo5P8udJfjLJyxYZKJABAJii6u7PJfnRJP++ux+Z5BsXGSiQAQCYoqqq70jyE0lePz+2a5GBAhkAgCl6cpKnJPkv3f2+qrp7kjctMtCT9AAAmJzuvijJRcPty5I8cZGxAhkAgMmoqud395Or6s+S9Nrz3f3w9a4hkAEAmJJXzN//3qFeQCADADAZ3f2u+fuL1rvvzRHIAABMTlV9V5J9Sb4ms+atJN3dd19vrEAGAGCK/jjJ/5vkXUmu38hAgQwAwBR9qrvPP5SBAhkAgCl6U1U9J8lrk3z+wMHufvd6AwUyAABTdL/5+5XhWCd50HoDBTIAAJPT3Q881LECGQCAyamqOyV5bJKTMjRvd6/72/QEMgAAU3ReknckuSTJDRsZKJABAJiiPd39y4cy8KjDPRMAANgCXlFVP1dVX1lVxx14W2SgFWQAAKboC0mek+Q3Mnv1iszf+016AADsSL+S5Ou6++MbHWiLBQAAU3Rpks8dykAryAAATNFnk7ynqt6UG/8mPS/zBgDAjvTWJK9bc+yYRQbaYgEAwBSdmeRd3f0n3f0nmT1p7ycXGWgFGQCAKfqxJK+pqjOTfE9mv1XvIYsMFMgAAExOd19WVWdkts3iQ0ke0t3/tMhYgQwAwGRU1SX50useJ8lxSXYleWdVpbvvtd41BDIAAFPyw7f2AgIZAIDJ6O7Lb+01vIoFAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMdi97Amxvf/3ib1j2FHaEXV/sZU9hx1jZc69lT2HHWN33wmVPYUd41WeOXfYUdoynPOgRy54Ch4kVZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGOxe9gQAAOBwq6q7JTl+fvOq7v6HRccKZAAAJqOq7p3kD5PcMclV88MnVNUnk/xCd797vWsI5B1qZWXlrCR7bs01VldX9x2e2QAAbExV7U2ydzi0v7v3J3lZkid09zvX3P/+SV6a5FvWu7ZA3rn2CFwAYLuax/D+g5w6em0cz+//jqo6epFrC2QAAKbk/Kp6fZKXJ7lifuzEJI9N8oZFLiCQAQCYjO5+YlX9YJLTMzxJL8k53X3eItcQyAAATEp3n5/k/EMd73WQAQDYEarqYHuWb8IKMgAAk1FVx93cqSQPW+QaAhkAgCn5WJLLMwviA3p++66LXEAgAwAwJZcl+f7u/tDaE1V1xUHufxP2IAMAMCXPT3LszZz73UUuIJABAJiM7j4nySVV9Z0HOfcHi1xDIAMAMCndfUOScw51vEAGAGCK3lhVj6iqWv+uNyaQAQCYoickeXWSz1fVp6vqmqr69CIDvYoFAACT093HHOpYgQwAwCRV1bFJTkmy58Cx7n7LeuMEMgAAk1NVP5vkSUlOSPKeJPdP8vYkD1pvrD3IAABM0ZOSfHuSy7v7gUnuk+STiwwUyAAATNG13X1tklTVbbv7r5Pcc5GBtlgAADBFV1bVnZK8LsmFVXV1kssXGSiQAQCYjKo6ubs/2N0/Mj+0r6relOSOSd6wyDVssQAAYEpekyRV9cYDB7r7ou4+t7u/sMgFrCADADAlR1XVU5Pco6p+ee3J7n7euhc4ItMCAIDlOCPJ9ZktBB9zkLd1WUEGAGAyuvsDSZ5dVe/t7vMP5RoCGQCAyaiqx3T3K5N8Q1V9/drzi2yxEMgAAEzJ0fP3tz/IuV7kAgIZAIDJ6O4XzT+8e5Indfcnk6Sqjk3y3EWu4Ul6AABM0b0OxHGSdPfVmf266XUJZAAApuio+apxkqSqjsuCuydssQAAYIqem+TtVfXq+e1HJnnWIgMFMgAAk9PdL6+q1SQPmh/60e5+/yJjBTIAAJM0D+KFonhkDzIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMBDIAAAwEMgAADAQyAAAMdi97AlO1srJyVpI9y57HLTjpcFxk97V9OC7DOm5zzfXLnsKOcdTtdy17CjvGqz5z7LKnsCOccfurlz2FHeO3bnfdsqfAYSKQj5w9q6ur+5Y9iZuzsrKyb9lzAADYimyxAACAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgIFABgCAgUAGAICBQAYAgMHuZU8AAAAOp6q6Y5KHJjl+fuiqJBd09ycXGW8FGQCAyaiqxyZ5d5IHJLnd/O2BSd41P7cuK8g717UrKyv7bs0FVldXb9V4AIBDVVV7k+wdDu3v7v1JfiPJt61dLa6qY5O8M8nL17u2QN6hVldXz172HAAADtU8hvcf5FQl6YMcv2F+bl0CGQCAKXlWkndX1Z8nuWJ+7KuTnJbkmYtcwB5kAAAmo7v/JMlKkouSfH7+9uYkK939skWuYQUZAIBJ6e6rq+pNGV7ForuvXnS8QAYAYDKq6t5J/jDJHZNcmdm+4xOq6pNJfqG7373eNQQyAABT8rIkT+jud44Hq+r+SV6a5FvWu4A9yAAATMnRa+M4Sbr7HUmOXuQCVpABAJiS86vq9Zm93vGBV7E4Mcljk7xhkQsIZAAAJqO7n1hVP5jk9Nz4V02f093nLXINgQwAwKR09/lJzj/U8fYgAwCwI1TVwX7z3k1YQQYAYDKq6ribO5XkYYtcQyADADAlH0tyeWZBfEDPb991kQsIZAAApuSyJN/f3R9ae6KqrjjI/W/CHmQAAKbk+UmOvZlzv7vIBQQyAACT0d3nJLmkqr7zIOf+YJFrCGQAACalu29Ics6hjhfIAABM0Rur6hFVVevf9cYEMgAAU/SEJK9O8vmq+nRVXVNVn15koFexAABgcrr7mEMdK5ABAJikqjo2ySlJ9hw41t1vWW/crQnka1dWVvbdivFTd9KyJwAAsFNV1c8meVKSE5K8J8n9k7w9yYPWG3vIgby6unr2oY7dCXzzAACwVE9K8u1J3tHdD6yqU5P8f4sM9CQ9AACm6NruvjZJquq23f3XSe65yEB7kAEAmKIrq+pOSV6X5MKqujrJ5YsMFMgAAExGVZ3c3R/s7h+ZH9pXVW9Kcsckb1jkGrZYAAAwJa9Jkqp644ED3X1Rd5/b3V9Y5AJWkAEAmJKjquqpSe5RVb+89mR3P2/dCxyRaQEAwHKckeT6zBaCjznI27qsIAMAMBnd/YEkz66q93b3+YdyDYEMAMBkVNVjuvuVSb6hqr5+7flFtlgIZAAApuTo+fvbH+RcL3IBgQwAwGR094vmH949yZO6+5NJUlXHJnnuItfwJD0AAKboXgfiOEm6++ok91lkoEAGAGCKjpqvGidJquq4LLh7whYLAACm6LlJ3l5Vr57ffmSSZy0yUCADADA53f3yqlpN8qD5oR/t7vcvMlYgAwAwSfMgXiiKR/YgAwDAQCADAMBAIAMAwEAgAwDAQCADAMDAq1hwqxx34aXLnsKOcN1HP7bsKewYRx9zzLKnsGM85UGPWPYUdoSnJNl9u+uWPY0d4W8e8LJlT2GHeeoRu7IVZACYMHEMGyeQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGAgkAEAYCCQAQBgIJABAGCwe9kTmLBrV1ZW9i17EkfS6urqvmXPAQDYmapqb5K9w6H93b2/qu6Y5ClJ/kWSu2p58gQAAAoDSURBVCbpJB9N8l+TnN3dn1zv2gL5CFldXT172XMAAJiq7t6fZP9BTv1pkr9I8oDu/kiSVNVXJHnc/NxD1ru2LRYAAEzJSd397ANxnCTd/ZHufnaSr1nkAgIZAIApubyq/nVV3e3Agaq6W1X9epIrFrmAQAYAYEoeneTOSS6qqk9U1SeSvDnJcUketcgF7EEGAGAyuvvqJL8+fzskVpABANgRquqnF7mfQAYAYKf4N4vcyRYLAAAmo6ree3OnktztZs7diEAGAGBK7pbkB5JcveZ4JXnbIhcQyAAATMl/S3L77n7P2hNV9eZFLiCQAQCYjO5+/C2cO3ORawhkAAAmo6qOu6Xz3f2J9a4hkAEAmJJ3JenM9hyv1Unuvt4FBDIAAJPR3Sff2mt4HWQAACanqt64yLGDsYIMAMBkVNWeJEcnuUtVHZsvbbW4Q5LjF7mGQAYAYEqekOTJSb4qs/3Ildne42uS/MEiF7DFAgCAyejuF8z3IT8ryb3nH780yWVJ3r7INQQyAABT9GPd/emq+u4kD0ry4iQvXGSgQAYAYIqun7//oSR/1N2vT3KbRQYKZAAApuiqqnpRkkcnOa+qbpsF21cgAwAwRY9KckGSH+juTyY5LsmvLTLQq1gAADA53f25JK8dbn84yYcXGWsFGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABgIZAAAGAhkAAAYCGQAABtXdy54DbKqq2tvd+5c9j53AY715PNabx2O9OTzOm8djfVNWkNmJ9i57AjuIx3rzeKw3j8d6c3icN4/Heg2BDAAAA4EMAAADgcxOZJ/V5vFYbx6P9ebxWG8Oj/Pm8Viv4Ul6AAAwsIIMAAADgQwAAAOBzI5WVW9b5/xTN2suW1VVfWad83eqql8Ybp9UVWce+ZlNy/xx+9+H+ZorVfX76/yZO/7vqqr+ZVU9dv7xqVX1nqr6y6r62g1c4xer6tKq6qq6y3C8qur35+feW1XfeiQ+h61qWY9tVT2uqv5m/va4w/tZbU9V9eSqut1w++8OPJ5V9ZKq+ujar0FVdVxVXTh/HC+sqmM3e97LIpDZ0br7O9e5y44P5AXcKckvDLdPSrKh6Kqq3YdzQsx092p3P/EW7nJSNvh3NUXd/Yfd/fL5zX+R5DXdfZ/u/ttFxlfVriRvTfLgJJevOf2DSU6Zv+1N8sLDM+vtYRmPbVUdl+TpSe6X5L5Jnr6Twu4WPDnJ7W7m3MuSPPQgx89K8sbuPiXJG+e3dwT/U2Jbq6qTkrwhyTuSfGeSi5O8NMm/SXLXJD+R5GFJvjrJ3efvn9/dvz8f/5nuvn1VfWWS/5TkDpn9d/HzSX4oyZdX1XuSvK+7f2LzPrOtqap+Lcmjktw2yX/p7qcnOTvJ184fpwuTfE+Sr5/f/pMkvz+/zwPm487p7hdV1QOSPDPJ1UlOrar7JPnTJCck2ZXkmd39nzbx09syquruSf5zkv+Q5LuSHJ1ZBPxektsk+ckkn0/ysO7+RFW9Ock7kzwws29YHt/d/2P+GP9qd/9wVX1fkhfM/4hO8r2Z/b38899Vd//bTfoUl2q+ovmrmT0O703yt0k+k+T9mUXE9VX1/d39wKp6XZITk+xJ8oIDv21s/pOVF2UWbv+qu//n/PjaP+70JC/v2TPi3zH/ictXdveHj/TnuQxb4bHN7GvNhd39ifm4CzOLv/94xD7xLaaqjs6Nv56+OslXJXlTVX28ux843r+73zL//+lap2f2eCazr+dvTvLrR2TSW4xAZgq+Lskjk/xMZoF8ZpLvTvLwzFaA35Pk1Mzi4ZgkH6iqF3b3F4drnJnkgu5+1nzF4nbzwPjF7r73Jn4uW1ZVPSSzSLtvkkpyblV9b2YrCt904HEao2x+e2+ST3X3t1fVbZO8tar+fH7Zb52P/WBVPSLJ33f3D83H3XETP70to6rumeRVSX4qyX2SfNP8/Z4klyb59e6+T1X92ySPTfL8+dDd3X3fqnpYZqtnD15z6V/NLDbeWlW3T3JtZn93//x3tRNU1Tcm+c0k39ndH5+vNj4xSbr7vKr6wySf6e7fmw/5mfk3IV+e5OKq+s/d/Y+ZfdPyzu7+lXX+yOOTXDHcvnJ+bHKBvIUe25s7vpM8NDf9evrTSR7Y3R/fwHXuNnwz95Ekdzu809y6bLFgCj7Y3Zd09w1J3pfZj4M6ySWZ/Qg5SV7f3Z+ff2H4aG76H/nFSX66qvYl+ebuvmZzpr6tPGT+9pdJ3p3ZNx2nLDjusfNVyncmufMw7n919wfnH1+S5LSqenZVfU93f+qwzn57+H+S/NckP9HdfzU/9qbuvqa7P5bkU0n+bH58/PedJK+dv3/XmuMHvDXJ86rqiUnu1N3XHea5bxcPSvLqA5FwYJXxFjyxqv4qs59SnZgv/du9PrNVfr7EY7t1HPavp/P/r+6Y1wYWyEzB54ePbxhu35Av/ZRkvM/1WfPTk+5+S2Y/cr4qycsOPKmEG6kkv9Pd956/fV13//GC435pGHdydx9YQf7sgTt19//NbEX5kiS/XVVPO+yfwdb3qSQfyuwnIAcs8u97vN9N/n0nSXefneRnk3x5Zqv4px6mOU/W/KchD07yHd39LZl9c7hnfvra7r5+gctclVn8HXDC/NiOdoQf2x3/mB/Gr6f/MN+2kvn7jx6mKW55AhmSVNXXJPmH7v6jJC/O7AtLknyxqr5seTPbUi5I8jPzH8+nqo6vqrsmuSazrSsHrL19QZKfP/A4VtU95vvjbqSqvirJ57r7lUmeky/9HewkX0jyI5mtuB/WJ89V1dfOf9Ly7Mx+YnJqbvp3tRP8RZJHVtWdk39+QtfNuWOSq7v7c/NvKO5/CH/euZn9fVZV3T+z7UaT214xt1Ue2wuSPKSqjp0/Oe8h82M7xs18PT2U/97PTXLgVUAel9lPuHYEe5Bh5gFJfq2qvpjZE0oOrCDvT/Leqnr3Tn+SXnf/eVV9fZK3z58s85kkj+nuv62qt85fHuj8zPZ9Xz//0enLMnti2ElJ3l2zgR/L7Nnsa31zkudU1Q1JvpjZEyV3nO7+bFX9cGZPeHzFYbz0k6vqgZmtPL8vs7+rGzL8Xe2EJ+l19/uq6llJLqqq6zNbufy7m7n7G5L8y6r6P0k+kNlWgIOab13510m+IrOvGed1988mOS+zJwpfmuRzme0DnaSt8tjO9zU/M7NvBJPkGQts95iag309/Y4kb6iqv1/7JL2q+o+Z/X/wLlV1ZZKnz39CeHaSP62qx2f2KiKP2sTPYan8qmkAABjYYgEAAAOBDAAAA4EMAAADgQwAAAOBDAAAA4EMAAADgQwAAIP/HxmSBciSJu94AAAAAElFTkSuQmCC\n", 47 | "text/plain": [ 48 | "
" 49 | ] 50 | }, 51 | "metadata": { 52 | "needs_background": "light" 53 | }, 54 | "output_type": "display_data" 55 | } 56 | ], 57 | "source": [ 58 | "task_similarity.plot_distance_matrix(embeddings, dataset_names)" 59 | ] 60 | } 61 | ], 62 | "metadata": { 63 | "kernelspec": { 64 | "display_name": "Python 3", 65 | "language": "python", 66 | "name": "python3" 67 | }, 68 | "language_info": { 69 | "codemirror_mode": { 70 | "name": "ipython", 71 | "version": 3 72 | }, 73 | "file_extension": ".py", 74 | "mimetype": "text/x-python", 75 | "name": "python", 76 | "nbconvert_exporter": "python", 77 | "pygments_lexer": "ipython3", 78 | "version": "3.7.3" 79 | }, 80 | "pycharm": { 81 | "stem_cell": { 82 | "cell_type": "raw", 83 | "metadata": { 84 | "collapsed": false 85 | }, 86 | "source": [] 87 | } 88 | } 89 | }, 90 | "nbformat": 4, 91 | "nbformat_minor": 4 92 | } 93 | -------------------------------------------------------------------------------- /static/distance_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/aws-cv-task2vec/c5795e55ba773f9845498091a90eee2fcba5da31/static/distance_matrix.png -------------------------------------------------------------------------------- /static/taxonomical distance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/aws-cv-task2vec/c5795e55ba773f9845498091a90eee2fcba5da31/static/taxonomical distance.png -------------------------------------------------------------------------------- /support_files/cub/final_tasks_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": 0, 3 | "1": 2, 4 | "2": 3, 5 | "3": 4, 6 | "4": 5, 7 | "5": 6, 8 | "6": 7, 9 | "7": 8, 10 | "8": 9, 11 | "9": 10, 12 | "10": 1, 13 | "11": 2, 14 | "12": 4, 15 | "13": 5, 16 | "14": 6, 17 | "15": 8, 18 | "16": 9, 19 | "17": 10, 20 | "18": 14, 21 | "19": 15, 22 | "20": 17, 23 | "21": 18, 24 | "22": 19, 25 | "23": 20, 26 | "24": 21 27 | } 28 | -------------------------------------------------------------------------------- /support_files/inat2018/final_tasks_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": 0, 3 | "1": 1, 4 | "2": 2, 5 | "3": 3, 6 | "4": 4, 7 | "5": 5, 8 | "6": 6, 9 | "7": 7, 10 | "8": 8, 11 | "9": 9, 12 | "10": 10, 13 | "11": 11, 14 | "12": 12, 15 | "13": 13, 16 | "14": 14, 17 | "15": 15, 18 | "16": 16, 19 | "17": 17, 20 | "18": 18, 21 | "19": 19, 22 | "20": 20, 23 | "21": 21, 24 | "22": 22, 25 | "23": 23, 26 | "24": 24, 27 | "25": 25, 28 | "26": 26, 29 | "27": 27, 30 | "28": 28, 31 | "29": 29, 32 | "30": 30, 33 | "31": 31, 34 | "32": 32, 35 | "33": 33, 36 | "34": 34, 37 | "35": 35, 38 | "36": 36, 39 | "37": 37, 40 | "38": 38, 41 | "39": 39, 42 | "40": 40, 43 | "41": 41, 44 | "42": 42, 45 | "43": 43, 46 | "44": 44, 47 | "45": 45, 48 | "46": 46, 49 | "47": 47, 50 | "48": 48, 51 | "49": 49, 52 | "50": 50, 53 | "51": 51, 54 | "52": 52, 55 | "53": 53, 56 | "54": 54, 57 | "55": 55, 58 | "56": 56, 59 | "57": 57, 60 | "58": 58, 61 | "59": 59, 62 | "60": 60, 63 | "61": 61, 64 | "62": 62, 65 | "63": 63, 66 | "64": 64, 67 | "65": 65, 68 | "66": 66, 69 | "67": 67, 70 | "68": 68, 71 | "69": 69, 72 | "70": 70, 73 | "71": 71, 74 | "72": 72, 75 | "73": 73, 76 | "74": 74, 77 | "75": 75, 78 | "76": 76, 79 | "77": 77, 80 | "78": 78, 81 | "79": 79, 82 | "80": 80, 83 | "81": 81, 84 | "82": 82, 85 | "83": 83, 86 | "84": 84, 87 | "85": 85, 88 | "86": 86, 89 | "87": 87, 90 | "88": 88, 91 | "89": 89, 92 | "90": 90, 93 | "91": 91, 94 | "92": 92, 95 | "93": 93, 96 | "94": 94, 97 | "95": 95, 98 | "96": 96, 99 | "97": 97, 100 | "98": 98, 101 | "99": 99, 102 | "100": 100, 103 | "101": 102, 104 | "102": 103, 105 | "103": 104, 106 | "104": 105, 107 | "105": 106, 108 | "106": 107, 109 | "107": 108, 110 | "108": 109, 111 | "109": 110, 112 | "110": 111, 113 | "111": 112, 114 | "112": 113, 115 | "113": 114, 116 | "114": 115, 117 | "115": 116, 118 | "116": 117, 119 | "117": 118, 120 | "118": 119, 121 | "119": 121, 122 | "120": 122, 123 | "121": 123, 124 | "122": 124, 125 | "123": 125, 126 | "124": 126, 127 | "125": 127, 128 | "126": 128, 129 | "127": 129, 130 | "128": 130, 131 | "129": 131, 132 | "130": 132, 133 | "131": 133, 134 | "132": 134, 135 | "133": 135, 136 | "134": 136, 137 | "135": 137, 138 | "136": 138, 139 | "137": 139, 140 | "138": 140, 141 | "139": 141, 142 | "140": 143, 143 | "141": 144, 144 | "142": 145, 145 | "143": 146, 146 | "144": 147, 147 | "145": 148, 148 | "146": 149, 149 | "147": 150, 150 | "148": 151, 151 | "149": 152, 152 | "150": 154, 153 | "151": 155, 154 | "152": 156, 155 | "153": 157, 156 | "154": 158, 157 | "155": 159, 158 | "156": 160, 159 | "157": 161, 160 | "158": 163, 161 | "159": 164, 162 | "160": 167, 163 | "161": 168, 164 | "162": 169, 165 | "163": 170, 166 | "164": 171, 167 | "165": 172, 168 | "166": 173, 169 | "167": 174, 170 | "168": 175, 171 | "169": 176, 172 | "170": 177, 173 | "171": 178, 174 | "172": 179, 175 | "173": 180, 176 | "174": 181, 177 | "175": 182, 178 | "176": 183, 179 | "177": 184, 180 | "178": 185, 181 | "179": 186, 182 | "180": 187, 183 | "181": 188, 184 | "182": 189, 185 | "183": 190, 186 | "184": 192, 187 | "185": 193, 188 | "186": 194, 189 | "187": 195, 190 | "188": 196, 191 | "189": 197, 192 | "190": 198, 193 | "191": 199, 194 | "192": 200, 195 | "193": 201, 196 | "194": 202, 197 | "195": 203, 198 | "196": 204, 199 | "197": 205, 200 | "198": 206, 201 | "199": 207, 202 | "200": 209, 203 | "201": 210, 204 | "202": 211, 205 | "203": 214, 206 | "204": 215, 207 | "205": 216, 208 | "206": 220, 209 | "207": 242 210 | } -------------------------------------------------------------------------------- /task2vec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import itertools 15 | import math 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import numpy as np 22 | from tqdm.auto import tqdm 23 | import logging 24 | import variational 25 | from torch.utils.data import DataLoader, Dataset 26 | from torch.optim.optimizer import Optimizer 27 | from utils import AverageMeter, get_error, get_device 28 | 29 | 30 | class Embedding: 31 | def __init__(self, hessian, scale, meta=None): 32 | self.hessian = np.array(hessian) 33 | self.scale = np.array(scale) 34 | self.meta = meta 35 | 36 | 37 | class ProbeNetwork(ABC, nn.Module): 38 | """Abstract class that all probe networks should inherit from. 39 | 40 | This is a standard torch.nn.Module but needs to expose a classifier property that returns the final classicifation 41 | module (e.g., the last fully connected layer). 42 | """ 43 | 44 | @property 45 | @abstractmethod 46 | def classifier(self): 47 | raise NotImplementedError("Override the classifier property to return the submodules of the network that" 48 | " should be interpreted as the classifier") 49 | 50 | @classifier.setter 51 | @abstractmethod 52 | def classifier(self, val): 53 | raise NotImplementedError("Override the classifier setter to set the submodules of the network that" 54 | " should be interpreted as the classifier") 55 | 56 | 57 | class Task2Vec: 58 | 59 | def __init__(self, model: ProbeNetwork, skip_layers=0, max_samples=None, classifier_opts=None, 60 | method='montecarlo', method_opts=None, loader_opts=None, bernoulli=False): 61 | if classifier_opts is None: 62 | classifier_opts = {} 63 | if method_opts is None: 64 | method_opts = {} 65 | if loader_opts is None: 66 | loader_opts = {} 67 | assert method in ('variational', 'montecarlo') 68 | assert skip_layers >= 0 69 | 70 | self.model = model 71 | # Fix batch norm running statistics (i.e., put batch_norm layers in eval mode) 72 | self.model.train() 73 | self.device = get_device(self.model) 74 | self.skip_layers = skip_layers 75 | self.max_samples = max_samples 76 | self.classifier_opts = classifier_opts 77 | self.method = method 78 | self.method_opts = method_opts 79 | self.loader_opts = loader_opts 80 | self.bernoulli = bernoulli 81 | self.loss_fn = nn.CrossEntropyLoss() if not self.bernoulli else nn.BCEWithLogitsLoss() 82 | self.loss_fn = self.loss_fn.to(self.device) 83 | 84 | def embed(self, dataset: Dataset): 85 | # Cache the last layer features (needed to train the classifier) and (if needed) the intermediate layer features 86 | # so that we can skip the initial layers when computing the embedding 87 | if self.skip_layers > 0: 88 | self._cache_features(dataset, indexes=(self.skip_layers, -1), loader_opts=self.loader_opts, 89 | max_samples=self.max_samples) 90 | else: 91 | self._cache_features(dataset, max_samples=self.max_samples) 92 | # Fits the last layer classifier using cached features 93 | self._fit_classifier(**self.classifier_opts) 94 | 95 | if self.skip_layers > 0: 96 | dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features, 97 | self.model.layers[-1].targets) 98 | self.compute_fisher(dataset) 99 | embedding = self.extract_embedding(self.model) 100 | return embedding 101 | 102 | def montecarlo_fisher(self, dataset: Dataset, epochs: int = 1): 103 | logging.info("Using montecarlo Fisher") 104 | if self.skip_layers > 0: 105 | dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features, 106 | self.model.layers[-1].targets) 107 | data_loader = _get_loader(dataset, **self.loader_opts) 108 | device = get_device(self.model) 109 | logging.info("Computing Fisher...") 110 | 111 | for p in self.model.parameters(): 112 | p.grad2_acc = torch.zeros_like(p.data) 113 | p.grad_counter = 0 114 | for k in range(epochs): 115 | logging.info(f"\tepoch {k + 1}/{epochs}") 116 | for i, (data, target) in enumerate(tqdm(data_loader, leave=False, desc="Computing Fisher")): 117 | data = data.to(device) 118 | output = self.model(data, start_from=self.skip_layers) 119 | # The gradients used to compute the FIM needs to be for y sampled from 120 | # the model distribution y ~ p_w(y|x), not for y from the dataset 121 | if self.bernoulli: 122 | target = torch.bernoulli(F.sigmoid(output)).detach() 123 | else: 124 | target = torch.multinomial(F.softmax(output, dim=-1), 1).detach().view(-1) 125 | loss = self.loss_fn(output, target) 126 | self.model.zero_grad() 127 | loss.backward() 128 | for p in self.model.parameters(): 129 | if p.grad is not None: 130 | p.grad2_acc += p.grad.data ** 2 131 | p.grad_counter += 1 132 | for p in self.model.parameters(): 133 | if p.grad_counter == 0: 134 | del p.grad2_acc 135 | else: 136 | p.grad2_acc /= p.grad_counter 137 | logging.info("done") 138 | 139 | def _run_epoch(self, data_loader: DataLoader, model: ProbeNetwork, loss_fn, 140 | optimizer: Optimizer, epoch: int, train: bool = True, 141 | add_compression_loss: bool = False, skip_layers=0, beta=1.0e-7): 142 | metrics = AverageMeter() 143 | device = get_device(model) 144 | 145 | for i, (input, target) in enumerate(tqdm(data_loader, leave=False, desc="Computing Fisher")): 146 | input = input.to(device) 147 | target = target.to(device) 148 | output = model(input, start_from=skip_layers) 149 | 150 | loss = loss_fn(output, target) 151 | lz = beta * variational.get_compression_loss(model) if add_compression_loss else torch.zeros_like(loss) 152 | loss += lz 153 | 154 | error = get_error(output, target) 155 | 156 | metrics.update(n=input.size(0), loss=loss.item(), lz=lz.item(), error=error) 157 | if train: 158 | optimizer.zero_grad() 159 | loss.backward() 160 | optimizer.step() 161 | # logging.info( 162 | print( 163 | "{}: [{epoch}] ".format('Epoch' if train else '', epoch=epoch) + 164 | "Data/Batch: {:.3f}/{:.3f} ".format(metrics.avg["data_time"], metrics.avg["batch_time"]) + 165 | "Loss {:.3f} Lz: {:.3f} ".format(metrics.avg["loss"], metrics.avg["lz"]) + 166 | "Error: {:.2f}".format(metrics.avg["error"]) 167 | ) 168 | return metrics.avg 169 | 170 | def variational_fisher(self, dataset: Dataset, epochs=1, beta=1e-7): 171 | logging.info("Training variational fisher...") 172 | parameters = [] 173 | for layer in self.model.layers[self.skip_layers:-1]: 174 | if isinstance(layer, nn.Module): # Skip lambda functions 175 | variational.make_variational(layer) 176 | parameters += variational.get_variational_vars(layer) 177 | bn_params = [] 178 | # Allows batchnorm parameters to change 179 | for m in self.model.modules(): 180 | if isinstance(m, nn.BatchNorm2d): 181 | bn_params += list(m.parameters()) 182 | # Avoids computing the gradients wrt to the weights to save time and memory 183 | for p in self.model.parameters(): 184 | if p not in set(parameters) and p not in set(self.model.classifier.parameters()): 185 | p.old_requires_grad = p.requires_grad 186 | p.requires_grad = False 187 | 188 | optimizer = torch.optim.Adam([ 189 | {'params': parameters}, 190 | {'params': bn_params, 'lr': 5e-4}, 191 | {'params': self.model.classifier.parameters(), 'lr': 5e-4}], 192 | lr=1e-2, betas=(.9, 0.999)) 193 | if self.skip_layers > 0: 194 | dataset = torch.utils.data.TensorDataset(self.model.layers[self.skip_layers].input_features, 195 | self.model.layers[-1].targets) 196 | train_loader = _get_loader(dataset, **self.loader_opts) 197 | 198 | for epoch in range(epochs): 199 | self._run_epoch(train_loader, self.model, self.loss_fn, optimizer, epoch, beta=beta, 200 | add_compression_loss=True, train=True) 201 | 202 | # Resets original value of requires_grad 203 | for p in self.model.parameters(): 204 | if hasattr(p, 'old_requires_grad'): 205 | p.requires_grad = p.old_requires_grad 206 | del p.old_requires_grad 207 | 208 | def compute_fisher(self, dataset: Dataset): 209 | """ 210 | Computes the Fisher Information of the weights of the model wrt the model output on the dataset and stores it. 211 | 212 | The Fisher Information Matrix is defined as: 213 | F = E_{x ~ dataset} E_{y ~ p_w(y|x)} [\nabla_w log p_w(y|x) \nabla_w log p_w(y|x)^t] 214 | where p_w(y|x) is the output probability vector of the network and w are the weights of the network. 215 | Notice that the label y is sampled from the model output distribution and not from the dataset. 216 | 217 | This code only approximate the diagonal of F. The result is stored in the model layers and can be extracted 218 | using the `get_fisher` method. Different approximation methods of the Fisher information matrix are available, 219 | and can be selected in the __init__. 220 | 221 | :param dataset: dataset with the task to compute the Fisher on 222 | """ 223 | if self.method == 'variational': 224 | fisher_fn = self.variational_fisher 225 | elif self.method == 'montecarlo': 226 | fisher_fn = self.montecarlo_fisher 227 | else: 228 | raise ValueError(f"Invalid Fisher method {self.method}") 229 | fisher_fn(dataset, **self.method_opts) 230 | 231 | def _cache_features(self, dataset: Dataset, indexes=(-1,), max_samples=None, loader_opts: dict = None): 232 | logging.info("Caching features...") 233 | if loader_opts is None: 234 | loader_opts = {} 235 | data_loader = DataLoader(dataset, shuffle=False, batch_size=loader_opts.get('batch_size', 64), 236 | num_workers=loader_opts.get('num_workers', 6), drop_last=False) 237 | 238 | device = next(self.model.parameters()).device 239 | 240 | def _hook(layer, inputs): 241 | if not hasattr(layer, 'input_features'): 242 | layer.input_features = [] 243 | layer.input_features.append(inputs[0].data.cpu().clone()) 244 | 245 | hooks = [self.model.layers[index].register_forward_pre_hook(_hook) 246 | for index in indexes] 247 | if max_samples is not None: 248 | n_batches = min( 249 | math.floor(max_samples / data_loader.batch_size) - 1, len(data_loader)) 250 | else: 251 | n_batches = len(data_loader) 252 | targets = [] 253 | 254 | for i, (input, target) in tqdm(enumerate(itertools.islice(data_loader, 0, n_batches)), total=n_batches, 255 | leave=False, 256 | desc="Caching features"): 257 | targets.append(target.clone()) 258 | self.model(input.to(device)) 259 | for hook in hooks: 260 | hook.remove() 261 | for index in indexes: 262 | self.model.layers[index].input_features = torch.cat(self.model.layers[index].input_features) 263 | self.model.layers[-1].targets = torch.cat(targets) 264 | 265 | def _fit_classifier(self, optimizer='adam', learning_rate=0.0004, weight_decay=0.0001, 266 | epochs=10): 267 | """Fits the last layer of the network using the cached features.""" 268 | logging.info("Fitting final classifier...") 269 | if not hasattr(self.model.classifier, 'input_features'): 270 | raise ValueError("You need to run `cache_features` on model before running `fit_classifier`") 271 | targets = self.model.classifier.targets.to(self.device) 272 | features = self.model.classifier.input_features.to(self.device) 273 | 274 | dataset = torch.utils.data.TensorDataset(features, targets) 275 | data_loader = _get_loader(dataset, **self.loader_opts) 276 | 277 | if optimizer == 'adam': 278 | optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=learning_rate, weight_decay=weight_decay) 279 | elif optimizer == 'sgd': 280 | optimizer = torch.optim.SGD(self.model.fc.parameters(), lr=learning_rate, weight_decay=weight_decay) 281 | else: 282 | raise ValueError(f'Unsupported optimizer {optimizer}') 283 | 284 | loss_fn = nn.CrossEntropyLoss() 285 | for epoch in tqdm(range(epochs), desc="Fitting classifier", leave=False): 286 | metrics = AverageMeter() 287 | for data, target in data_loader: 288 | optimizer.zero_grad() 289 | output = self.model.classifier(data) 290 | loss = loss_fn(self.model.classifier(data), target) 291 | error = get_error(output, target) 292 | loss.backward() 293 | optimizer.step() 294 | metrics.update(n=data.size(0), loss=loss.item(), error=error) 295 | logging.info(f"[epoch {epoch}]: " + "\t".join(f"{k}: {v}" for k, v in metrics.avg.items())) 296 | 297 | def extract_embedding(self, model: ProbeNetwork): 298 | """ 299 | Reads the values stored by `compute_fisher` and returns them in a common format that describes the diagonal of the 300 | Fisher Information Matrix for each layer. 301 | 302 | :param model: 303 | :return: 304 | """ 305 | hess, scale = [], [] 306 | for name, module in model.named_modules(): 307 | if module is model.classifier: 308 | continue 309 | # The variational Fisher approximation estimates the variance of noise that can be added to the weights 310 | # without increasing the error more than a threshold. The inverse of this is proportional to an 311 | # approximation of the hessian in the local minimum. 312 | if hasattr(module, 'logvar0') and hasattr(module, 'loglambda2'): 313 | logvar = module.logvar0.view(-1).detach().cpu().numpy() 314 | hess.append(np.exp(-logvar)) 315 | loglambda2 = module.loglambda2.detach().cpu().numpy() 316 | scale.append(np.exp(-loglambda2).repeat(logvar.size)) 317 | # The other Fisher approximation methods directly approximate the hessian at the minimum 318 | elif hasattr(module, 'weight') and hasattr(module.weight, 'grad2_acc'): 319 | grad2 = module.weight.grad2_acc.cpu().detach().numpy() 320 | filterwise_hess = grad2.reshape(grad2.shape[0], -1).mean(axis=1) 321 | hess.append(filterwise_hess) 322 | scale.append(np.ones_like(filterwise_hess)) 323 | return Embedding(hessian=np.concatenate(hess), scale=np.concatenate(scale), meta=None) 324 | 325 | 326 | def _get_loader(trainset, testset=None, batch_size=64, num_workers=6, num_samples=10000, drop_last=True): 327 | if getattr(trainset, 'is_multi_label', False): 328 | raise ValueError("Multi-label datasets not supported") 329 | # TODO: Find a way to standardize this 330 | if hasattr(trainset, 'labels'): 331 | labels = trainset.labels 332 | elif hasattr(trainset, 'targets'): 333 | labels = trainset.targets 334 | else: 335 | labels = list(trainset.tensors[1].cpu().numpy()) 336 | num_classes = int(getattr(trainset, 'num_classes', max(labels) + 1)) 337 | class_count = np.eye(num_classes)[labels].sum(axis=0) 338 | weights = 1. / class_count[labels] / num_classes 339 | weights /= weights.sum() 340 | 341 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=num_samples) 342 | # No need for mutli-threaded loading if everything is already in memory, 343 | # and would raise an error if TensorDataset is on CUDA 344 | num_workers = num_workers if not isinstance(trainset, torch.utils.data.TensorDataset) else 0 345 | trainloader = torch.utils.data.DataLoader(trainset, sampler=sampler, batch_size=batch_size, 346 | num_workers=num_workers, drop_last=drop_last) 347 | 348 | if testset is None: 349 | return trainloader 350 | else: 351 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, pin_memory=True, shuffle=False, 352 | num_workers=num_workers) 353 | return trainloader, testloader 354 | -------------------------------------------------------------------------------- /task_similarity.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"). You 6 | # may not use this file except in compliance with the License. A copy of 7 | # the License is located at 8 | # 9 | # http://aws.amazon.com/apache2.0/ 10 | # 11 | # or in the "license" file accompanying this file. This file is 12 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 13 | # ANY KIND, either express or implied. See the License for the specific 14 | # language governing permissions and limitations under the License. 15 | 16 | import itertools 17 | import scipy.spatial.distance as distance 18 | import numpy as np 19 | import copy 20 | import pickle 21 | 22 | _DISTANCES = {} 23 | 24 | 25 | # TODO: Remove methods that do not perform well 26 | 27 | def _register_distance(distance_fn): 28 | _DISTANCES[distance_fn.__name__] = distance_fn 29 | return distance_fn 30 | 31 | 32 | def is_excluded(k): 33 | exclude = ['fc', 'linear'] 34 | return any([e in k for e in exclude]) 35 | 36 | 37 | def load_embedding(filename): 38 | with open(filename, 'rb') as f: 39 | e = pickle.load(f) 40 | return e 41 | 42 | 43 | def get_trivial_embedding_from(e): 44 | trivial_embedding = copy.deepcopy(e) 45 | for l in trivial_embedding['layers']: 46 | a = np.array(l['filter_logvar']) 47 | a[:] = l['filter_lambda2'] 48 | l['filter_logvar'] = list(a) 49 | return trivial_embedding 50 | 51 | 52 | def binary_entropy(p): 53 | from scipy.special import xlogy 54 | return - (xlogy(p, p) + xlogy(1. - p, 1. - p)) 55 | 56 | 57 | def get_layerwise_variance(e, normalized=False): 58 | var = [np.exp(l['filter_logvar']) for l in e['layers']] 59 | if normalized: 60 | var = [v / np.linalg.norm(v) for v in var] 61 | return var 62 | 63 | 64 | def get_variance(e, normalized=False): 65 | var = 1. / np.array(e.hessian) 66 | if normalized: 67 | lambda2 = 1. / np.array(e.scale) 68 | var = var / lambda2 69 | return var 70 | 71 | 72 | def get_variances(*embeddings, normalized=False): 73 | return [get_variance(e, normalized=normalized) for e in embeddings] 74 | 75 | 76 | def get_hessian(e, normalized=False): 77 | hess = np.array(e.hessian) 78 | if normalized: 79 | scale = np.array(e.scale) 80 | hess = hess / scale 81 | return hess 82 | 83 | 84 | def get_hessians(*embeddings, normalized=False): 85 | return [get_hessian(e, normalized=normalized) for e in embeddings] 86 | 87 | 88 | def get_scaled_hessian(e0, e1): 89 | h0, h1 = get_hessians(e0, e1, normalized=False) 90 | return h0 / (h0 + h1 + 1e-8), h1 / (h0 + h1 + 1e-8) 91 | 92 | 93 | def get_full_kl(e0, e1): 94 | var0, var1 = get_variance(e0), get_variance(e1) 95 | kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0)) 96 | kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1)) 97 | return kl0, kl1 98 | 99 | 100 | def layerwise_kl(e0, e1): 101 | layers0, layers1 = get_layerwise_variance(e0), get_layerwise_variance(e1) 102 | kl0 = [] 103 | for var0, var1 in zip(layers0, layers1): 104 | kl0.append(np.sum(.5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0)))) 105 | return kl0 106 | 107 | 108 | def layerwise_cosine(e0, e1): 109 | layers0, layers1 = get_layerwise_variance(e0, normalized=True), get_layerwise_variance(e1, normalized=True) 110 | res = [] 111 | for var0, var1 in zip(layers0, layers1): 112 | res.append(distance.cosine(var0, var1)) 113 | return res 114 | 115 | 116 | @_register_distance 117 | def kl(e0, e1): 118 | var0, var1 = get_variance(e0), get_variance(e1) 119 | kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0)) 120 | kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1)) 121 | return np.maximum(kl0, kl1).sum() 122 | 123 | 124 | @_register_distance 125 | def asymmetric_kl(e0, e1): 126 | var0, var1 = get_variance(e0), get_variance(e1) 127 | kl0 = .5 * (var0 / var1 - 1 + np.log(var1) - np.log(var0)) 128 | kl1 = .5 * (var1 / var0 - 1 + np.log(var0) - np.log(var1)) 129 | return kl0.sum() 130 | 131 | 132 | @_register_distance 133 | def jsd(e0, e1): 134 | var0, var1 = get_variance(e0), get_variance(e1) 135 | var = .5 * (var0 + var1) 136 | kl0 = .5 * (var0 / var - 1 + np.log(var) - np.log(var0)) 137 | kl1 = .5 * (var1 / var - 1 + np.log(var) - np.log(var1)) 138 | return (.5 * (kl0 + kl1)).mean() 139 | 140 | 141 | @_register_distance 142 | def cosine(e0, e1): 143 | h1, h2 = get_scaled_hessian(e0, e1) 144 | return distance.cosine(h1, h2) 145 | 146 | 147 | @_register_distance 148 | def normalized_cosine(e0, e1): 149 | h1, h2 = get_variances(e0, e1, normalized=True) 150 | return distance.cosine(h1, h2) 151 | 152 | 153 | @_register_distance 154 | def correlation(e0, e1): 155 | v1, v2 = get_variances(e0, e1, normalized=False) 156 | return distance.correlation(v1, v2) 157 | 158 | 159 | @_register_distance 160 | def entropy(e0, e1): 161 | h1, h2 = get_scaled_hessian(e0, e1) 162 | return np.log(2) - binary_entropy(h1).mean() 163 | 164 | 165 | def get_normalized_embeddings(embeddings, normalization=None): 166 | F = [1. / get_variance(e, normalized=False) if e is not None else None for e in embeddings] 167 | zero_embedding = np.zeros_like([x for x in F if x is not None][0]) 168 | F = np.array([x if x is not None else zero_embedding for x in F]) 169 | # FIXME: compute variance using only valid embeddings 170 | if normalization is None: 171 | normalization = np.sqrt((F ** 2).mean(axis=0, keepdims=True)) 172 | F /= normalization 173 | return F, normalization 174 | 175 | 176 | def pdist(embeddings, distance='cosine'): 177 | distance_fn = _DISTANCES[distance] 178 | n = len(embeddings) 179 | distance_matrix = np.zeros([n, n]) 180 | if distance != 'asymmetric_kl': 181 | for (i, e1), (j, e2) in itertools.combinations(enumerate(embeddings), 2): 182 | distance_matrix[i, j] = distance_fn(e1, e2) 183 | distance_matrix[j, i] = distance_matrix[i, j] 184 | else: 185 | for (i, e1) in enumerate(embeddings): 186 | for (j, e2) in enumerate(embeddings): 187 | distance_matrix[i, j] = distance_fn(e1, e2) 188 | return distance_matrix 189 | 190 | 191 | def cdist(from_embeddings, to_embeddings, distance='cosine'): 192 | distance_fn = _DISTANCES[distance] 193 | distance_matrix = np.zeros([len(from_embeddings), len(to_embeddings)]) 194 | for (i, e1) in enumerate(from_embeddings): 195 | for (j, e2) in enumerate(to_embeddings): 196 | if e1 is None or e2 is None: 197 | continue 198 | distance_matrix[i, j] = distance_fn(e1, e2) 199 | return distance_matrix 200 | 201 | 202 | def plot_distance_matrix(embeddings, labels=None, distance='cosine'): 203 | import seaborn as sns 204 | from scipy.cluster.hierarchy import linkage 205 | from scipy.spatial.distance import squareform 206 | import pandas as pd 207 | import matplotlib.pyplot as plt 208 | distance_matrix = pdist(embeddings, distance=distance) 209 | cond_distance_matrix = squareform(distance_matrix, checks=False) 210 | linkage_matrix = linkage(cond_distance_matrix, method='complete', optimal_ordering=True) 211 | if labels is not None: 212 | distance_matrix = pd.DataFrame(distance_matrix, index=labels, columns=labels) 213 | sns.clustermap(distance_matrix, row_linkage=linkage_matrix, col_linkage=linkage_matrix, cmap='viridis_r') 214 | plt.show() 215 | 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | from collections import defaultdict 15 | import torch 16 | import numpy as np 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value""" 21 | 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = defaultdict(int) 27 | self.avg = defaultdict(float) 28 | self.sum = defaultdict(int) 29 | self.count = defaultdict(int) 30 | 31 | def update(self, n=1, **val): 32 | for k in val: 33 | self.val[k] = val[k] 34 | self.sum[k] += val[k] * n 35 | self.count[k] += n 36 | self.avg[k] = self.sum[k] / self.count[k] 37 | 38 | 39 | def set_batchnorm_mode(model, train=True): 40 | """Allows to set batch_norm layer mode to train or eval, independendtly on the mode of the model.""" 41 | def _set_batchnorm_mode(module): 42 | if isinstance(module, torch.nn.BatchNorm1d) or isinstance(module, torch.nn.BatchNorm2d): 43 | if train: 44 | module.train() 45 | else: 46 | module.eval() 47 | 48 | model.apply(_set_batchnorm_mode) 49 | 50 | 51 | def get_error(output, target): 52 | pred = output.argmax(dim=1) 53 | correct = pred.eq(target).float().sum() 54 | return float((1. - correct / output.size(0)) * 100.) 55 | 56 | 57 | def adjust_learning_rate(optimizer, epoch, optimizer_cfg): 58 | lr = optimizer_cfg.lr * (0.1 ** np.less(optimizer_cfg.schedule, epoch).sum()) 59 | for param_group in optimizer.param_groups: 60 | param_group['lr'] = lr 61 | 62 | 63 | def get_device(model: torch.nn.Module): 64 | return next(model.parameters()).device 65 | -------------------------------------------------------------------------------- /variational.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from torch.nn.parameter import Parameter 18 | 19 | import types 20 | 21 | 22 | def get_variational_vars(model): 23 | """Returns all variables involved in optimizing the hessian estimation.""" 24 | result = [] 25 | if hasattr(model, 'logvar0'): 26 | result.append(model.logvar0) 27 | result.append(model.loglambda2) 28 | for l in model.children(): 29 | result += get_variational_vars(l) 30 | return result 31 | 32 | 33 | def get_compression_loss(model): 34 | """Get the model loss function for hessian estimation. 35 | 36 | Compute KL divergence assuming a normal posterior and a diagonal normal prior p(w) ~ N(0, lambda**2 * I) 37 | (where lambda is selected independently for each layer and shared by all filters in the same layer). 38 | Recall from the paper that the optimal posterior q(w|D) that minimizes the training loss plus the compression lost 39 | is approximatively given by q(w|D) ~ N(w, F**-1), where F is the Fisher information matrix. 40 | """ 41 | modules = [x for x in model.modules() if hasattr(x, 'logvar0')] 42 | k = sum([x.weight.numel() for x in modules]) 43 | 44 | w_norm2 = sum([x.weight.pow(2).sum() / x.loglambda2.exp() for x in modules]) 45 | logvar = sum([x.logvar.sum() for x in modules]) 46 | trace = sum([x.logvar.exp().sum() / x.loglambda2.exp() for x in modules]) 47 | lambda2_cost = sum([x.loglambda2 * x.weight.numel() for x in modules]) 48 | 49 | # Standard formula for KL divergence of two normal distributions 50 | # https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence 51 | Lz = kl_divergence = w_norm2 + trace + lambda2_cost - logvar - k 52 | return Lz 53 | 54 | 55 | def variational_forward(module, input): 56 | """Modified forward pass that adds noise to the output.""" 57 | 58 | # Recall that module.logvar0 is created by make_variational() 59 | # (specifically, by add_logvar()) 60 | module.logvar = module.logvar0.expand_as(module.weight) 61 | 62 | var = module.logvar.exp() 63 | 64 | if isinstance(module, torch.nn.modules.conv.Conv2d): 65 | output = F.conv2d(input, module.weight, module.bias, module.stride, 66 | module.padding, module.dilation, module.groups) 67 | # From Variational Dropout and the Local reparametrization trick 68 | # (Kingma et al., 2015) 69 | output_var = F.conv2d(input ** 2 + 1e-2, var, None, module.stride, 70 | module.padding, module.dilation, module.groups) 71 | elif isinstance(module, torch.nn.modules.linear.Linear): 72 | output = F.linear(input, module.weight, module.bias) 73 | output_var = F.linear(input ** 2 + 1e-2, var, None) 74 | else: 75 | raise NotImplementedError("Module {} not implemented.".format(type(module))) 76 | 77 | eps = torch.empty_like(output).normal_() 78 | # Local reparemetrization trick 79 | return output + torch.sqrt(output_var) * eps 80 | 81 | 82 | def _reset_logvar(module, variance_scaling=0.05): 83 | if hasattr(module, 'logvar0'): 84 | w = module.weight.data 85 | # Initial ballpark estimate for optimal variance is the variance 86 | # of the weights in the kernel 87 | var = w.view(w.size(0), -1).var(dim=1).view(-1, *([1] * (w.ndimension() - 1))) # .expand_as(w) 88 | # Further scale down the variance by some factor 89 | module.logvar0.data[:] = (var * variance_scaling + 1e-8).log() 90 | # Initial guess for lambda is the l2 norm of the weights 91 | module.loglambda2.data = (w.pow(2).mean() + 1e-8).log() 92 | 93 | 94 | def _add_logvar(module): 95 | """Adds a parameter (logvar0) to store the noise variance for the weights. 96 | 97 | Also adds a scalar parameter loglambda2 to store the scaling coefficient 98 | for the layer. 99 | 100 | The variance is assumed to be the same for all weights in the same filter. 101 | The common value is stored in logvar0, which is expanded to the same 102 | dimension as the weight matrix in logvar. 103 | """ 104 | if not hasattr(module, 'weight'): 105 | return 106 | if module.weight.data.ndimension() < 2: 107 | return 108 | if not hasattr(module, 'logvar0'): 109 | w = module.weight.data 110 | # w is of shape NUM_OUT x NUM_IN x K_h X K_w 111 | var = w.view(w.size(0), -1).var(dim=1).view(-1, *([1] * (w.ndimension() - 1))) 112 | # var is of shape NUM_OUT x 1 x 1 x 1 113 | # (so that it can be expanded to the same size as w by torch.expand_as()) 114 | # The content does not matter since we will reset it later anyway 115 | module.logvar0 = Parameter(var.log()) 116 | # log(lambda**2) is a scalar shared by all weights in the layer 117 | module.loglambda2 = Parameter(w.pow(2).mean().log()) 118 | module.logvar = module.logvar0.expand_as(module.weight) 119 | _reset_logvar(module) 120 | 121 | 122 | def make_variational(model): 123 | """Replaces the forward pass of the model layers to add noise.""" 124 | model.apply(_add_logvar) 125 | for m in model.modules(): 126 | if hasattr(m, 'logvar0'): 127 | m.forward = types.MethodType(variational_forward, m) 128 | --------------------------------------------------------------------------------