├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── embedding_prop.jpeg ├── embedding_propagation ├── __init__.py ├── batch_embedding_propagation.py └── embedding_propagation.py ├── exp_configs ├── __init__.py ├── finetune_exps.py ├── pretrain_exps.py └── ssl_exps.py ├── requirements.txt ├── setup.py ├── src ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cub.py │ ├── episodic_cub.py │ ├── episodic_dataset.py │ ├── episodic_miniimagenet.py │ ├── episodic_tiered_imagenet.py │ ├── miniimagenet.py │ └── tiered_imagenet.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── conv4.py │ │ ├── resnet12.py │ │ └── wrn.py │ ├── base_ssl │ │ ├── __init__.py │ │ ├── distances.py │ │ ├── oracle.py │ │ ├── predict_methods │ │ │ ├── __init__.py │ │ │ ├── adaptive.py │ │ │ ├── label_prop.py │ │ │ └── prototypical.py │ │ ├── selection_methods │ │ │ ├── __init__.py │ │ │ └── ssl.py │ │ └── utils.py │ ├── base_wrapper.py │ ├── finetuning.py │ ├── pretraining.py │ └── ssl_wrapper.py ├── modules │ ├── __init__.py │ └── distances.py ├── tools │ ├── __init__.py │ ├── meters.py │ └── plot_episode.py └── utils.py └── trainval.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | results/ 6 | user_configs.py 7 | user_config.py 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | *.vscode 109 | *.npy 110 | *.png 111 | *.jpg 112 | *.svg 113 | *.eps 114 | *.pdf 115 | 116 | data/ 117 | user_config.py 118 | 119 | *_eai.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 ServiceNow 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2021 ServiceNow, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | *ServiceNow completed its acquisition of Element AI on January 8, 2021. All references to Element AI in the materials that are part of this project should refer to ServiceNow.* 2 | 3 | 4 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) 5 | 6 | 7 |

Embedding Propagation

8 |
Smoother Manifold for Few-Shot Classification [Paper] (ECCV2020)
9 | 10 | 11 | 12 | Embedding propagation can be used to regularize the intermediate features so that generalization performance is improved. 13 | 14 | ![](embedding_prop.jpeg) 15 | 16 | ## Usage 17 | 18 | Add an embedding propagation layer to your network. 19 | 20 | ``` 21 | pip install git+https://github.com/ElementAI/embedding-propagation 22 | ``` 23 | 24 | ```python 25 | import torch 26 | from embedding_propagation import EmbeddingPropagation 27 | 28 | ep = EmbeddingPropagation() 29 | features = torch.randn(32, 32) 30 | embeddings = ep(features) 31 | ``` 32 | 33 | ## Experiments 34 | 35 | Generate the results from the [Paper]. 36 | 37 | ### Install requirements 38 | 39 | `pip install -r requirements.txt` 40 | 41 | This command installs the [Haven library](https://github.com/haven-ai/haven-ai) which helps in managing the experiments. 42 | 43 | ### Download the Datasets 44 | 45 | * [mini-imagenet](https://github.com/renmengye/few-shot-ssl-public#miniimagenet) ([pre-processing](https://github.com/ElementAI/TADAM/tree/master/datasets)) 46 | * [tiered-imagenet](https://github.com/renmengye/few-shot-ssl-public#tieredimagenet) 47 | * [CUB](https://github.com/wyharveychen/CloserLookFewShot/tree/master/filelists/CUB) 48 | 49 | If you have the `pkl` version of miniimagenet, you can still use it by setting the dataset name to "episodic_miniimagenet_pkl", in each of the files in `exp_configs`. 50 | 51 | 52 | 53 | ### Reproduce the results in the paper 54 | 55 | #### 1. Pre-training 56 | 57 | ``` 58 | python3 trainval.py -e pretrain -sb ./logs/pretraining -d 59 | ``` 60 | where `` is the directory where the data is saved. 61 | 62 | #### 2. Fine-tuning 63 | 64 | In `exp_configs/finetune_exps.py`, set `"pretrained_weights_root": ./logs/pretraining/` 65 | 66 | ``` 67 | python3 trainval.py -e finetune -sb ./logs/finetuning -d 68 | ``` 69 | 70 | #### 3. SSL experirments with 100 unlabeled 71 | 72 | In `exp_configs/ssl_exps.py`, set `"pretrained_weights_root": ./logs/finetuning/` 73 | 74 | ``` 75 | python3 trainval.py -e ssl_large -sb ./logs/ssl/ -d 76 | ``` 77 | 78 | #### 4. SSL experirments with 20-100% unlabeled 79 | 80 | In `exp_configs/ssl_exps.py`, set `"pretrained_weights_root": ./logs/finetuning/` 81 | 82 | ``` 83 | python3 trainval.py -e ssl_small -sb ./logs/ssl/ -d 84 | ``` 85 | 86 | ### Results 87 | 88 | |dataset|model|1-shot|5-shot| 89 | |-------|-----|------|------| 90 | |episodic_cub|conv4|65.94 ± 0.93|78.80 ± 0.64| 91 | |episodic_cub|resnet12|81.32 ± 0.84|91.02 ± 0.44| 92 | |episodic_cub|wrn|87.48 ± 0.68|93.74 ± 0.35| 93 | |episodic_miniimagenet|conv4|57.41 ± 0.85|72.35 ± 0.62| 94 | |episodic_miniimagenet|resnet12|64.82 ± 0.89|80.59 ± 0.64| 95 | |episodic_miniimagenet|wrn|69.92 ± 0.81|83.64 ± 0.54| 96 | |episodic_tiered-imagenet|conv4|58.63 ± 0.92|72.80 ± 0.78| 97 | |episodic_tiered-imagenet|resnet12|75.90 ± 0.90|86.83 ± 0.58| 98 | |episodic_tiered-imagenet|wrn|78.46 ± 0.90|87.46 ± 0.62| 99 | 100 | Different from the paper, these results were obtained on a run with fixed hyperparameters during fine-tuning: lr=0.001, alpha=0.2 (now default), train_iters=600, classification_weight=0.1 101 | 102 | ### Pre-trained weights 103 | https://zenodo.org/record/5552602#.YV2b-UbMKvU 104 | 105 | ## Citation 106 | ``` 107 | @article{rodriguez2020embedding, 108 | title={Embedding Propagation: Smoother Manifold for Few-Shot Classification}, 109 | author={Pau Rodríguez and Issam Laradji and Alexandre Drouin and Alexandre Lacoste}, 110 | year={2020}, 111 | journal={arXiv preprint arXiv:2003.04151}, 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /embedding_prop.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/embedding_prop.jpeg -------------------------------------------------------------------------------- /embedding_propagation/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding_propagation import * 2 | from .batch_embedding_propagation import * 3 | -------------------------------------------------------------------------------- /embedding_propagation/batch_embedding_propagation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | class BatchEmbeddingPropagation(torch.nn.Module): 7 | def __init__(self, alpha=0.5, rbf_scale=1, norm_prop=False): 8 | super().__init__() 9 | self.alpha = alpha 10 | self.rbf_scale = rbf_scale 11 | self.norm_prop = norm_prop 12 | 13 | def forward(self, x, propagator=None): 14 | return batch_embedding_propagation(x, self.alpha, self.rbf_scale, self.norm_prop, propagator=propagator) 15 | 16 | 17 | class BatchLabelPropagation(torch.nn.Module): 18 | def __init__(self, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True, balanced=False): 19 | super().__init__() 20 | self.alpha = alpha 21 | self.rbf_scale = rbf_scale 22 | self.norm_prop = norm_prop 23 | self.apply_log = apply_log 24 | self.balanced = balanced 25 | 26 | def forward(self, x, labels, nclasses, propagator=None): 27 | """Applies label propagation given a set of embeddings and labels 28 | 29 | Arguments: 30 | x {Tensor} -- Input embeddings 31 | labels {Tensor} -- Input labels from 0 to nclasses + 1. The highest value corresponds to unlabeled samples. 32 | nclasses {int} -- Total number of classes 33 | 34 | Keyword Arguments: 35 | propagator {Tensor} -- A pre-computed propagator (default: {None}) 36 | 37 | Returns: 38 | tuple(Tensor, Tensor) -- Logits and Propagator 39 | """ 40 | return batch_label_propagation(x, labels, nclasses, self.alpha, self.rbf_scale, 41 | self.norm_prop, self.apply_log, propagator=propagator, 42 | balanced=self.balanced) 43 | 44 | def batch_get_similarity_matrix(x, rbf_scale): 45 | b, e, c = x.size() 46 | sq_dist = ((x.view(b, e, 1, c) - x.view(b, 1, e, c))**2).sum(-1) / np.sqrt(c) 47 | mask = sq_dist != 0 48 | sq_dist = sq_dist / sq_dist[mask].std() 49 | weights = torch.exp(-sq_dist * rbf_scale) 50 | mask = torch.eye(weights.size(-1), dtype=torch.bool, device=weights.device) 51 | weights = weights * (~mask).float() 52 | return weights 53 | 54 | 55 | def batch_embedding_propagation(x, alpha, rbf_scale, norm_prop, propagator=None): 56 | if propagator is None: 57 | weights = batch_get_similarity_matrix(x, rbf_scale) 58 | propagator = batch_global_consistency( 59 | weights, alpha=alpha, norm_prop=norm_prop) 60 | return torch.bmm(propagator, x) 61 | 62 | 63 | def batch_global_consistency(weights, alpha=1, norm_prop=False): 64 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug) 65 | Args: 66 | weights: Tensor of shape (n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and 67 | s the scale parameter. 68 | labels: Tensor of shape (n, n_classes) 69 | alpha: Scaler, acts as a smoothing factor 70 | Returns: 71 | Tensor of shape (n, n_classes) representing the logits of each classes 72 | """ 73 | n = weights.shape[-1] 74 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device) 75 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=-1)) 76 | # checknan(laplacian=isqrt_diag) 77 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, None, :] 78 | # checknan(normalizedlaplacian=S) 79 | 80 | propagator = identity[None] - alpha * S 81 | propagator = torch.inverse(propagator) 82 | # checknan(propagator=propagator) 83 | if norm_prop: 84 | propagator = F.normalize(propagator, p=1, dim=-1) 85 | return propagator 86 | 87 | 88 | def batch_label_propagation(x, labels, nclasses, alpha, rbf_scale, norm_prop, apply_log, propagator=None, balanced=False, epsilon=1e-6): 89 | labels = F.one_hot(labels, nclasses + 1) 90 | labels = labels[..., :nclasses].float() # the max label is unlabeled 91 | if balanced: 92 | labels = labels / labels.sum(-1, keepdim=True) 93 | if propagator is None: 94 | weights = batch_get_similarity_matrix(x, rbf_scale) 95 | propagator = batch_global_consistency( 96 | weights, alpha=alpha, norm_prop=norm_prop) 97 | y_pred = torch.bmm(propagator, labels) 98 | if apply_log: 99 | y_pred = torch.log(y_pred + epsilon) 100 | 101 | return y_pred 102 | -------------------------------------------------------------------------------- /embedding_propagation/embedding_propagation.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class EmbeddingPropagation(torch.nn.Module): 8 | def __init__(self, alpha=0.5, rbf_scale=1, norm_prop=False): 9 | super().__init__() 10 | self.alpha = alpha 11 | self.rbf_scale = rbf_scale 12 | self.norm_prop = norm_prop 13 | 14 | def forward(self, x, propagator=None): 15 | return embedding_propagation(x, self.alpha, self.rbf_scale, self.norm_prop, propagator=propagator) 16 | 17 | 18 | class LabelPropagation(torch.nn.Module): 19 | def __init__(self, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True, balanced=False): 20 | super().__init__() 21 | self.alpha = alpha 22 | self.rbf_scale = rbf_scale 23 | self.norm_prop = norm_prop 24 | self.apply_log = apply_log 25 | self.balanced = balanced 26 | 27 | def forward(self, x, labels, nclasses, propagator=None): 28 | """Applies label propagation given a set of embeddings and labels 29 | 30 | Arguments: 31 | x {Tensor} -- Input embeddings 32 | labels {Tensor} -- Input labels from 0 to nclasses + 1. The highest value corresponds to unlabeled samples. 33 | nclasses {int} -- Total number of classes 34 | 35 | Keyword Arguments: 36 | propagator {Tensor} -- A pre-computed propagator (default: {None}) 37 | 38 | Returns: 39 | tuple(Tensor, Tensor) -- Logits and Propagator 40 | """ 41 | return label_propagation(x, labels, nclasses, self.alpha, self.rbf_scale, 42 | self.norm_prop, self.apply_log, propagator=propagator, 43 | balanced=self.balanced) 44 | 45 | 46 | def get_similarity_matrix(x, rbf_scale): 47 | b, c = x.size() 48 | sq_dist = ((x.view(b, 1, c) - x.view(1, b, c))**2).sum(-1) / np.sqrt(c) 49 | mask = sq_dist != 0 50 | sq_dist = sq_dist / sq_dist[mask].std() 51 | weights = torch.exp(-sq_dist * rbf_scale) 52 | mask = torch.eye(weights.size(1), dtype=torch.bool, device=weights.device) 53 | weights = weights * (~mask).float() 54 | return weights 55 | 56 | 57 | def embedding_propagation(x, alpha, rbf_scale, norm_prop, propagator=None): 58 | if propagator is None: 59 | weights = get_similarity_matrix(x, rbf_scale) 60 | propagator = global_consistency( 61 | weights, alpha=alpha, norm_prop=norm_prop) 62 | return torch.mm(propagator, x) 63 | 64 | 65 | def label_propagation(x, labels, nclasses, alpha, rbf_scale, norm_prop, apply_log, propagator=None, balanced=False, epsilon=1e-6): 66 | labels = F.one_hot(labels, nclasses + 1) 67 | labels = labels[:, :nclasses].float() # the max label is unlabeled 68 | if balanced: 69 | labels = labels / labels.sum(0, keepdim=True) 70 | if propagator is None: 71 | weights = get_similarity_matrix(x, rbf_scale) 72 | propagator = global_consistency( 73 | weights, alpha=alpha, norm_prop=norm_prop) 74 | y_pred = torch.mm(propagator, labels) 75 | if apply_log: 76 | y_pred = torch.log(y_pred + epsilon) 77 | 78 | return y_pred 79 | 80 | 81 | def global_consistency(weights, alpha=1, norm_prop=False): 82 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug) 83 | 84 | Args: 85 | weights: Tensor of shape (n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and 86 | s the scale parameter. 87 | labels: Tensor of shape (n, n_classes) 88 | alpha: Scaler, acts as a smoothing factor 89 | Returns: 90 | Tensor of shape (n, n_classes) representing the logits of each classes 91 | """ 92 | n = weights.shape[1] 93 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device) 94 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=-1)) 95 | # checknan(laplacian=isqrt_diag) 96 | S = weights * isqrt_diag[None, :] * isqrt_diag[:, None] 97 | # checknan(normalizedlaplacian=S) 98 | propagator = identity - alpha * S 99 | propagator = torch.inverse(propagator[None, ...])[0] 100 | # checknan(propagator=propagator) 101 | if norm_prop: 102 | propagator = F.normalize(propagator, p=1, dim=-1) 103 | return propagator 104 | -------------------------------------------------------------------------------- /exp_configs/__init__.py: -------------------------------------------------------------------------------- 1 | from . import pretrain_exps, ssl_exps 2 | from . import pretrain_exps 3 | from . import finetune_exps 4 | 5 | EXP_GROUPS = {} 6 | EXP_GROUPS = pretrain_exps.EXP_GROUPS 7 | EXP_GROUPS.update(ssl_exps.EXP_GROUPS) 8 | EXP_GROUPS.update(finetune_exps.EXP_GROUPS) 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /exp_configs/finetune_exps.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | conv4 = { 4 | "name": "finetuning", 5 | 'backbone':'conv4', 6 | "depth": 4, 7 | "width": 1, 8 | "transform_train": "basic", 9 | "transform_val": "basic", 10 | "transform_test": "basic" 11 | } 12 | 13 | wrn = { 14 | "name": "finetuning", 15 | "backbone": 'wrn', 16 | "depth": 28, 17 | "width": 10, 18 | "transform_train": "wrn_finetune_train", 19 | "transform_val": "wrn_val", 20 | "transform_test": "wrn_val" 21 | } 22 | 23 | resnet12 = { 24 | "name": "finetuning", 25 | "backbone": 'resnet12', 26 | "depth": 12, 27 | "width": 1, 28 | "transform_train": "basic", 29 | "transform_val": "basic", 30 | "transform_test": "basic" 31 | } 32 | 33 | miniimagenet = { 34 | "dataset": "miniimagenet", 35 | "dataset_train": "episodic_miniimagenet", 36 | "dataset_val": "episodic_miniimagenet", 37 | "dataset_test": "episodic_miniimagenet", 38 | "data_root": "mini-imagenet", 39 | "n_classes": 64 40 | } 41 | 42 | tiered_imagenet = { 43 | "dataset": "tiered-imagenet", 44 | "n_classes": 351, 45 | "dataset_train": "episodic_tiered-imagenet", 46 | "dataset_val": "episodic_tiered-imagenet", 47 | "dataset_test": "episodic_tiered-imagenet", 48 | "data_root": "tiered-imagenet", 49 | } 50 | 51 | cub = { 52 | "dataset": "cub", 53 | "n_classes": 100, 54 | "dataset_train": "episodic_cub", 55 | "dataset_val": "episodic_cub", 56 | "dataset_test": "episodic_cub", 57 | "data_root": "CUB_200_2011" 58 | } 59 | 60 | EXP_GROUPS = {"finetune": []} 61 | 62 | for dataset in [miniimagenet, tiered_imagenet, cub]: 63 | for backbone in [conv4, resnet12, wrn]: 64 | for lr in [0.01, 0.001]: 65 | for shot in [1, 5]: 66 | for train_iters in [100, 600]: 67 | for classification_weight in [0, 0.1, 0.5]: 68 | EXP_GROUPS['finetune'] += [{"model": backbone, 69 | 70 | # Hardware 71 | "ngpu": 2, 72 | "random_seed": 42, 73 | 74 | # Optimization 75 | "batch_size": 1, 76 | "target_loss": "val_accuracy", 77 | "lr": lr, 78 | "min_lr_decay": 0.0001, 79 | "weight_decay": 0.0005, 80 | "patience": 10, 81 | "max_epoch": 200, 82 | "train_iters": train_iters, 83 | "val_iters": 600, 84 | "test_iters": 1000, 85 | "tasks_per_batch": 1, 86 | "pretrained_weights_root": "./logs/pretraining", 87 | 88 | # Model 89 | "dropout": 0.1, 90 | "avgpool": True, 91 | 92 | # Data 93 | 'n_classes': dataset["n_classes"], 94 | "collate_fn": "identity", 95 | "transform_train": backbone["transform_train"], 96 | "transform_val": backbone["transform_val"], 97 | "transform_test": backbone["transform_test"], 98 | 99 | "dataset_train": dataset["dataset_train"], 100 | 'dataset_train_root': dataset["data_root"], 101 | "classes_train": 5, 102 | "support_size_train": shot, 103 | "query_size_train": 15, 104 | "unlabeled_size_train": 0, 105 | 106 | "dataset_val": dataset["dataset_val"], 107 | 'dataset_val_root': dataset["data_root"], 108 | "classes_val": 5, 109 | "support_size_val": shot, 110 | "query_size_val": 15, 111 | "unlabeled_size_val": 0, 112 | 113 | "dataset_test": dataset["dataset_test"], 114 | 'dataset_test_root': dataset["data_root"], 115 | "classes_test": 5, 116 | "support_size_test": shot, 117 | "query_size_test": 15, 118 | "unlabeled_size_test": 0, 119 | 120 | # Hparams 121 | "embedding_prop" : True, 122 | "few_shot_weight": 1, 123 | "classification_weight": classification_weight, 124 | "rotation_weight": 0, 125 | "active_size": 0, 126 | "distance_type": "labelprop", 127 | "rotation_labels": [0], 128 | }] 129 | -------------------------------------------------------------------------------- /exp_configs/pretrain_exps.py: -------------------------------------------------------------------------------- 1 | from haven import haven_utils as hu 2 | 3 | conv4 = { 4 | "name": "pretraining", 5 | 'backbone':'conv4', 6 | "depth": 4, 7 | "width": 1, 8 | "transform_train": "basic", 9 | "transform_val": "basic", 10 | "transform_test": "basic" 11 | } 12 | 13 | wrn = { 14 | "name": "pretraining", 15 | "backbone": 'wrn', 16 | "depth": 28, 17 | "width": 10, 18 | "transform_train": "wrn_pretrain_train", 19 | "transform_val": "wrn_val", 20 | "transform_test": "wrn_val" 21 | } 22 | 23 | resnet12 = { 24 | "name": "pretraining", 25 | "backbone": 'resnet12', 26 | "depth": 12, 27 | "width": 1, 28 | "transform_train": "basic", 29 | "transform_val": "basic", 30 | "transform_test": "basic" 31 | } 32 | 33 | miniimagenet = { 34 | "dataset": "miniimagenet", 35 | "dataset_train": "rotated_episodic_miniimagenet_pkl", 36 | "dataset_val": "episodic_miniimagenet_pkl", 37 | "dataset_test": "episodic_miniimagenet_pkl", 38 | "n_classes": 64, 39 | "data_root": "mini-imagenet" 40 | } 41 | 42 | tiered_imagenet = { 43 | "dataset": "tiered-imagenet", 44 | "n_classes": 351, 45 | "dataset_train": "rotated_tiered-imagenet", 46 | "dataset_val": "episodic_tiered-imagenet", 47 | "dataset_test": "episodic_tiered-imagenet", 48 | "data_root": "tiered-imagenet", 49 | } 50 | 51 | cub = { 52 | "dataset": "cub", 53 | "n_classes": 100, 54 | "dataset_train": "rotated_cub", 55 | "dataset_val": "episodic_cub", 56 | "dataset_test": "episodic_cub", 57 | "data_root": "CUB_200_2011" 58 | } 59 | 60 | EXP_GROUPS = {"pretrain": []} 61 | 62 | for dataset in [miniimagenet, tiered_imagenet, cub]: 63 | for backbone in [conv4, resnet12, wrn]: 64 | for lr in [0.2, 0.1]: 65 | EXP_GROUPS['pretrain'] += [{"model": backbone, 66 | 67 | # Hardware 68 | "ngpu": 4, 69 | "random_seed": 42, 70 | 71 | # Optimization 72 | "batch_size": 128, 73 | "target_loss": "val_accuracy", 74 | "lr": lr, 75 | "min_lr_decay": 0.0001, 76 | "weight_decay": 0.0005, 77 | "patience": 10, 78 | "max_epoch": 200, 79 | "train_iters": 600, 80 | "val_iters": 600, 81 | "test_iters": 600, 82 | "tasks_per_batch": 1, 83 | 84 | # Model 85 | "dropout": 0.1, 86 | "avgpool": True, 87 | 88 | # Data 89 | 'n_classes': dataset["n_classes"], 90 | "collate_fn": "default", 91 | "transform_train": backbone["transform_train"], 92 | "transform_val": backbone["transform_val"], 93 | "transform_test": backbone["transform_test"], 94 | 95 | "dataset_train": dataset["dataset_train"], 96 | "dataset_train_root": dataset["data_root"], 97 | "classes_train": 5, 98 | "support_size_train": 5, 99 | "query_size_train": 15, 100 | "unlabeled_size_train": 0, 101 | 102 | "dataset_val": dataset["dataset_val"], 103 | "dataset_val_root": dataset["data_root"], 104 | "classes_val": 5, 105 | "support_size_val": 5, 106 | "query_size_val": 15, 107 | "unlabeled_size_val": 0, 108 | 109 | "dataset_test": dataset["dataset_test"], 110 | "dataset_test_root": dataset["data_root"], 111 | "classes_test": 5, 112 | "support_size_test": 5, 113 | "query_size_test": 15, 114 | "unlabeled_size_test": 0, 115 | 116 | 117 | # Hparams 118 | "embedding_prop": True, 119 | "cross_entropy_weight": 1, 120 | "few_shot_weight": 0, 121 | "rotation_weight": 1, 122 | "active_size": 0, 123 | "distance_type": "labelprop", 124 | "kernel_bound": "", 125 | "rotation_labels": [0, 1, 2, 3] 126 | }] 127 | -------------------------------------------------------------------------------- /exp_configs/ssl_exps.py: -------------------------------------------------------------------------------- 1 | import os 2 | from haven import haven_utils as hu 3 | 4 | import os 5 | from haven import haven_utils as hu 6 | 7 | conv4 = { 8 | "name": "ssl", 9 | 'backbone':'conv4', 10 | "depth": 4, 11 | "width": 1, 12 | "transform_train": "basic", 13 | "transform_val": "basic", 14 | "transform_test": "basic" 15 | } 16 | 17 | wrn = { 18 | "name": "ssl", 19 | "backbone": 'wrn', 20 | "depth": 28, 21 | "width": 10, 22 | "transform_train": "wrn_finetune_train", 23 | "transform_val": "wrn_val", 24 | "transform_test": "wrn_val" 25 | } 26 | 27 | resnet12 = { 28 | "name": "ssl", 29 | "backbone": 'resnet12', 30 | "depth": 12, 31 | "width": 1, 32 | "transform_train": "basic", 33 | "transform_val": "basic", 34 | "transform_test": "basic" 35 | } 36 | 37 | miniimagenet = { 38 | "dataset": "miniimagenet", 39 | "dataset_train": "episodic_miniimagenet", 40 | "dataset_val": "episodic_miniimagenet", 41 | "dataset_test": "episodic_miniimagenet", 42 | "n_classes": 64, 43 | 'data_root':'mini-imagenet/' 44 | } 45 | 46 | tiered_imagenet = { 47 | "dataset": "tiered-imagenet", 48 | "n_classes": 351, 49 | "dataset_train": "episodic_tiered-imagenet", 50 | "dataset_val": "episodic_tiered-imagenet", 51 | "dataset_test": "episodic_tiered-imagenet", 52 | 'data_root':'tiered-imagenet' 53 | } 54 | 55 | cub = { 56 | "dataset": "cub", 57 | "n_classes": 100, 58 | "dataset_train": "episodic_cub", 59 | "dataset_val": "episodic_cub", 60 | "dataset_test": "episodic_cub", 61 | 'data_root':'CUB_200_2011' 62 | } 63 | 64 | EXP_GROUPS = {} 65 | EXP_GROUPS['ssl_large'] = [] 66 | # 12 exps 67 | for dataset in [miniimagenet, tiered_imagenet]: 68 | for backbone in [resnet12, conv4, wrn]: 69 | for embedding_prop in [True]: 70 | for shot in [1, 5]: 71 | EXP_GROUPS['ssl_large'] += [{ 72 | 'dataset_train_root': dataset["data_root"], 73 | 'dataset_val_root': dataset["data_root"], 74 | 'dataset_test_root': dataset["data_root"], 75 | "model": backbone, 76 | 77 | # Hardware 78 | "ngpu": 1, 79 | "random_seed": 42, 80 | 81 | # Optimization 82 | "batch_size": 1, 83 | "train_iters": 10, 84 | "test_iters": 600, 85 | "tasks_per_batch": 1, 86 | 87 | # Model 88 | "dropout": 0.1, 89 | "avgpool": True, 90 | 91 | # Data 92 | 'n_classes': dataset["n_classes"], 93 | "collate_fn": "identity", 94 | "transform_train": backbone["transform_train"], 95 | "transform_val": backbone["transform_val"], 96 | "transform_test": backbone["transform_test"], 97 | 98 | "dataset_train": dataset["dataset_train"], 99 | "classes_train": 5, 100 | "support_size_train": shot, 101 | "query_size_train": 15, 102 | "unlabeled_size_train": 0, 103 | 104 | "dataset_val": dataset["dataset_val"], 105 | "classes_val": 5, 106 | "support_size_val": shot, 107 | "query_size_val": 15, 108 | "unlabeled_size_val": 0, 109 | 110 | "dataset_test": dataset["dataset_test"], 111 | "classes_test": 5, 112 | "support_size_test": shot, 113 | "query_size_test": 15, 114 | "unlabeled_size_test": 100, 115 | "predict_method": "labelprop", 116 | "finetuned_weights_root": "./logs/finetuning", 117 | 118 | # Hparams 119 | "embedding_prop" : embedding_prop, 120 | }] 121 | 122 | # 24 exps 123 | EXP_GROUPS['ssl_small'] = [] 124 | for dataset in [tiered_imagenet, miniimagenet]: 125 | for backbone in [conv4, resnet12, wrn]: 126 | for embedding_prop in [True]: 127 | for shot, ust in zip([1,2,3], 128 | [4,3,2]): 129 | EXP_GROUPS['ssl_small'] += [{ 130 | 'dataset_train_root': dataset["data_root"], 131 | 'dataset_val_root': dataset["data_root"], 132 | 'dataset_test_root': dataset["data_root"], 133 | 134 | "model": backbone, 135 | 136 | # Hardware 137 | "ngpu": 1, 138 | "random_seed": 42, 139 | 140 | # Optimization 141 | "batch_size": 1, 142 | "train_iters": 10, 143 | "val_iters": 600, 144 | "test_iters": 600, 145 | "tasks_per_batch": 1, 146 | 147 | # Model 148 | "dropout": 0.1, 149 | "avgpool": True, 150 | 151 | # Data 152 | 'n_classes': dataset["n_classes"], 153 | "collate_fn": "identity", 154 | "transform_train": backbone["transform_train"], 155 | "transform_val": backbone["transform_val"], 156 | "transform_test": backbone["transform_test"], 157 | 158 | "dataset_train": dataset["dataset_train"], 159 | "classes_train": 5, 160 | "support_size_train": shot, 161 | "query_size_train": 15, 162 | "unlabeled_size_train": 0, 163 | 164 | "dataset_val": dataset["dataset_val"], 165 | "classes_val": 5, 166 | "support_size_val": shot, 167 | "query_size_val": 15, 168 | "unlabeled_size_val": 0, 169 | 170 | "dataset_test": dataset["dataset_test"], 171 | "classes_test": 5, 172 | "support_size_test": shot, 173 | "query_size_test": 15, 174 | "unlabeled_size_test": ust, 175 | "predict_method":"labelprop", 176 | "finetuned_weights_root": "./logs/finetuning", 177 | 178 | # Hparams 179 | "embedding_prop" : embedding_prop, 180 | }] 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | haven-ai>=0.0 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='embedding_propagation', 4 | version='0.6.0', 5 | description='Manifold Regularization', 6 | url='https://github.com/ElementAI/embedding-propagation', 7 | maintainer='Pau Rodríguez', 8 | maintainer_email='issam.laradji@elementai.com', 9 | license='MIT', 10 | packages=['embedding_propagation'], 11 | zip_safe=False, 12 | install_requires=[ 13 | 'tqdm>=0.0', 14 | 'matplotlib>=0.0', 15 | 'numpy>=0.0', 16 | 'pandas>=0.0', 17 | 'Pillow>=0.0', 18 | 'scikit-image>=0.0', 19 | 'scikit-learn>=0.0', 20 | 'sklearn>=0.0', 21 | 'torch>=0.0', 22 | 'torchvision>=0.0' 23 | ]), -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/__init__.py -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # from . import trancos, fish_reg 3 | from torchvision import transforms 4 | import torchvision 5 | import cv2, os 6 | from src import utils as ut 7 | import pandas as pd 8 | import numpy as np 9 | 10 | import pandas as pd 11 | import numpy as np 12 | from torchvision.datasets import CIFAR10, CIFAR100 13 | from .episodic_dataset import FewShotSampler 14 | from .episodic_miniimagenet import EpisodicMiniImagenet, EpisodicMiniImagenetPkl 15 | from .miniimagenet import NonEpisodicMiniImagenet, RotatedNonEpisodicMiniImagenet, RotatedNonEpisodicMiniImagenetPkl 16 | from .episodic_tiered_imagenet import EpisodicTieredImagenet 17 | from .tiered_imagenet import RotatedNonEpisodicTieredImagenet 18 | from .cub import RotatedNonEpisodicCUB, NonEpisodicCUB 19 | from .episodic_cub import EpisodicCUB 20 | from .episodic_tiered_imagenet import EpisodicTieredImagenet 21 | from .tiered_imagenet import RotatedNonEpisodicTieredImagenet, NonEpisodicTieredImagenet 22 | 23 | 24 | def get_dataset(dataset_name, 25 | data_root, 26 | split, 27 | transform, 28 | classes, 29 | support_size, 30 | query_size, 31 | unlabeled_size, 32 | n_iters): 33 | 34 | transform_func = get_transformer(transform, split) 35 | if dataset_name == "rotated_miniimagenet": 36 | dataset = RotatedNonEpisodicMiniImagenet(data_root, 37 | split, 38 | transform_func) 39 | 40 | elif dataset_name == "miniimagenet": 41 | dataset = NonEpisodicMiniImagenet(data_root, 42 | split, 43 | transform_func) 44 | elif dataset_name == "episodic_miniimagenet": 45 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size) 46 | dataset = EpisodicMiniImagenet(data_root=data_root, 47 | split=split, 48 | sampler=few_shot_sampler, 49 | size=n_iters, 50 | transforms=transform_func) 51 | elif dataset_name == "rotated_episodic_miniimagenet_pkl": 52 | dataset = RotatedNonEpisodicMiniImagenetPkl(data_root=data_root, 53 | split=split, 54 | transforms=transform_func) 55 | elif dataset_name == "episodic_miniimagenet_pkl": 56 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size) 57 | dataset = EpisodicMiniImagenetPkl(data_root=data_root, 58 | split=split, 59 | sampler=few_shot_sampler, 60 | size=n_iters, 61 | transforms=transform_func) 62 | elif dataset_name == "cub": 63 | dataset = NonEpisodicCUB(data_root, split, transform_func) 64 | elif dataset_name == "rotated_cub": 65 | dataset = RotatedNonEpisodicCUB(data_root, split, transform_func) 66 | elif dataset_name == "episodic_cub": 67 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size) 68 | dataset = EpisodicCUB(data_root=data_root, 69 | split=split, 70 | sampler=few_shot_sampler, 71 | size=n_iters, 72 | transforms=transform_func) 73 | elif dataset_name == "tiered-imagenet": 74 | dataset = NonEpisodicTieredImagenet(data_root, split, transform_func) 75 | elif dataset_name == "rotated_tiered-imagenet": 76 | dataset = RotatedNonEpisodicTieredImagenet(data_root, split, transform_func) 77 | elif dataset_name == "episodic_tiered-imagenet": 78 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size) 79 | dataset = EpisodicTieredImagenet(data_root, 80 | split=split, 81 | sampler=few_shot_sampler, 82 | size=n_iters, 83 | transforms=transform_func) 84 | return dataset 85 | 86 | 87 | # =================================================== 88 | # helpers 89 | def get_transformer(transform, split): 90 | if transform == "data_augmentation": 91 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 92 | torchvision.transforms.Resize((84,84)), 93 | torchvision.transforms.ToTensor()]) 94 | 95 | return transform 96 | 97 | if "{}_{}".format(transform, split) == "cifar_train": 98 | transform = torchvision.transforms.transforms.Compose([ 99 | torchvision.transforms.RandomCrop(32, padding=4), 100 | torchvision.transforms.RandomHorizontalFlip(), 101 | torchvision.transforms.ToTensor(), 102 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 103 | ]) 104 | return transform 105 | 106 | if "{}_{}".format(transform, split) == "cifar_test" or "{}_{}".format(transform, split) == "cifar_val": 107 | transform = torchvision.transforms.Compose([ 108 | torchvision.transforms.ToTensor(), 109 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 110 | ]) 111 | return transform 112 | 113 | if transform == "basic": 114 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 115 | torchvision.transforms.Resize((84,84)), 116 | torchvision.transforms.ToTensor()]) 117 | 118 | return transform 119 | 120 | if transform == "wrn_pretrain_train": 121 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 122 | torchvision.transforms.RandomResizedCrop((80, 80), scale=(0.08, 1)), 123 | torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 124 | torchvision.transforms.ToTensor() 125 | ]) 126 | return transform 127 | elif transform == "wrn_finetune_train": 128 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 129 | torchvision.transforms.Resize((92, 92)), 130 | torchvision.transforms.CenterCrop(80), 131 | torchvision.transforms.RandomHorizontalFlip(), 132 | torchvision.transforms.ToTensor() 133 | ]) 134 | return transform 135 | 136 | elif "wrn" in transform: 137 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 138 | torchvision.transforms.Resize((92, 92)), 139 | torchvision.transforms.CenterCrop(80), 140 | torchvision.transforms.ToTensor()]) 141 | return transform 142 | 143 | raise NotImplementedError 144 | 145 | -------------------------------------------------------------------------------- /src/datasets/cub.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torchvision 3 | import torch 4 | from torch.utils.data import Dataset 5 | import json 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | 10 | class NonEpisodicCUB(Dataset): 11 | name="CUB" 12 | task="cls" 13 | split_paths = {"train":"train", "test":"test", "valid": "val"} 14 | c = 3 15 | h = 84 16 | w = 84 17 | 18 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs): 19 | """ Constructor 20 | 21 | Args: 22 | split: data split 23 | few_shot_sampler: FewShotSampler instance 24 | task: dataset task (if more than one) 25 | size: number of tasks to generate (int) 26 | disjoint: whether to create disjoint splits. 27 | """ 28 | self.data_root = data_root 29 | self.split = {"train":"base", "val":"val", "valid":"val", "test":"novel"}[split] 30 | with open(os.path.join(self.data_root, "few_shot_lists", "%s.json" %self.split), 'r') as infile: 31 | self.metadata = json.load(infile) 32 | self.transforms = transforms 33 | self.rotation_labels = rotation_labels 34 | self.labels = np.array(self.metadata['image_labels']) 35 | label_map = {l: i for i, l in enumerate(sorted(np.unique(self.labels)))} 36 | self.labels = np.array([label_map[l] for l in self.labels]) 37 | self.size = len(self.metadata["image_labels"]) 38 | 39 | def next_run(self): 40 | pass 41 | 42 | def rotate_img(self, img, rot): 43 | if rot == 0: # 0 degrees rotation 44 | return img 45 | elif rot == 90: # 90 degrees rotation 46 | return np.flipud(np.transpose(img, (1, 0, 2))) 47 | elif rot == 180: # 90 degrees rotation 48 | return np.fliplr(np.flipud(img)) 49 | elif rot == 270: # 270 degrees rotation / or -90 50 | return np.transpose(np.flipud(img), (1, 0, 2)) 51 | else: 52 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 53 | 54 | def __getitem__(self, item): 55 | image = np.array(Image.open(self.metadata["image_names"][item]).convert("RGB")) 56 | images = self.transforms(image) * 2 - 1 57 | return images, int(self.labels[item]) 58 | 59 | def __len__(self): 60 | return len(self.labels) 61 | 62 | class RotatedNonEpisodicCUB(NonEpisodicCUB): 63 | name="CUB" 64 | task="cls" 65 | split_paths = {"train":"train", "test":"test", "valid": "val"} 66 | c = 3 67 | h = 84 68 | w = 84 69 | 70 | def __init__(self, *args, **kwargs): 71 | """ Constructor 72 | 73 | Args: 74 | split: data split 75 | few_shot_sampler: FewShotSampler instance 76 | task: dataset task (if more than one) 77 | size: number of tasks to generate (int) 78 | disjoint: whether to create disjoint splits. 79 | """ 80 | super().__init__(*args, **kwargs) 81 | 82 | def rotate_img(self, img, rot): 83 | if rot == 0: # 0 degrees rotation 84 | return img 85 | elif rot == 90: # 90 degrees rotation 86 | return np.flipud(np.transpose(img, (1, 0, 2))) 87 | elif rot == 180: # 90 degrees rotation 88 | return np.fliplr(np.flipud(img)) 89 | elif rot == 270: # 270 degrees rotation / or -90 90 | return np.transpose(np.flipud(img), (1, 0, 2)) 91 | else: 92 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 93 | 94 | def __getitem__(self, item): 95 | image = np.array(Image.open(self.metadata["image_names"][item]).convert("RGB")) 96 | if np.random.randint(2): 97 | image = np.fliplr(image) 98 | image_90 = self.transforms(self.rotate_img(image, 90)) 99 | image_180 = self.transforms(self.rotate_img(image, 180)) 100 | image_270 = self.transforms(self.rotate_img(image, 270)) 101 | images = torch.stack([self.transforms(image), image_90, image_180, image_270]) * 2 - 1 102 | return images, torch.ones(4, dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels) -------------------------------------------------------------------------------- /src/datasets/episodic_cub.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torchvision 3 | import torch 4 | from torch.utils.data import Dataset 5 | from .episodic_dataset import EpisodicDataset, FewShotSampler 6 | import json 7 | import os 8 | import numpy as np 9 | from PIL import Image 10 | import numpy 11 | 12 | class EpisodicCUB(EpisodicDataset): 13 | h = 84 14 | w = 84 15 | c = 3 16 | name="CUB" 17 | task="cls" 18 | split_paths = {"train":"base", "val":"val", "valid":"val", "test":"novel"} 19 | def __init__(self, data_root, split, sampler, size, transforms): 20 | self.data_root = data_root 21 | self.split = split 22 | with open(os.path.join(self.data_root, "few_shot_lists", "%s.json" %self.split_paths[split]), 'r') as infile: 23 | self.metadata = json.load(infile) 24 | labels = np.array(self.metadata['image_labels']) 25 | label_map = {l: i for i, l in enumerate(sorted(np.unique(labels)))} 26 | labels = np.array([label_map[l] for l in labels]) 27 | super().__init__(labels, sampler, size, transforms) 28 | 29 | def sample_images(self, indices): 30 | return [np.array(Image.open(self.metadata['image_names'][i]).convert("RGB")) for i in indices] 31 | 32 | def __iter__(self): 33 | return super().__iter__() 34 | 35 | if __name__ == '__main__': 36 | from torch.utils.data import DataLoader 37 | from tools.plot_episode import plot_episode 38 | sampler = FewShotSampler(5, 5, 15, 0) 39 | transforms = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 40 | torchvision.transforms.Resize((84,84)), 41 | torchvision.transforms.ToTensor(), 42 | ]) 43 | dataset = EpisodicCUB("train", sampler, 10, transforms) 44 | loader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x) 45 | for batch in loader: 46 | plot_episode(batch[0], classes_first=False) 47 | 48 | -------------------------------------------------------------------------------- /src/datasets/episodic_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torch.utils.data import DataLoader 5 | import collections 6 | from copy import deepcopy 7 | 8 | _DataLoader = DataLoader 9 | class EpisodicDataLoader(_DataLoader): 10 | def __iter__(self): 11 | if isinstance(self.dataset, EpisodicDataset): 12 | self.dataset.__iter__() 13 | else: 14 | pass 15 | return super().__iter__() 16 | torch.utils.data.DataLoader = EpisodicDataLoader 17 | 18 | class FewShotSampler(): 19 | FewShotTask = collections.namedtuple("FewShotTask", ["nclasses", "support_size", "query_size", "unlabeled_size"]) 20 | def __init__(self, nclasses, support_size, query_size, unlabeled_size): 21 | self.task = self.FewShotTask(nclasses, support_size, query_size, unlabeled_size) 22 | 23 | def sample(self): 24 | return deepcopy(self.task) 25 | 26 | class EpisodicDataset(Dataset): 27 | def __init__(self, labels, sampler, size, transforms): 28 | self.labels = labels 29 | self.sampler = sampler 30 | self.labelset = np.unique(labels) 31 | self.indices = np.arange(len(labels)) 32 | self.transforms = transforms 33 | self.reshuffle() 34 | self.size = size 35 | 36 | def reshuffle(self): 37 | """ 38 | Helper method to randomize tasks again 39 | """ 40 | self.clss_idx = [np.random.permutation(self.indices[self.labels == label]) for label in self.labelset] 41 | self.starts = np.zeros(len(self.clss_idx), dtype=int) 42 | self.lengths = np.array([len(x) for x in self.clss_idx]) 43 | 44 | def gen_few_shot_task(self, nclasses, size): 45 | """ Iterates through the dataset sampling tasks 46 | 47 | Args: 48 | n: FewShotTask.n 49 | sample_size: FewShotTask.k 50 | query_size: FewShotTask.k (default), else query_set_size // FewShotTask.n 51 | 52 | Returns: Sampled task or None in the case the dataset has been exhausted. 53 | 54 | """ 55 | classes = np.random.choice(self.labelset, nclasses, replace=False) 56 | starts = self.starts[classes] 57 | reminders = self.lengths[classes] - starts 58 | if np.min(reminders) < size: 59 | return None 60 | sample_indices = np.array( 61 | [self.clss_idx[classes[i]][starts[i]:(starts[i] + size)] for i in range(len(classes))]) 62 | sample_indices = np.reshape(sample_indices, [nclasses, size]).transpose() 63 | self.starts[classes] += size 64 | return sample_indices.flatten() 65 | 66 | def sample_task_list(self): 67 | """ Generates a list of tasks (until the dataset is exhausted) 68 | 69 | Returns: the list of tasks [(FewShotTask object, task_indices), ...] 70 | 71 | """ 72 | task_list = [] 73 | task_info = self.sampler.sample() 74 | nclasses, support_size, query_size, unlabeled_size = task_info 75 | unlabeled_size = min(unlabeled_size, self.lengths.min() - support_size - query_size) 76 | task_info = FewShotSampler.FewShotTask(nclasses=nclasses, 77 | support_size=support_size, 78 | query_size=query_size, 79 | unlabeled_size=unlabeled_size) 80 | k = support_size + query_size + unlabeled_size 81 | if np.any(k > self.lengths): 82 | raise RuntimeError("Requested more samples than existing") 83 | few_shot_task = self.gen_few_shot_task(nclasses, k) 84 | 85 | while few_shot_task is not None: 86 | task_list.append((task_info, few_shot_task)) 87 | task_info = self.sampler.sample() 88 | nclasses, support_size, query_size, unlabeled_size = task_info 89 | k = support_size + query_size + unlabeled_size 90 | few_shot_task = self.gen_few_shot_task(nclasses, k) 91 | return task_list 92 | 93 | def sample_images(self, indices): 94 | raise NotImplementedError 95 | 96 | def __getitem__(self, idx): 97 | """ Reads the idx th task (episode) from disk 98 | 99 | Args: 100 | idx: task index 101 | 102 | Returns: task dictionary with (dataset (char), task (char), dim (tuple), episode (Tensor)) 103 | 104 | """ 105 | fs_task_info, indices = self.task_list[idx] 106 | ordered_argindices = np.argsort(indices) 107 | ordered_indices = np.sort(indices) 108 | nclasses, support_size, query_size, unlabeled_size = fs_task_info 109 | k = support_size + query_size + unlabeled_size 110 | _images = self.sample_images(ordered_indices) 111 | images = torch.stack([self.transforms(_images[i]) for i in np.argsort(ordered_argindices)]) 112 | total, c, h, w = images.size() 113 | assert(total == (k * nclasses)) 114 | images = images.view(k, nclasses, c, h, w) 115 | del(_images) 116 | images = images * 2 - 1 117 | targets = np.zeros([nclasses * k], dtype=int) 118 | targets[ordered_argindices] = self.labels[ordered_indices, ...].ravel() 119 | sample = {"dataset": self.name, 120 | "channels": c, 121 | "height": h, 122 | "width": w, 123 | "nclasses": nclasses, 124 | "support_size": support_size, 125 | "query_size": query_size, 126 | "unlabeled_size": unlabeled_size, 127 | "targets": torch.from_numpy(targets), 128 | "support_set": images[:support_size, ...], 129 | "query_set": images[support_size:(support_size + 130 | query_size), ...], 131 | "unlabeled_set": None if unlabeled_size == 0 else images[(support_size + query_size):, ...]} 132 | return sample 133 | 134 | 135 | def __iter__(self): 136 | # print("Prefetching new epoch episodes") 137 | self.task_list = [] 138 | while len(self.task_list) < self.size: 139 | self.reshuffle() 140 | self.task_list += self.sample_task_list() 141 | # print("done prefetching.") 142 | return [] 143 | 144 | def __len__(self): 145 | return self.size -------------------------------------------------------------------------------- /src/datasets/episodic_miniimagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import os 5 | from src.datasets.episodic_dataset import EpisodicDataset, FewShotSampler 6 | import pickle as pkl 7 | 8 | # Inherit order is important, FewShotDataset constructor is prioritary 9 | class EpisodicMiniImagenet(EpisodicDataset): 10 | tasks_type = "clss" 11 | name = "miniimagenet" 12 | episodic=True 13 | split_paths = {"train":"train", "valid":"val", "val":"val", "test": "test"} 14 | # c = 3 15 | # h = 84 16 | # w = 84 17 | 18 | def __init__(self, data_root, split, sampler, size, transforms): 19 | """ Constructor 20 | 21 | Args: 22 | split: data split 23 | few_shot_sampler: FewShotSampler instance 24 | task: dataset task (if more than one) 25 | size: number of tasks to generate (int) 26 | disjoint: whether to create disjoint splits. 27 | """ 28 | self.data_root = os.path.join(data_root, "mini-imagenet-%s.npz") 29 | self.split = split 30 | data = np.load(self.data_root % self.split_paths[split]) 31 | self.features = data["features"] 32 | labels = data["targets"] 33 | del(data) 34 | super().__init__(labels, sampler, size, transforms) 35 | 36 | def sample_images(self, indices): 37 | return self.features[indices] 38 | 39 | def __iter__(self): 40 | return super().__iter__() 41 | 42 | # Inherit order is important, FewShotDataset constructor is prioritary 43 | class EpisodicMiniImagenetPkl(EpisodicDataset): 44 | tasks_type = "clss" 45 | name = "miniimagenet" 46 | episodic=True 47 | split_paths = {"train":"train", "valid":"val", "val":"val", "test": "test"} 48 | # c = 3 49 | # h = 84 50 | # w = 84 51 | 52 | def __init__(self, data_root, split, sampler, size, transforms): 53 | """ Constructor 54 | 55 | Args: 56 | split: data split 57 | few_shot_sampler: FewShotSampler instance 58 | task: dataset task (if more than one) 59 | size: number of tasks to generate (int) 60 | disjoint: whether to create disjoint splits. 61 | """ 62 | self.data_root = os.path.join(data_root, "mini-imagenet-cache-%s.pkl") 63 | self.split = split 64 | with open(self.data_root % self.split_paths[split], 'rb') as infile: 65 | data = pkl.load(infile) 66 | self.features = data["image_data"] 67 | label_names = data["class_dict"].keys() 68 | labels = np.zeros((self.features.shape[0],), dtype=int) 69 | for i, name in enumerate(sorted(label_names)): 70 | labels[np.array(data['class_dict'][name])] = i 71 | del(data) 72 | super().__init__(labels, sampler, size, transforms) 73 | 74 | def sample_images(self, indices): 75 | return self.features[indices] 76 | 77 | def __iter__(self): 78 | return super().__iter__() 79 | 80 | if __name__ == '__main__': 81 | from torch.utils.data import DataLoader 82 | from src.tools.plot_episode import plot_episode 83 | import time 84 | sampler = FewShotSampler(5, 5, 15, 0) 85 | transforms = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 86 | torchvision.transforms.ToTensor(), 87 | ]) 88 | dataset = EpisodicMiniImagenetPkl('./miniimagenet', 'train', sampler, 1000, transforms) 89 | loader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x) 90 | for batch in loader: 91 | print(np.unique(batch[0]["targets"].view(20, 5).numpy())) 92 | # plot_episode(batch[0], classes_first=False) 93 | # time.sleep(1) 94 | 95 | -------------------------------------------------------------------------------- /src/datasets/episodic_tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torchvision 3 | import torch 4 | from torch.utils.data import Dataset 5 | from .episodic_dataset import EpisodicDataset, FewShotSampler 6 | import json 7 | import os 8 | import numpy as np 9 | import numpy 10 | import cv2 11 | import pickle as pkl 12 | 13 | # Inherit order is important, FewShotDataset constructor is prioritary 14 | class EpisodicTieredImagenet(EpisodicDataset): 15 | tasks_type = "clss" 16 | name = "tiered-imagenet" 17 | split_paths = {"train":"train", "test":"test", "valid": "val"} 18 | c = 3 19 | h = 84 20 | w = 84 21 | def __init__(self, data_root, split, sampler, size, transforms): 22 | self.data_root = data_root 23 | self.split = split 24 | img_path = os.path.join(self.data_root, "%s_images_png.pkl" %(split)) 25 | label_path = os.path.join(self.data_root, "%s_labels.pkl" %(split)) 26 | with open(img_path, 'rb') as infile: 27 | self.features = pkl.load(infile, encoding="bytes") 28 | with open(label_path, 'rb') as infile: 29 | labels = pkl.load(infile, encoding="bytes")[b'label_specific'] 30 | super().__init__(labels, sampler, size, transforms) 31 | 32 | def sample_images(self, indices): 33 | return [cv2.imdecode(self.features[i], cv2.IMREAD_COLOR)[:,:,::-1] for i in indices] 34 | 35 | def __iter__(self): 36 | return super().__iter__() 37 | 38 | if __name__ == '__main__': 39 | import sys 40 | from torch.utils.data import DataLoader 41 | from tools.plot_episode import plot_episode 42 | sampler = FewShotSampler(5, 5, 15, 0) 43 | transforms = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(), 44 | torchvision.transforms.Resize((84,84)), 45 | torchvision.transforms.ToTensor(), 46 | ]) 47 | dataset = EpisodicTieredImagenet("train", sampler, 10, transforms) 48 | loader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x) 49 | for batch in loader: 50 | plot_episode(batch[0], classes_first=False) 51 | -------------------------------------------------------------------------------- /src/datasets/miniimagenet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | import os 5 | import pickle as pkl 6 | 7 | class NonEpisodicMiniImagenet(Dataset): 8 | tasks_type = "clss" 9 | name = "miniimagenet" 10 | split_paths = {"train": "train", "val":"val", "valid": "val", "test": "test"} 11 | episodic=False 12 | c = 3 13 | h = 84 14 | w = 84 15 | 16 | def __init__(self, data_root, split, transforms, **kwargs): 17 | """ Constructor 18 | 19 | Args: 20 | split: data split 21 | few_shot_sampler: FewShotSampler instance 22 | task: dataset task (if more than one) 23 | size: number of tasks to generate (int) 24 | disjoint: whether to create disjoint splits. 25 | """ 26 | self.data_root = os.path.join(data_root, "mini-imagenet-%s.npz") 27 | data = np.load(self.data_root % self.split_paths[split]) 28 | self.features = data["features"] 29 | self.labels = data["targets"] 30 | self.transforms = transforms 31 | 32 | def next_run(self): 33 | pass 34 | 35 | def __getitem__(self, item): 36 | image = self.transforms(self.features[item]) 37 | image = image * 2 - 1 38 | return image, self.labels[item] 39 | 40 | def __len__(self): 41 | return len(self.features) 42 | 43 | class RotatedNonEpisodicMiniImagenet(Dataset): 44 | tasks_type = "clss" 45 | name = "miniimagenet" 46 | split_paths = {"train": "train", "val":"val", "valid": "val", "test": "test"} 47 | episodic=False 48 | c = 3 49 | h = 84 50 | w = 84 51 | 52 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs): 53 | """ Constructor 54 | 55 | Args: 56 | split: data split 57 | few_shot_sampler: FewShotSampler instance 58 | task: dataset task (if more than one) 59 | size: number of tasks to generate (int) 60 | disjoint: whether to create disjoint splits. 61 | """ 62 | self.data_root = os.path.join(data_root, "mini-imagenet-%s.npz") 63 | data = np.load(self.data_root % self.split_paths[split]) 64 | self.features = data["features"] 65 | self.labels = data["targets"] 66 | self.transforms = transforms 67 | self.size = len(self.features) 68 | self.rotation_labels = rotation_labels 69 | 70 | def next_run(self): 71 | pass 72 | 73 | def rotate_img(self, img, rot): 74 | if rot == 0: # 0 degrees rotation 75 | return img 76 | elif rot == 90: # 90 degrees rotation 77 | return np.flipud(np.transpose(img, (1, 0, 2))) 78 | elif rot == 180: # 90 degrees rotation 79 | return np.fliplr(np.flipud(img)) 80 | elif rot == 270: # 270 degrees rotation / or -90 81 | return np.transpose(np.flipud(img), (1, 0, 2)) 82 | else: 83 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 84 | 85 | def __getitem__(self, item): 86 | image = self.features[item] 87 | if np.random.randint(2): 88 | image = np.fliplr(image).copy() 89 | cat = [self.transforms(image)] 90 | if len(self.rotation_labels) > 1: 91 | image_90 = self.transforms(self.rotate_img(image, 90)) 92 | image_180 = self.transforms(self.rotate_img(image, 180)) 93 | image_270 = self.transforms(self.rotate_img(image, 270)) 94 | cat.extend([image_90, image_180, image_270]) 95 | images = torch.stack(cat) * 2 - 1 96 | return images, torch.ones(len(self.rotation_labels), dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels) 97 | 98 | def __len__(self): 99 | return self.size 100 | 101 | class RotatedNonEpisodicMiniImagenetPkl(Dataset): 102 | tasks_type = "clss" 103 | name = "miniimagenet" 104 | split_paths = {"train": "train", "val":"val", "valid": "val", "test": "test"} 105 | episodic=False 106 | c = 3 107 | h = 84 108 | w = 84 109 | 110 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs): 111 | """ Constructor 112 | 113 | Args: 114 | split: data split 115 | few_shot_sampler: FewShotSampler instance 116 | task: dataset task (if more than one) 117 | size: number of tasks to generate (int) 118 | disjoint: whether to create disjoint splits. 119 | """ 120 | self.data_root = os.path.join(data_root, "mini-imagenet-cache-%s.pkl") 121 | with open(self.data_root % self.split_paths[split], 'rb') as infile: 122 | data = pkl.load(infile) 123 | self.features = data["image_data"] 124 | label_names = data["class_dict"].keys() 125 | self.labels = np.zeros((self.features.shape[0],), dtype=int) 126 | for i, name in enumerate(sorted(label_names)): 127 | self.labels[np.array(data['class_dict'][name])] = i 128 | del(data) 129 | self.transforms = transforms 130 | self.size = len(self.features) 131 | self.rotation_labels = rotation_labels 132 | 133 | def next_run(self): 134 | pass 135 | 136 | def rotate_img(self, img, rot): 137 | if rot == 0: # 0 degrees rotation 138 | return img 139 | elif rot == 90: # 90 degrees rotation 140 | return np.flipud(np.transpose(img, (1, 0, 2))) 141 | elif rot == 180: # 90 degrees rotation 142 | return np.fliplr(np.flipud(img)) 143 | elif rot == 270: # 270 degrees rotation / or -90 144 | return np.transpose(np.flipud(img), (1, 0, 2)) 145 | else: 146 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 147 | 148 | def __getitem__(self, item): 149 | image = self.features[item] 150 | if np.random.randint(2): 151 | image = np.fliplr(image).copy() 152 | cat = [self.transforms(image)] 153 | if len(self.rotation_labels) > 1: 154 | image_90 = self.transforms(self.rotate_img(image, 90)) 155 | image_180 = self.transforms(self.rotate_img(image, 180)) 156 | image_270 = self.transforms(self.rotate_img(image, 270)) 157 | cat.extend([image_90, image_180, image_270]) 158 | images = torch.stack(cat) * 2 - 1 159 | return images, torch.ones(len(self.rotation_labels), dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels) 160 | 161 | def __len__(self): 162 | return self.size 163 | -------------------------------------------------------------------------------- /src/datasets/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import PIL 4 | import pickle as pkl 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import torch 8 | 9 | class NonEpisodicTieredImagenet(Dataset): 10 | tasks_type = "clss" 11 | name = "tiered-imagenet" 12 | split_paths = {"train":"train", "test":"test", "valid": "val"} 13 | c = 3 14 | h = 84 15 | w = 84 16 | 17 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs): 18 | """ Constructor 19 | 20 | Args: 21 | split: data split 22 | few_shot_sampler: FewShotSampler instance 23 | task: dataset task (if more than one) 24 | size: number of tasks to generate (int) 25 | disjoint: whether to create disjoint splits. 26 | """ 27 | split = self.split_paths[split] 28 | self.data_root = data_root 29 | img_path = os.path.join(self.data_root, "%s_images_png.pkl" %(split)) 30 | label_path = os.path.join(self.data_root, "%s_labels.pkl" %(split)) 31 | self.transforms = transforms 32 | self.rotation_labels = rotation_labels 33 | with open(img_path, 'rb') as infile: 34 | self.features = pkl.load(infile, encoding="bytes") 35 | with open(label_path, 'rb') as infile: 36 | self.labels = pkl.load(infile, encoding="bytes")[b'label_specific'] 37 | 38 | def next_run(self): 39 | pass 40 | 41 | def __getitem__(self, item): 42 | image = cv2.imdecode(self.features[item], cv2.IMREAD_COLOR)[..., ::-1] 43 | images = self.transforms(image) * 2 - 1 44 | return images, int(self.labels[item]) 45 | 46 | def __len__(self): 47 | return len(self.labels) 48 | 49 | class RotatedNonEpisodicTieredImagenet(NonEpisodicTieredImagenet): 50 | tasks_type = "clss" 51 | name = "tiered-imagenet" 52 | split_paths = {"train":"train", "test":"test", "valid": "val"} 53 | c = 3 54 | h = 84 55 | w = 84 56 | 57 | def __init__(self, *args, **kwargs): 58 | """ Constructor 59 | 60 | Args: 61 | split: data split 62 | few_shot_sampler: FewShotSampler instance 63 | task: dataset task (if more than one) 64 | size: number of tasks to generate (int) 65 | disjoint: whether to create disjoint splits. 66 | """ 67 | super().__init__(*args, **kwargs) 68 | 69 | def rotate_img(self, img, rot): 70 | if rot == 0: # 0 degrees rotation 71 | return img 72 | elif rot == 90: # 90 degrees rotation 73 | return np.flipud(np.transpose(img, (1, 0, 2))) 74 | elif rot == 180: # 90 degrees rotation 75 | return np.fliplr(np.flipud(img)) 76 | elif rot == 270: # 270 degrees rotation / or -90 77 | return np.transpose(np.flipud(img), (1, 0, 2)) 78 | else: 79 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 80 | 81 | def __getitem__(self, item): 82 | image = cv2.imdecode(self.features[item], cv2.IMREAD_COLOR)[..., ::-1] 83 | if np.random.randint(2): 84 | image = np.fliplr(image) 85 | image_90 = self.transforms(self.rotate_img(image, 90)) 86 | image_180 = self.transforms(self.rotate_img(image, 180)) 87 | image_270 = self.transforms(self.rotate_img(image, 270)) 88 | images = torch.stack([self.transforms(image), image_90, image_180, image_270]) * 2 - 1 89 | return images, torch.ones(4, dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels) 90 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import argparse 4 | import pandas as pd 5 | import pickle, os 6 | import numpy as np 7 | from . import pretraining, finetuning, ssl_wrapper 8 | 9 | def get_model(model_name, backbone, n_classes, exp_dict, pretrained_weights_dir=None, savedir_base=None): 10 | if model_name == "pretraining": 11 | model = pretraining.PretrainWrapper(backbone, n_classes, exp_dict) 12 | 13 | elif model_name == "finetuning": 14 | model = finetuning.FinetuneWrapper(backbone, n_classes, exp_dict) 15 | 16 | elif model_name == "ssl": 17 | model = ssl_wrapper.SSLWrapper(backbone, n_classes, exp_dict, savedir_base=savedir_base) 18 | 19 | else: 20 | raise ValueError('model does not exist...') 21 | 22 | # load pretrained model 23 | if pretrained_weights_dir: 24 | s_path = os.path.join(os.path.dirname(pretrained_weights_dir), 'score_list_best.pkl') 25 | if not os.path.exists(s_path): 26 | s_path = os.path.join(exp_dict['checkpoint_exp_id'], 27 | 'score_list.pkl') 28 | print('Loaded checkpoint from exp_id: %s' % 29 | os.path.split(os.path.dirname(pretrained_weights_dir))[-1] 30 | ) 31 | print('Fine-tuned accuracy: %.3f' % hu.load_pkl(s_path)[-1]['test_accuracy']) 32 | model.model.load_state_dict(torch.load(pretrained_weights_dir)['model']) 33 | 34 | return model 35 | 36 | # =============================================== 37 | # Trainers 38 | def train_on_loader(model, train_loader): 39 | model.train() 40 | 41 | n_batches = len(train_loader) 42 | train_monitor = TrainMonitor() 43 | for e in range(1): 44 | for i, batch in enumerate(train_loader): 45 | score_dict = model.train_on_batch(batch) 46 | 47 | train_monitor.add(score_dict) 48 | if i % 10 == 0: 49 | msg = "%d/%d %s" % (i, n_batches, train_monitor.get_avg_score()) 50 | 51 | print(msg) 52 | 53 | return train_monitor.get_avg_score() 54 | 55 | def val_on_loader(model, val_loader, val_monitor): 56 | model.eval() 57 | 58 | n_batches = len(val_loader) 59 | 60 | for i, batch in enumerate(val_loader): 61 | score = model.val_on_batch(batch) 62 | 63 | val_monitor.add(score) 64 | if i % 10 == 0: 65 | msg = "%d/%d %s" % (i, n_batches, val_monitor.get_avg_score()) 66 | 67 | print(msg) 68 | 69 | 70 | return val_monitor.get_avg_score() 71 | 72 | 73 | @torch.no_grad() 74 | def vis_on_loader(model, vis_loader, savedir): 75 | model.eval() 76 | 77 | n_batches = len(vis_loader) 78 | split = vis_loader.dataset.split 79 | for i, batch in enumerate(vis_loader): 80 | print("%d - visualizing %s image - savedir:%s" % (i, batch["meta"]["split"][0], savedir.split("/")[-2])) 81 | model.vis_on_batch(batch, savedir=savedir) 82 | 83 | 84 | def test_on_loader(model, test_loader): 85 | model.eval() 86 | ae = 0. 87 | n_samples = 0. 88 | 89 | n_batches = len(test_loader) 90 | pbar = tqdm.tqdm(total=n_batches) 91 | for i, batch in enumerate(test_loader): 92 | pred_count = model.predict(batch, method="counts") 93 | 94 | ae += abs(batch["counts"].cpu().numpy().ravel() - pred_count.ravel()).sum() 95 | n_samples += batch["counts"].shape[0] 96 | 97 | pbar.set_description("TEST mae: %.4f" % (ae / n_samples)) 98 | pbar.update(1) 99 | 100 | pbar.close() 101 | score = ae / n_samples 102 | print({"test_score": score, "test_mae":score}) 103 | 104 | return {"test_score": score, "test_mae":score} 105 | 106 | 107 | class TrainMonitor: 108 | def __init__(self): 109 | self.score_dict_sum = {} 110 | self.n = 0 111 | 112 | def add(self, score_dict): 113 | for k,v in score_dict.items(): 114 | if k not in self.score_dict_sum: 115 | self.score_dict_sum[k] = score_dict[k] 116 | else: 117 | self.n += 1 118 | self.score_dict_sum[k] += score_dict[k] 119 | 120 | def get_avg_score(self): 121 | return {k:v/(self.n + 1) for k,v in self.score_dict_sum.items()} 122 | 123 | -------------------------------------------------------------------------------- /src/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import resnet12, conv4, wrn 3 | 4 | def get_backbone(backbone_name, exp_dict): 5 | if backbone_name == "resnet12": 6 | backbone = resnet12.Resnet12(width=1, dropout=exp_dict["dropout"]) 7 | elif backbone_name == "conv4": 8 | backbone = conv4.Conv4(exp_dict) 9 | elif backbone_name == "wrn": 10 | backbone = wrn.WideResNet(depth=exp_dict["model"]["depth"], width=exp_dict["model"]["width"], exp_dict=exp_dict) 11 | 12 | return backbone 13 | -------------------------------------------------------------------------------- /src/models/backbones/conv4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class Conv4(torch.nn.Module): 7 | def __init__(self, exp_dict): 8 | super().__init__() 9 | self.conv0 = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False) 10 | self.bn0 = torch.nn.BatchNorm2d(64) 11 | self.conv1 = torch.nn.Conv2d(64, 64, 3, 1, 1, bias=False) 12 | self.bn1 = torch.nn.BatchNorm2d(64) 13 | self.conv2 = torch.nn.Conv2d(64, 64, 3, 1, 1, bias=False) 14 | self.bn2 = torch.nn.BatchNorm2d(64) 15 | self.conv3 = torch.nn.Conv2d(64, 64, 3, 1, 1, bias=False) 16 | self.bn3 = torch.nn.BatchNorm2d(64) 17 | self.exp_dict = exp_dict 18 | if self.exp_dict["avgpool"] == True: 19 | self.output_size = 64 20 | else: 21 | self.output_size = 1600 22 | 23 | def add_classifier(self, no, name="classifier", modalities=None): 24 | setattr(self, name, torch.nn.Linear(self.output_size, no)) 25 | 26 | def forward(self, x, *args, **kwargs): 27 | *dim, c, h, w = x.size() 28 | x = x.view(-1, c, h, w) 29 | x = self.conv0(x) # 84 30 | x = F.relu(self.bn0(x), True) 31 | x = F.max_pool2d(x, 2, 2, 0) # 84 -> 42 32 | x = self.conv1(x) 33 | x = F.relu(self.bn1(x), True) 34 | x = F.max_pool2d(x, 2, 2, 0) # 42 -> 21 35 | x = self.conv2(x) 36 | x = F.relu(self.bn2(x), True) 37 | x = F.max_pool2d(x, 2, 2, 0) # 21 -> 10 38 | x = self.conv3(x) 39 | x = F.relu(self.bn3(x), True) 40 | x = F.max_pool2d(x, 2, 2, 0) # 21 -> 5 41 | if self.exp_dict["avgpool"] == True: 42 | x = x.mean(3, keepdim=True).mean(2, keepdim=True) 43 | return x.view(*dim, self.output_size) -------------------------------------------------------------------------------- /src/models/backbones/resnet12.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class Block(torch.nn.Module): 7 | def __init__(self, ni, no, stride, dropout=0, groups=1): 8 | super().__init__() 9 | self.dropout = torch.nn.Dropout2d(dropout) if dropout > 0 else lambda x: x 10 | self.conv0 = torch.nn.Conv2d(ni, no, 3, stride, padding=1, bias=False) 11 | self.bn0 = torch.nn.BatchNorm2d(no) 12 | self.conv1 = torch.nn.Conv2d(no, no, 3, 1, padding=1, bias=False) 13 | self.bn1 = torch.nn.BatchNorm2d(no) 14 | self.conv2 = torch.nn.Conv2d(no, no, 3, 1, padding=1, bias=False) 15 | self.bn2 = torch.nn.BatchNorm2d(no) 16 | if stride == 2 or ni != no: 17 | self.shortcut = torch.nn.Conv2d(ni, no, 1, stride=1, padding=0) 18 | 19 | def get_parameters(self): 20 | return self.parameters() 21 | 22 | def forward(self, x, is_support=True): 23 | y = F.relu(self.bn0(self.conv0(x)), True) 24 | y = self.dropout(y) 25 | y = F.relu(self.bn1(self.conv1(y)), True) 26 | y = self.dropout(y) 27 | y = self.bn2(self.conv2(y)) 28 | return F.relu(y + self.shortcut(x), True) 29 | 30 | 31 | class Resnet12(torch.nn.Module): 32 | def __init__(self, width, dropout): 33 | super().__init__() 34 | self.output_size = 512 35 | assert(width == 1) # Comment for different variants of this model 36 | self.widths = [x * int(width) for x in [64, 128, 256]] 37 | self.widths.append(self.output_size * width) 38 | self.bn_out = torch.nn.BatchNorm1d(self.output_size) 39 | 40 | start_width = 3 41 | for i in range(len(self.widths)): 42 | setattr(self, "group_%d" %i, Block(start_width, self.widths[i], 1, dropout)) 43 | start_width = self.widths[i] 44 | 45 | def add_classifier(self, nclasses, name="classifier", modalities=None): 46 | setattr(self, name, torch.nn.Linear(self.output_size, nclasses)) 47 | 48 | def up_to_embedding(self, x, is_support): 49 | """ Applies the four residual groups 50 | Args: 51 | x: input images 52 | n: number of few-shot classes 53 | k: number of images per few-shot class 54 | is_support: whether the input is the support set (for non-transductive) 55 | """ 56 | for i in range(len(self.widths)): 57 | x = getattr(self, "group_%d" % i)(x, is_support) 58 | x = F.max_pool2d(x, 3, 2, 1) 59 | return x 60 | 61 | def forward(self, x, is_support): 62 | """Main Pytorch forward function 63 | 64 | Returns: class logits 65 | 66 | Args: 67 | x: input mages 68 | is_support: whether the input is the sample set 69 | """ 70 | *args, c, h, w = x.size() 71 | x = x.view(-1, c, h, w) 72 | x = self.up_to_embedding(x, is_support) 73 | return F.relu(self.bn_out(x.mean(3).mean(2)), True) -------------------------------------------------------------------------------- /src/models/backbones/wrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class Block(torch.nn.Module): 8 | def __init__(self, ni, no, stride, dropout=0): 9 | super().__init__() 10 | self.conv0 = torch.nn.Conv2d(ni, no, 3, stride=stride, padding=1, bias=False) 11 | self.bn0 = torch.nn.BatchNorm2d(no) 12 | torch.nn.init.kaiming_normal_(self.conv0.weight.data) 13 | self.bn1 = torch.nn.BatchNorm2d(no) 14 | if dropout == 0: 15 | self.dropout = lambda x: x 16 | else: 17 | self.dropout = torch.nn.Dropout2d(dropout) 18 | self.conv1 = torch.nn.Conv2d(no, no, 3, stride=1, padding=1, bias=False) 19 | torch.nn.init.kaiming_normal_(self.conv1.weight.data) 20 | self.reduce = ni != no 21 | if self.reduce: 22 | self.conv_reduce = torch.nn.Conv2d(ni, no, 1, stride=stride, bias=False) 23 | torch.nn.init.kaiming_normal_(self.conv_reduce.weight.data) 24 | 25 | def forward(self, x): 26 | y = self.conv0(x) 27 | y = F.relu(self.bn0(y), inplace=True) 28 | y = self.dropout(y) 29 | y = self.conv1(y) 30 | y = self.bn1(y) 31 | if self.reduce: 32 | return F.relu(y + self.conv_reduce(x), True) 33 | else: 34 | return F.relu(y + x, True) 35 | 36 | 37 | class Group(torch.nn.Module): 38 | def __init__(self, ni, no, n, stride, dropout=0): 39 | super().__init__() 40 | self.n = n 41 | for i in range(n): 42 | self.__setattr__("block_%d" % i, Block(ni if i == 0 else no, no, stride if i == 0 else 1, dropout=dropout)) 43 | 44 | def forward(self, x): 45 | for i in range(self.n): 46 | x = self.__getattr__("block_%d" % i)(x) 47 | return x 48 | 49 | 50 | class WideResNet(torch.nn.Module): 51 | def __init__(self, depth, width, exp_dict): 52 | super(WideResNet, self).__init__() 53 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 54 | self.n = (depth - 4) // 6 55 | self.output_size = 640 56 | self.widths = torch.Tensor([16, 32, 64]).mul(width).int().numpy().tolist() 57 | self.conv0 = torch.nn.Conv2d(3, self.widths[0] // 2, 3, padding=1, bias=False) 58 | self.bn_0 = torch.nn.BatchNorm2d(self.widths[0] // 2) 59 | self.dropout_prob = exp_dict["dropout"] 60 | self.group_0 = Group(self.widths[0] // 2, self.widths[0], self.n, 2, dropout=self.dropout_prob) 61 | self.group_1 = Group(self.widths[0], self.widths[1], self.n, 2, dropout=self.dropout_prob) 62 | self.group_2 = Group(self.widths[1], self.widths[2], self.n, 2, dropout=self.dropout_prob) 63 | self.bn_out = torch.nn.BatchNorm1d(self.output_size) 64 | 65 | def get_base_parameters(self): 66 | parameters = [] 67 | parameters += list(self.conv0.parameters()) 68 | parameters += list(self.group_0.parameters()) 69 | parameters += list(self.group_1.parameters()) 70 | parameters += list(self.group_2.parameters()) 71 | parameters += list(self.bn.parameters()) 72 | if self.embedding: 73 | parameters += list(self.conv_embed) 74 | return parameters 75 | 76 | def get_classifier_parameters(self): 77 | return self.classifier.parameters() 78 | 79 | def add_classifier(self, nclasses, name="classifier", modalities=None): 80 | setattr(self, name, torch.nn.Linear(self.output_size, nclasses)) 81 | 82 | def forward(self, x, **kwargs): 83 | o = F.relu(self.bn_0(self.conv0(x)), True) 84 | o = self.group_0(o) 85 | o = self.group_1(o) 86 | o = self.group_2(o) 87 | o = o.mean(3).mean(2) 88 | o = F.relu(self.bn_out(o.view(o.size(0), -1))) 89 | return o -------------------------------------------------------------------------------- /src/models/base_ssl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/models/base_ssl/__init__.py -------------------------------------------------------------------------------- /src/models/base_ssl/distances.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | # import seaborn as sns 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import warnings 8 | from functools import partial 9 | # from tools.meters import BasicMeter 10 | # from train import TensorLogger 11 | import sys 12 | 13 | def checknan(**kwargs): 14 | for item, value in kwargs.items(): 15 | v = value.data.sum() 16 | if torch.isnan(v) or torch.isinf(torch.abs(v)): 17 | return "%s is NaN or inf" %item 18 | 19 | def _make_aligned_labels(inputs): 20 | batch, n_classes, n_sample_pc, z_dim = inputs.shape 21 | identity = torch.eye(n_classes, dtype=inputs.dtype, device=inputs.device) 22 | return identity[None, :, None, :].expand(batch, -1, n_sample_pc, -1).contiguous() 23 | 24 | 25 | def matching_nets(support_set, query_set, *sets, distance_type="euclidean", **kwargs): 26 | """ Computes the logits using the method described in [1] 27 | 28 | [1] Vinyals, Oriol, et al. "Matching networks for one shot learning." NeurIPS. 2016. 29 | 30 | Args: 31 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of 32 | each images. 33 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of 34 | each images. 35 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation 36 | z of each images. 37 | euclidean: Whether to use the euclidean distance or the cosine distance 38 | 39 | Returns: 40 | Class logits (warning, this function returns log probabilities, so NLL loss is recommended) 41 | """ 42 | euclidean = distance_type == "euclidean" 43 | _support_set, _support_labels = support_set 44 | _query_set, _query_labels = query_set 45 | b, n, sk, z = _support_set.size() 46 | b, n, qk, z = _query_set.size() 47 | if isinstance(_support_labels, bool): 48 | labels = _make_aligned_labels(_support_set) 49 | else: 50 | labels = _support_labels 51 | if euclidean: 52 | _support_set = _support_set.view(b, 1, n * sk, z) 53 | _query_set = _query_set.view(b, n * qk, 1, z) 54 | att = - ((_support_set - _query_set) ** 2).sum(3) / np.sqrt(z) 55 | else: 56 | _support_set = F.normalize(_support_set, dim=3) 57 | _query_set = F.normalize(_query_set, dim=3) 58 | _support_set = _support_set.view(b, n * sk, z).transpose(2, 1) 59 | _query_set = _query_set.view(b, n * qk, z) 60 | att = torch.matmul(_query_set, _support_set) 61 | att = F.softmax(att, dim=2).view(b, n * qk, 1, n * sk) 62 | labels = labels.view(b, 1, n * sk, n) 63 | return torch.log(torch.matmul(att, labels).view(b * n * qk, n)) 64 | 65 | 66 | def prototype_distance(support_set, query_set, *args, **kwargs): 67 | """Computes distance from each element of the query set to prototypes in the sample set. 68 | 69 | Args: 70 | sample_set: tuple of (Tensor, is_labeled=True). The Tensor has shape 71 | (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of each images. 72 | query_set: tuple of (Tensor, is_labeled=False). The tensor has shape, 73 | (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of each images. 74 | 75 | Returns: 76 | Tensor of shape (batch, n_total_query, n_classes) containing the similarity between each pair of query, 77 | prototypes, for each task. 78 | """ 79 | _support_set, _support_labels = support_set 80 | _query_set, _query_labels = query_set 81 | b, n, query_size, c = _query_set.size() 82 | _support_set = _support_set.mean(2).view(b, 1, n, c) 83 | _query_set = _query_set.view(b, n * query_size, 1, c) 84 | d = _query_set - _support_set 85 | return -torch.sum(d ** 2, 3) / np.sqrt(c) 86 | 87 | 88 | def gauss_distance(sample_set, query_set, unlabeled_set=None): 89 | """ (experimental) function to try different approaches to model prototypes as gaussians 90 | Args: 91 | sample_set: features extracted from the sample set 92 | query_set: features extracted from the query set 93 | query_set: features extracted from the unlabeled set 94 | 95 | """ 96 | b, n, k, c = sample_set.size() 97 | sample_set_std = sample_set.std(2).view(b, 1, n, c) 98 | sample_set_mean = sample_set.mean(2).view(b, 1, n, c) 99 | query_set = query_set.view(b, n * k, 1, c) 100 | d = (query_set - sample_set_mean) / sample_set_std 101 | return -torch.sum(d ** 2, 3) / np.sqrt(c) 102 | 103 | 104 | def _make_aligned_labels(inputs): 105 | """Uses the shape of inputs to infer batch_size, n_classes, and n_sample_per_class. From this, we build the one-hot 106 | encoding label tensor aligned with inputs. This is used to keep the lable information when tensors are flatenned 107 | across the n_class and n_sample_per_class. 108 | 109 | Args: 110 | inputs: tensor of shape (batch, n_classes, n_sample_per_class, z_dim) containing encoded examples for a task. 111 | Returns: 112 | tensor of shape (batch, n_classes, n_sample_pc, n_classes) containing the one-hot encoding label of each example 113 | """ 114 | batch, n_classes, n_sample_pc, z_dim = inputs.shape 115 | identity = torch.eye(n_classes, dtype=inputs.dtype, device=inputs.device) 116 | return identity[None, :, None, :].expand(batch, -1, n_sample_pc, -1).contiguous() 117 | 118 | 119 | def generalized_pw_sq_dist(data, d_type="euclidean"): 120 | batch, n_classes, n_samples, z_dim = data.size() 121 | data = data.view(batch, -1, z_dim) 122 | if d_type == "euclidean": 123 | return torch.sum((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim) 124 | elif d_type == "l1": 125 | return torch.mean(torch.abs(data[:, :, None, :] - data[:, None, :, :]), dim=3) 126 | elif d_type == "stable_euclidean": 127 | return torch.sqrt(1e-6 + torch.mean((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim)) 128 | elif d_type == "cosine": 129 | data = F.normalize(data, dim=2) 130 | return torch.bmm(data, data.transpose(2, 1)) 131 | else: 132 | raise ValueError("Distance type not recognized") 133 | 134 | 135 | def pw_sq_dist(sample_set, query_set, unlabeled_set=None, label_offset=0): 136 | """Computes distance from each element of the query set to prototypes in the sample set. 137 | 138 | Args: 139 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_class, z_dim) containing the representation z of 140 | each images. 141 | query_set: Tensor of shape (batch, n_classes, n_query_per_class, z_dim) containing the representation z of 142 | each images. 143 | 144 | Returns: 145 | dist: Tensor of shape (batch, n_total, n_total) containing the squared distance between each pair. 146 | labels: Tensor of shape (batch, n_total, n_classes) Containing the one hot vector for the sample set and zeros 147 | for the query set. 148 | """ 149 | batch, n_classes, _, z_dim = sample_set.shape 150 | sample_labels = _make_aligned_labels(sample_set) 151 | query_labels = _make_aligned_labels(query_set) * 0. + label_offset # XXX: Set labels to a constant 152 | # TODO: it's a bit sketchy that this function is used to extract the labels. Should be done externally. 153 | labels = torch.cat((sample_labels, query_labels), dim=2).view(batch, -1, n_classes) 154 | samples = torch.cat((sample_set, query_set), dim=2).view(batch, -1, z_dim) 155 | return torch.sum((samples[:, :, None, :] - samples[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim), labels 156 | 157 | 158 | def _ravel_index(index, shape): 159 | shape0 = np.prod(shape[:-1]) 160 | shape1 = shape[-1] 161 | new_shape = list(shape[:-1]) + [1] 162 | offsets = (torch.arange(shape0, dtype=index.dtype, device=index.device) * shape1).view(new_shape) 163 | return (index + offsets).view(-1) 164 | 165 | 166 | def _learned_scale_adjustment(scale, mlp, sq_dist, batch, n_sample_set, n_query_set, n_classes): 167 | """ 168 | Learn a data-dependent adjustment of the scale factor used in the distance computation 169 | 170 | scale: float 171 | The original scale factor 172 | mlp: nn.Module 173 | The model used to learn the scale factor adjustment 174 | sq_dist: torch.Tensor, shape=(bs, n_total, n_total) 175 | A matrix of squared distances between all the examples 176 | n_sample_set: int 177 | Number of examples in the sample set 178 | n_query_set: int 179 | Number of examples in the query set 180 | n_classes: int 181 | Number of classes 182 | 183 | """ 184 | global_mean = sq_dist.view(batch, -1).mean(1, keepdim=True).repeat(batch, sq_dist.size(1)) 185 | global_std = sq_dist.view(batch, -1).std(1, keepdim=True).repeat(batch, sq_dist.size(1)) 186 | avg_mean = sq_dist.mean(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1)) 187 | avg_std = sq_dist.std(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1)) 188 | sample_distances = sq_dist[:, :, :(n_classes * n_sample_set)].clone().view(batch, -1, n_sample_set) 189 | class_mean = sample_distances.mean(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1)) 190 | class_std = sample_distances.std(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1)) 191 | sample_mean = sq_dist.mean(2) 192 | sample_std = sq_dist.std(2) 193 | n = torch.ones_like(sample_mean) * n_classes 194 | k = torch.ones_like(sample_mean) * n_sample_set 195 | q = torch.ones_like(sample_mean) * n_query_set 196 | vin = torch.stack( 197 | [global_mean, global_std, sample_mean, sample_std, avg_mean, avg_std, class_mean, class_std, n, k, q], 2) 198 | scales = mlp(vin.view(-1, 11)).view(sq_dist.size(0), sq_dist.size(1), 1) 199 | return scale + F.softplus(scales) 200 | 201 | 202 | def label_prop(sample_set, query_set, unlabeled_set=None, alpha=0.1, scale_factor=1, label_offset=0, apply_log=False, 203 | topk=0, method="global_consistancy", epsilon=1e-8, mlp=None, return_all=False, normalize_weights=False, 204 | debug_plot_path=None, propagator=None, labels=None, weights=None): 205 | """Uses the laplacian graph to smooth the labels matrix by "propagating" labels. 206 | 207 | Args: 208 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of 209 | each images. 210 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of 211 | each images. 212 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation 213 | z of each images. 214 | alpha: Smoothing factor in the laplacian graph 215 | scale_factor: scale modifying the euclidean distance before exponential kernel. 216 | label_offset: Applies an offset to the labels before propagating. Has an effect on the degree of uncertainty 217 | when apply_log is True. 218 | apply_log: if True, it is assumed that the label propagation methods returns un-normalized probabilities. Hence 219 | to return logits, applying logarithm is necessary. 220 | topk: limit the weight matrix to the topk most similiar 221 | method: "regularized_laplacian" or "global_consistancy". 222 | epsilon: small value used when apply_log is True. 223 | mlp: If 1, it trains an MLP to predict the scaling factor from weight matrix stats. If 2, it does it without 224 | passing backprop to the main architecture. 225 | return_all: For debugging purpose. 226 | 227 | Returns: 228 | Tensor of shape (batch, n_total_query, n_classes) representing the logits of each classes 229 | 230 | """ 231 | init_query_size = query_set.size()[2] 232 | if unlabeled_set is not None: 233 | init_unlabel_size = unlabeled_set.size()[2] 234 | query_set = torch.cat((unlabeled_set, query_set), dim=2) # XXX: Concat order is important to logit extraction 235 | 236 | # Get data shape 237 | batch, n_classes, n_query_pc, z_dim = query_set.size() 238 | batch, n_classes, n_sample_pc, z_dim = sample_set.size() 239 | 240 | if propagator is None: 241 | # Compute the pairwise distance between the examples of the sample and query sets 242 | # XXX: labels are set to a constant for the query set 243 | sq_dist, labels = pw_sq_dist(sample_set, query_set, label_offset) 244 | 245 | # Learn to adjust the scale 246 | if mlp is not None: 247 | scale_factor = _learned_scale_adjustment(scale_factor, mlp, sq_dist, batch, n_sample_pc, n_query_pc, 248 | n_classes) 249 | 250 | # Compute similarity between the examples -- inversely proportional to distance 251 | weights = torch.exp(-0.5 * sq_dist / scale_factor ** 2) 252 | 253 | # Discard similarity for examples other than the top k most similar 254 | if topk > 0: 255 | weights, ind = torch.sort(weights, dim=2) # ascending order 256 | weights[:, :, :-int(topk + 1)] *= 0 257 | weights = weights.scatter(2, ind, weights) 258 | 259 | # Normalize the weights 260 | if normalize_weights: 261 | weights = weights / torch.sum(weights, dim=2, keepdim=True) 262 | 263 | if (method == "regularized_laplacian") or (method is None): 264 | logits, propagator = regularized_laplacian(weights, labels, alpha=alpha) 265 | elif method == "global_consistancy": 266 | logits, propagator = global_consistency(weights, labels, alpha=alpha) 267 | else: 268 | raise Exception("Unknonwn method %s." % method) 269 | else: 270 | logits = _propagate(labels, propagator) 271 | 272 | if debug_plot_path is not None: 273 | print("Sample set:", n_sample_pc, " Query set:", init_query_size, 274 | " unlabeled set:", 0 if unlabeled_set is None else n_unlabeled_pc) 275 | # XXX: Only saves the first batch elements 276 | np.save(debug_plot_path + "_weights.npy", weights[0].detach().cpu().numpy()) 277 | np.save(debug_plot_path + "_propagator.npy", propagator[0].detach().cpu().numpy()) 278 | plt_labels = \ 279 | np.argmax(torch.cat((_make_aligned_labels(sample_set), 280 | _make_aligned_labels(query_set)), dim=2).view(batch, -1, n_classes).cpu().numpy(), 281 | axis=-1)[0] 282 | np.save(debug_plot_path + "_labels.npy", plt_labels) 283 | 284 | if apply_log: 285 | logits = torch.log(logits + epsilon) 286 | 287 | logits = logits.reshape(batch, n_classes, -1, n_classes) 288 | if return_all: 289 | # Extracts only the logits for the query set 290 | # TODO: verify that this works as expected. <<<------- ***** 291 | query_labels = _make_aligned_labels(query_set) 292 | query_logits = logits[:, :, -init_query_size:, :].reshape(batch, -1, n_classes) 293 | if unlabeled_set is not None: 294 | unlabeled_labels = query_labels[:, :, :init_unlabel_size, :] 295 | query_labels = query_labels[:, :, init_unlabel_size:, :] 296 | unlabel_logits = logits[:, :, :init_unlabel_size, :].reshape(batch, -1, n_classes) 297 | else: 298 | unlabel_logits = None 299 | unlabeled_labels = None 300 | return query_logits, unlabel_logits, labels, weights, _make_aligned_labels( 301 | sample_set), query_labels, unlabeled_labels, propagator 302 | else: 303 | # Extracts only the logits for the query set 304 | # TODO: verify that this works as expected. <<<------- ***** 305 | logits = logits.reshape(batch, n_classes, -1, n_classes)[:, :, -init_query_size:, :].reshape(batch, -1, 306 | n_classes) 307 | return logits 308 | 309 | def make_one_hot_labels(labels, label_offset=0.): 310 | n_classes = labels.max() + 1 311 | assert (n_classes.item() == 5) 312 | to_one_hot = torch.eye(n_classes, device=labels.device, dtype=torch.float) 313 | one_hot_labels = to_one_hot[labels] 314 | mask = labels == -1 315 | one_hot_labels[mask, :] = label_offset 316 | return one_hot_labels 317 | 318 | 319 | # def labelprop(suppot_set, query_set, suppot_labels, 320 | # gaussian_scale=1, alpha=1, 321 | # propagator=None, weights=None, return_all=False, 322 | # apply_log=False, scale_bound="", standarize="", kernel="", square_root=False, 323 | # offset=0, epsilon=1e-6, dropout=0, n_classes=5): 324 | # if scale_bound == "softplus": 325 | # gaussian_scale = 0.01 + F.softplus(gaussian_scale) 326 | # alpha = 0.1 + F.softplus(alpha) 327 | # elif scale_bound == "square": 328 | # gaussian_scale = 1e-4 + gaussian_scale ** 2 329 | # alpha = 0.1 + alpha ** 2 330 | # elif scale_bound == "convex_relu": 331 | # #gaussian_scale = gaussian_scale ** 2 332 | # alpha = F.relu(alpha) + 0.1 333 | # elif scale_bound == "convex_square": 334 | # # gaussian_scale = gaussian_scale ** 2 335 | # alpha = 0.1 + alpha ** 2 336 | # elif scale_bound == "relu": 337 | # gaussian_scale = F.relu(gaussian_scale) + 0.01 338 | # alpha = F.relu(alpha) + 0.1 339 | # elif scale_bound == "constant": 340 | # gaussian_scale = 1 341 | # alpha = 1 342 | # elif scale_bound == "alpha_square": 343 | # alpha = 0.1 + F.relu(alpha) 344 | 345 | # samples = [] 346 | # labels = [] 347 | # offsets = [] 348 | # is_labeled = [] 349 | # for data, _labels in sets: 350 | # b, _, sample_size, c = data.size() 351 | # n = n_classes 352 | # samples.append(data) 353 | # offsets.append(sample_size) 354 | # if isinstance(_labels, bool): 355 | # labels.append(_make_aligned_labels(data)) 356 | # if not (_labels): 357 | # labels[-1].data[:] = labels[-1].data * 0 + offset 358 | # is_labeled.append(_labels) 359 | # else: 360 | # labels.append(_labels) 361 | # is_labeled.append((_labels.cpu().numpy() > 0).any()) 362 | # samples = torch.cat(samples, dim=2) 363 | # labels = torch.cat(labels, dim=2).view(b, -1, n) 364 | 365 | # if propagator is None: 366 | # # Compute the pairwise distance between the examples of the sample and query sets 367 | # # XXX: labels are set to a constant for the query set 368 | # sq_dist = generalized_pw_sq_dist(samples, "euclidean") 369 | # if square_root: 370 | # sq_dist = (sq_dist + epsilon).sqrt() 371 | # if standarize == "all": 372 | # mask = sq_dist != 0 373 | # # sq_dist = sq_dist - sq_dist[mask].mean() 374 | # sq_dist = sq_dist / sq_dist[mask].std() 375 | # elif standarize == "median": 376 | # mask = sq_dist != 0 377 | # gaussian_scale = torch.sqrt( 378 | # 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1))) 379 | # elif standarize == "frobenius": 380 | # mask = sq_dist != 0 381 | # sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt() 382 | # elif standarize == "percentile": 383 | # mask = sq_dist != 2 384 | # sorted, indices = torch.sort(sq_dist.data[mask]) 385 | # total = sorted.size(0) 386 | # gaussian_scale = sorted[int(total * 0.1)].detach() 387 | # if kernel == "rbf": 388 | # weights = torch.exp(-sq_dist * gaussian_scale) 389 | # elif kernel == "convex_rbf": 390 | # scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype) 391 | # weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1)) 392 | # weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1) 393 | # # checknan(timessoftmax=weights) 394 | # elif kernel == "euclidean": 395 | # # Compute similarity between the examples -- inversely proportional to distance 396 | # weights = 1 / (gaussian_scale + sq_dist) 397 | # elif kernel == "softmax": 398 | # weights = F.softmax(-sq_dist / gaussian_scale, -1) 399 | 400 | # mask = (torch.eye(weights.size(1), dtype=weights.dtype, device=weights.device)[None, :, :]).view(1, weights.size(1), weights.size(2)).expand(weights.size(0), -1, -1) 401 | # weights = weights * (1 - mask) 402 | # # checknan(masking=weights) 403 | 404 | 405 | # logits, propagator = global_consistency(weights, labels, alpha=alpha) 406 | # else: 407 | # logits = _propagate(labels, propagator) 408 | 409 | # if apply_log: 410 | # logits = torch.log(logits + epsilon) 411 | 412 | # logits = logits.view(b, n, -1, n) 413 | # logits_ret = [] 414 | # start = 0 415 | # for i, offset in enumerate(offsets): 416 | # logits_ret.append(logits[:, :, start:(start + offset), :].contiguous().view(b, -1, n)) 417 | # start += offset 418 | # if return_all: 419 | # # Extracts only the logits for the query set 420 | # # TODO: verify that this works as expected. <<<------- ***** 421 | # labels = labels.view(b, n, -1, n) 422 | # labels_ret = [] 423 | # start = 0 424 | # for i, offset in enumerate(offsets): 425 | # labels_ret.append(labels[:, :, start:(start + offset), :].contiguous().view(b, -1, n)) 426 | # start += offset 427 | # return tuple(logits_ret + labels_ret + [weights, propagator, labels]) 428 | # else: 429 | # # Extracts only the logits for the query set 430 | # # TODO: verify that this works as expected. <<<------- ***** 431 | # return tuple(logits_ret) 432 | 433 | def standarized_label_prop(*sets, gaussian_scale=1, alpha=1, 434 | propagator=None, weights=None, return_all=False, 435 | apply_log=False, scale_bound="", standarize="", kernel="", square_root=False, 436 | offset=0, epsilon=1e-6, dropout=0, n_classes=5): 437 | if scale_bound == "softplus": 438 | gaussian_scale = 0.01 + F.softplus(gaussian_scale) 439 | alpha = 0.1 + F.softplus(alpha) 440 | elif scale_bound == "square": 441 | gaussian_scale = 1e-4 + gaussian_scale ** 2 442 | alpha = 0.1 + alpha ** 2 443 | elif scale_bound == "convex_relu": 444 | #gaussian_scale = gaussian_scale ** 2 445 | alpha = F.relu(alpha) + 0.1 446 | elif scale_bound == "convex_square": 447 | # gaussian_scale = gaussian_scale ** 2 448 | alpha = 0.1 + alpha ** 2 449 | elif scale_bound == "relu": 450 | gaussian_scale = F.relu(gaussian_scale) + 0.01 451 | alpha = F.relu(alpha) + 0.1 452 | elif scale_bound == "constant": 453 | gaussian_scale = 1 454 | alpha = 1 455 | elif scale_bound == "alpha_square": 456 | alpha = 0.1 + F.relu(alpha) 457 | 458 | samples = [] 459 | labels = [] 460 | offsets = [] 461 | is_labeled = [] 462 | for data, _labels in sets: 463 | b, _, sample_size, c = data.size() 464 | n = n_classes 465 | samples.append(data) 466 | offsets.append(sample_size) 467 | if isinstance(_labels, bool): 468 | labels.append(_make_aligned_labels(data)) 469 | if not (_labels): 470 | labels[-1].data[:] = labels[-1].data * 0 + offset 471 | is_labeled.append(_labels) 472 | else: 473 | labels.append(_labels) 474 | is_labeled.append((_labels.cpu().numpy() > 0).any()) 475 | samples = torch.cat(samples, dim=2) 476 | labels = torch.cat(labels, dim=2).view(b, -1, n) 477 | 478 | if propagator is None: 479 | # Compute the pairwise distance between the examples of the sample and query sets 480 | # XXX: labels are set to a constant for the query set 481 | sq_dist = generalized_pw_sq_dist(samples, "euclidean") 482 | if square_root: 483 | sq_dist = (sq_dist + epsilon).sqrt() 484 | if standarize == "all": 485 | mask = sq_dist != 0 486 | # sq_dist = sq_dist - sq_dist[mask].mean() 487 | sq_dist = sq_dist / sq_dist[mask].std() 488 | elif standarize == "median": 489 | mask = sq_dist != 0 490 | gaussian_scale = torch.sqrt( 491 | 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1))) 492 | elif standarize == "frobenius": 493 | mask = sq_dist != 0 494 | sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt() 495 | elif standarize == "percentile": 496 | mask = sq_dist != 2 497 | sorted, indices = torch.sort(sq_dist.data[mask]) 498 | total = sorted.size(0) 499 | gaussian_scale = sorted[int(total * 0.1)].detach() 500 | if kernel == "rbf": 501 | weights = torch.exp(-sq_dist * gaussian_scale) 502 | elif kernel == "convex_rbf": 503 | scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype) 504 | weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1)) 505 | weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1) 506 | # checknan(timessoftmax=weights) 507 | elif kernel == "euclidean": 508 | # Compute similarity between the examples -- inversely proportional to distance 509 | weights = 1 / (gaussian_scale + sq_dist) 510 | elif kernel == "softmax": 511 | weights = F.softmax(-sq_dist / gaussian_scale, -1) 512 | 513 | mask = (torch.eye(weights.size(1), dtype=weights.dtype, device=weights.device)[None, :, :]).view(1, weights.size(1), weights.size(2)).expand(weights.size(0), -1, -1) 514 | weights = weights * (1 - mask) 515 | # checknan(masking=weights) 516 | 517 | 518 | logits, propagator = global_consistency(weights, labels, alpha=alpha) 519 | else: 520 | logits = _propagate(labels, propagator) 521 | 522 | if apply_log: 523 | logits = torch.log(logits + epsilon) 524 | 525 | logits = logits.view(b, n, -1, n) 526 | logits_ret = [] 527 | start = 0 528 | for i, offset in enumerate(offsets): 529 | logits_ret.append(logits[:, :, start:(start + offset), :].contiguous().view(b, -1, n)) 530 | start += offset 531 | if return_all: 532 | # Extracts only the logits for the query set 533 | # TODO: verify that this works as expected. <<<------- ***** 534 | labels = labels.view(b, n, -1, n) 535 | labels_ret = [] 536 | start = 0 537 | for i, offset in enumerate(offsets): 538 | labels_ret.append(labels[:, :, start:(start + offset), :].contiguous().view(b, -1, n)) 539 | start += offset 540 | return tuple(logits_ret + labels_ret + [weights, propagator, labels]) 541 | else: 542 | # Extracts only the logits for the query set 543 | # TODO: verify that this works as expected. <<<------- ***** 544 | return tuple(logits_ret) 545 | 546 | 547 | def regularized_laplacian(weights, labels, alpha): 548 | """Uses the laplacian graph to smooth the labels matrix by "propagating" labels 549 | 550 | Args: 551 | weights: Tensor of shape (batch, n, n) 552 | labels: Tensor of shape (batch, n, n_classes) 553 | alpha: Scaler, acts as a smoothing factor 554 | apply_log: if True, it is assumed that the label propagation methods returns un-normalized probabilities. Hence 555 | to return logits, applying logarithm is necessary. 556 | epsilon: value added before applying log 557 | Returns: 558 | Tensor of shape (batch, n, n_classes) representing the logits of each classes 559 | """ 560 | n = weights.shape[1] 561 | diag = torch.diag_embed(torch.sum(weights, dim=2)) 562 | laplacian = diag - weights 563 | identity = torch.eye(n, dtype=laplacian.dtype, device=laplacian.device)[None, :, :] 564 | propagator = torch.inverse(identity + alpha * laplacian) 565 | 566 | return _propagate(labels, propagator), propagator 567 | 568 | 569 | def global_consistency(weights, labels, alpha=0.1): 570 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug) 571 | 572 | Args: 573 | weights: Tensor of shape (batch, n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and 574 | s the scale parameter. 575 | labels: Tensor of shape (batch, n, n_classes) 576 | alpha: Scaler, acts as a smoothing factor 577 | Returns: 578 | Tensor of shape (batch, n, n_classes) representing the logits of each classes 579 | """ 580 | alpha_ = 1 / (1 + alpha) 581 | beta_ = alpha / (1 + alpha) 582 | n = weights.shape[1] 583 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)[None, :, :] 584 | #weights = weights * (1. - identity) # zero out diagonal 585 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=2)) 586 | # checknan(laplacian=isqrt_diag) 587 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, :, None] 588 | # checknan(normalizedlaplacian=S) 589 | propagator = identity - alpha_ * S 590 | propagator = torch.inverse(propagator) * beta_ 591 | # checknan(propagator=propagator) 592 | 593 | return _propagate(labels, propagator, scaling=1), propagator 594 | 595 | 596 | def _propagate(labels, propagator, scaling=1.): 597 | return torch.matmul(propagator, labels) * scaling 598 | 599 | 600 | class MLP(torch.nn.Module): 601 | def __init__(self, detach=False): 602 | super().__init__() 603 | self.__detach = detach 604 | self.linear1 = torch.nn.Linear(11, 128) 605 | self.linear2 = torch.nn.Linear(128, 16) 606 | self.linear3 = torch.nn.Linear(16, 1) 607 | 608 | def forward(self, x): 609 | if self.__detach: 610 | x = x.detach() 611 | x = F.relu(self.linear1(x), True) 612 | x = F.relu(self.linear2(x), True) 613 | return self.linear3(x) 614 | 615 | class Distance(torch.nn.Module): 616 | def __init__(self, exp_params): 617 | """ Helper to obtain a distance function from a string 618 | Args: 619 | distance_type: string indicating the distance type 620 | """ 621 | super().__init__() 622 | self.exp_params = exp_params 623 | self.distance_type, *args = exp_params["distance_type"].split(',') 624 | if self.distance_type in ["euclidean", "prototypical"]: 625 | self.d = prototype_distance 626 | elif self.distance_type == "labelprop": 627 | self.register_buffer("moving_alpha", torch.ones(1) * self.exp_params["labelprop_alpha_prior"]) 628 | if self.exp_params["kernel_type"] == "convex_rbf": 629 | scale_size = 10 630 | else: 631 | scale_size = 1 632 | self.register_buffer("moving_gaussian_scale", torch.ones(scale_size) * exp_params["labelprop_scale_prior"]) 633 | self.d = partial(standarized_label_prop, scale_bound=self.exp_params["kernel_bound"], 634 | kernel=self.exp_params["kernel_type"], 635 | standarize=self.exp_params["kernel_standarization"], 636 | square_root=self.exp_params["kernel_square_root"], 637 | apply_log=True) 638 | elif self.distance_type == "matching": 639 | matching_distance, = args 640 | self.d = partial(matching_nets, distance_type=matching_distance) 641 | elif self.distance_type == "labelprop_boris": 642 | """ 643 | TODO(@boris) define here the relation network, and define a new "label_prop" above, alike the original 644 | one parameters inside Distance are taken into account by the optimizer 645 | """ 646 | self.d = partial(label_prop, apply_log=True) 647 | 648 | def forward(self, *sets, **kwargs): 649 | return self.d(*sets, **kwargs) -------------------------------------------------------------------------------- /src/models/base_ssl/oracle.py: -------------------------------------------------------------------------------- 1 | from . import utils as ut 2 | import numpy as np 3 | import torch 4 | import sys 5 | import copy 6 | from scipy.stats import pearsonr 7 | import h5py 8 | import os 9 | import numpy as np 10 | import pylab 11 | 12 | from sklearn.cluster import KMeans 13 | # from ipywidgets import interact, interactive, fixed, interact_manual 14 | # import ipywidgets as widgets 15 | import torch 16 | import torch.nn.functional as F 17 | from functools import partial 18 | import json 19 | from skimage.io import imsave 20 | import tqdm 21 | import pprint 22 | 23 | import torch 24 | import sys 25 | sys.path.insert(0, os.path.abspath('..')) 26 | from .distances import prototype_distance # here we import the labelpropagation algorithm inside the "Distance" class 27 | import pandas 28 | import json 29 | 30 | # loading data 31 | # from trainers.few_shot_parallel_alpha_scale import inner_loop_lbfgs2, inner_loop_lbfgs_bootstrap 32 | from torch.utils.data import DataLoader 33 | 34 | class Sampler(object): 35 | """ 36 | Samples few shot tasks from precomputed embeddings 37 | """ 38 | def __init__(self, embeddings_fname, n_classes, distract_flag): 39 | self.h5fp = h5py.File(embeddings_fname, 'r') 40 | self.labels = self.h5fp["test_targets"][...] 41 | indices = np.arange(self.labels.shape[0]) 42 | self.label_indices = {i: indices[self.labels == i] for i in set(self.labels)} 43 | self.nclasses = len(self.label_indices.keys()) 44 | self.n_classes = n_classes 45 | self.distract_flag = distract_flag 46 | 47 | def sample_episode_indices(self, support_size, 48 | query_size, unlabeled_size, ways): 49 | """ 50 | Returns the indices of the images of a random episode with predefined support, query and unlabeled sizes. 51 | the number of images is expressed in "ways" 52 | """ 53 | label_indices = {k: np.random.permutation(v) for k,v in self.label_indices.items()} 54 | #label_indices = self.label_indices 55 | 56 | if self.distract_flag: 57 | classes = np.random.permutation(self.nclasses) 58 | distract_classes = classes[ways:(ways+ways)] 59 | classes = classes[:ways] 60 | else: 61 | classes = np.random.permutation(self.nclasses)[:ways] 62 | 63 | support_indices = [] 64 | query_indices = [] 65 | unlabel_indices = [] 66 | 67 | for cls in classes: 68 | start = 0 69 | end = support_size 70 | support_indices.append(label_indices[cls][start:end]) 71 | start = end 72 | end += query_size 73 | query_indices.append(label_indices[cls][start:end]) 74 | start = end 75 | end += unlabeled_size 76 | assert(end < len(label_indices[cls])) 77 | unlabel_indices.append(label_indices[cls][start:end]) 78 | 79 | if self.distract_flag: 80 | for cls in distract_classes: 81 | unlabel_indices.append(label_indices[cls][:unlabeled_size]) 82 | 83 | return np.vstack(support_indices), np.vstack(query_indices), np.vstack(unlabel_indices) 84 | 85 | def _sample_field(self, field, *indices): 86 | features = self.h5fp["test_{}".format(field)] 87 | ret = [] 88 | for _indices in indices: 89 | _indices = _indices.ravel() 90 | argind = np.argsort(_indices) 91 | if len(argind) == 0: 92 | ret.append(None) 93 | else: 94 | ind = _indices[argind] 95 | dset = features[ind.tolist()] 96 | dset[argind] = dset.copy() 97 | ret.append(dset) 98 | 99 | return tuple(ret) 100 | 101 | def sample_features(self, support_indices, query_indices, unlabel_indices): 102 | return self._sample_field("features", support_indices, query_indices, unlabel_indices) 103 | 104 | def sample_labels(self, support_indices, query_indices, unlabel_indices): 105 | return self._sample_field("targets", support_indices, query_indices, unlabel_indices) 106 | 107 | def sample_episode(self, support_size, query_size, unlabeled_size, apply_ten_flag=False): 108 | """ 109 | Randomly samples an episode (features and labels) given the size of each set and the number of classes 110 | 111 | Returns: tuple(numpy array). Sets are of the size (set_size * nclasses, a512). a512 is the number 112 | of channels of the embeddings 113 | """ 114 | ways = self.n_classes 115 | support_indices, query_indices, unlabel_indices = self.sample_episode_indices(support_size, query_size, unlabeled_size, ways) 116 | support_set, query_set, unlabel_set = self.sample_features(support_indices, 117 | query_indices, 118 | unlabel_indices) 119 | support_labels = ut.make_labels(support_size, ways) 120 | query_labels = ut.make_labels(query_size, ways) 121 | unlabel_labels = ut.make_labels(unlabeled_size, ways) 122 | 123 | episode_dict = episode2dict(support_set, query_set, unlabel_set, support_labels, query_labels, unlabel_labels) 124 | 125 | if apply_ten_flag: 126 | episode_dict = ut.apply_ten_on_episode(episode_dict) 127 | 128 | return episode_dict 129 | 130 | def episode2dict(support_set, query_set, unlabel_set, support_labels, query_labels, unlabel_labels): 131 | n_classes = support_set.shape[0] 132 | 133 | support_dict = {"samples": support_set, "labels":support_labels} 134 | query_dict = {"samples": query_set, "labels":query_labels} 135 | unlabeled_dict = {"samples": unlabel_set, "labels":unlabel_labels} 136 | 137 | return {"support":support_dict, 138 | "query":query_dict, 139 | "unlabeled":unlabeled_dict} 140 | 141 | 142 | 143 | def compute_acc(pred_labels, true_labels): 144 | acc = (true_labels.flatten() == pred_labels.ravel()).astype(float).mean() 145 | 146 | return acc 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /src/models/base_ssl/predict_methods/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import label_prop as lp 3 | from . import prototypical 4 | from . import adaptive 5 | 6 | def get_predictions(predict_method, episode_dict): 7 | if predict_method == "labelprop": 8 | return lp.label_prop_predict(episode_dict) 9 | 10 | elif predict_method == "prototypical": 11 | return prototypical.prototypical_predict(episode_dict) 12 | else: 13 | raise ValueError("Prediction method not found") -------------------------------------------------------------------------------- /src/models/base_ssl/predict_methods/adaptive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | # from models.wide_resnet_imagenet import WideResnetImagenet, Group 7 | # from modules.dynamic_residual_groups import get_group 8 | # from modules.ten import TEN 9 | # from modules.distances import Distance, _propagate, standarized_label_prop 10 | # from modules.output_heads import get_output_head 11 | # from modules.activations import get_activation 12 | # from modules.layers import MetricLinear 13 | from . import label_prop as lp 14 | 15 | 16 | def adaptive_predict(episode_dict, double_flag=False): 17 | # Get variables 18 | S = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda() 19 | S_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda() 20 | Q = torch.from_numpy(episode_dict["query"]["samples"]).cuda() 21 | U = torch.from_numpy(episode_dict["unlabeled"]["samples"]).cuda() 22 | 23 | SUQ = torch.cat([S, U, Q])[None] 24 | Q_labels = -1*torch.ones(Q.shape[0], device=S_labels.device, dtype=S_labels.dtype).cuda() 25 | U_labels = -1*torch.ones(U.shape[0], device=S_labels.device, dtype=S_labels.dtype).cuda() 26 | 27 | SUQ_labels = torch.cat([S_labels, U_labels, Q_labels]) 28 | # Init Adaptive 29 | adaptive = AdaptiveTenInner(output_size=S.shape[1], 30 | nclasses=5, 31 | double_flag=double_flag).cuda() 32 | 33 | # Apply Adaptive 34 | # adaptive.train_step() 35 | UQ_logits = adaptive.forward(x=SUQ, 36 | support_size=S.shape[0], 37 | query_size=U.shape[0] + Q.shape[0], 38 | labels=SUQ_labels) 39 | 40 | return UQ_logits[-Q.shape[0]:].argmax(dim=1) 41 | 42 | 43 | class AdaptiveTenInner(torch.nn.Module): 44 | def __init__(self, output_size, nclasses, double_flag, mu_init=0, scale_init=0, precision_init=0, **exp_dict): 45 | super().__init__() 46 | self.nclasses = nclasses 47 | self.precision = torch.nn.Parameter(torch.randn(1, 1, output_size) * 0.1) 48 | self.classifier = torch.nn.Linear(output_size, nclasses) 49 | self.exp_dict = exp_dict 50 | self.optimizer = torch.optim.LBFGS([self.precision] + list(self.classifier.parameters()), tolerance_grad=1e-5, tolerance_change=1e-5, lr=0.1) 51 | if double_flag: 52 | self.label_prop = lp.LabelpropDouble() 53 | else: 54 | self.label_prop = lp.Labelprop() 55 | 56 | def train_step(self, _x, support_size, query_size, labels): 57 | x = _x.clone() 58 | self.optimizer.zero_grad() 59 | b, k, c = x.size() 60 | x = x * torch.sigmoid(1 + self.precision) 61 | 62 | zeros = torch.zeros(1, k, self.nclasses, device=x.device) 63 | 64 | logits, propagator = standarized_label_prop(x, zeros, 65 | 1, 1, 66 | apply_log=True, scale_bound="", 67 | standarize="all", kernel="rbf") 68 | x = _propagate(x, propagator) 69 | # support_set = x.view(-1, self.nclasses, c)[:support_size, ...].view(-1, c) 70 | support_set = x.view(-1, c)[:support_size, ...] 71 | support_labels = labels.view(-1, self.nclasses)[:support_size, ...].view(-1) 72 | logits = self.classifier(support_set) 73 | loss = F.cross_entropy(logits, support_labels) + 0.0001 * (self.precision ** 2).mean() 74 | loss.backward() 75 | return loss 76 | 77 | def forward(self, x, support_size, query_size, labels): 78 | self.train() 79 | self.optimizer.step(lambda: self.train_step(x, support_size, query_size, labels)) 80 | # self.train_step(x, support_size, query_size, labels) 81 | with torch.no_grad(): 82 | # mu, scale, precision = (0.5 + self.mu **2).detach(), (0.5 + self.scale ** 2).detach(), (1 + self.precision.detach()) 83 | mu, scale, precision = 1, 1, self.precision.detach() 84 | # mu, scale, precision = 1, 1, 1 85 | x = x * torch.sigmoid(1 + precision) 86 | one_hot_labels = F.one_hot(labels.view(-1)).float() 87 | one_hot_labels = one_hot_labels.view(1, 88 | support_size + query_size, 89 | self.nclasses) 90 | one_hot_labels[:, support_size:, ...] = 0 91 | 92 | logits, propagator = standarized_label_prop(x, one_hot_labels, 93 | scale, mu, apply_log=True, 94 | scale_bound="", standarize="all", 95 | kernel="rbf") 96 | x = _propagate(x, propagator) 97 | logits, propagator = standarized_label_prop(x, one_hot_labels, 98 | scale, mu, apply_log=True, 99 | scale_bound="", standarize="all", 100 | kernel="rbf") 101 | logits = logits.view((support_size + query_size), self.nclasses) 102 | return logits[support_size:].view(-1, self.nclasses) 103 | 104 | 105 | def standarized_label_prop(embeddings, 106 | labels, 107 | gaussian_scale=1, alpha=1, 108 | weights=None, 109 | apply_log=False, scale_bound="", standarize="", kernel="", square_root=False, 110 | epsilon=1e-6): 111 | if scale_bound == "softplus": 112 | gaussian_scale = 0.01 + F.softplus(gaussian_scale) 113 | alpha = 0.1 + F.softplus(alpha) 114 | elif scale_bound == "square": 115 | gaussian_scale = 1e-4 + gaussian_scale ** 2 116 | alpha = 0.1 + alpha ** 2 117 | elif scale_bound == "convex_relu": 118 | #gaussian_scale = gaussian_scale ** 2 119 | alpha = F.relu(alpha) + 0.1 120 | elif scale_bound == "convex_square": 121 | # gaussian_scale = gaussian_scale ** 2 122 | alpha = 0.1 + alpha ** 2 123 | elif scale_bound == "relu": 124 | gaussian_scale = F.relu(gaussian_scale) + 0.01 125 | alpha = F.relu(alpha) + 0.1 126 | elif scale_bound == "constant": 127 | gaussian_scale = 1 128 | alpha = 1 129 | elif scale_bound == "alpha_square": 130 | alpha = 0.1 + F.relu(alpha) 131 | 132 | # Compute the pairwise distance between the examples of the sample and query sets 133 | # XXX: labels are set to a constant for the query set 134 | sq_dist = generalized_pw_sq_dist(embeddings, "euclidean") 135 | if square_root: 136 | sq_dist = (sq_dist + epsilon).sqrt() 137 | if standarize == "all": 138 | mask = sq_dist != 0 139 | # sq_dist = sq_dist - sq_dist[mask].mean() 140 | sq_dist = sq_dist / sq_dist[mask].std() 141 | elif standarize == "median": 142 | mask = sq_dist != 0 143 | gaussian_scale = torch.sqrt( 144 | 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1))) 145 | elif standarize == "frobenius": 146 | mask = sq_dist != 0 147 | sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt() 148 | elif standarize == "percentile": 149 | mask = sq_dist != 2 150 | sorted, indices = torch.sort(sq_dist.data[mask]) 151 | total = sorted.size(0) 152 | gaussian_scale = sorted[int(total * 0.1)].detach() 153 | 154 | if kernel == "rbf": 155 | weights = torch.exp(-sq_dist * gaussian_scale) 156 | elif kernel == "convex_rbf": 157 | scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype) 158 | weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1)) 159 | weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1) 160 | # checknan(timessoftmax=weights) 161 | elif kernel == "euclidean": 162 | # Compute similarity between the examples -- inversely proportional to distance 163 | weights = 1 / (gaussian_scale + sq_dist) 164 | elif kernel == "softmax": 165 | weights = F.softmax(-sq_dist / gaussian_scale, -1) 166 | 167 | mask = torch.eye(weights.size(1), dtype=torch.bool, device=weights.device)[None, :, :] 168 | weights = weights * (~mask).float() 169 | 170 | logits, propagator = global_consistency(weights, labels, alpha=alpha) 171 | 172 | if apply_log: 173 | logits = torch.log(logits + epsilon) 174 | 175 | return logits, propagator 176 | 177 | def generalized_pw_sq_dist(data, d_type="euclidean"): 178 | batch, samples, z_dim = data.size() 179 | if d_type == "euclidean": 180 | return torch.sum((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim) 181 | elif d_type == "l1": 182 | return torch.mean(torch.abs(data[:, :, None, :] - data[:, None, :, :]), dim=3) 183 | elif d_type == "stable_euclidean": 184 | return torch.sqrt(1e-6 + torch.mean((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim)) 185 | elif d_type == "cosine": 186 | data = F.normalize(data, dim=2) 187 | return torch.bmm(data, data.transpose(2, 1)) 188 | else: 189 | raise ValueError("Distance type not recognized") 190 | 191 | def global_consistency(weights, labels, alpha=0.1): 192 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug) 193 | Args: 194 | weights: Tensor of shape (batch, n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and 195 | s the scale parameter. 196 | labels: Tensor of shape (batch, n, n_classes) 197 | alpha: Scaler, acts as a smoothing factor 198 | Returns: 199 | Tensor of shape (batch, n, n_classes) representing the logits of each classes 200 | """ 201 | n = weights.shape[1] 202 | _alpha = 1 / (1 + alpha) 203 | _beta = alpha / (1 + alpha) 204 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)[None, :, :] 205 | #weights = weights * (1. - identity) # zero out diagonal 206 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=2)) 207 | # checknan(laplacian=isqrt_diag) 208 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, :, None] 209 | # checknan(normalizedlaplacian=S) 210 | propagator = identity - _alpha * S 211 | propagator = torch.inverse(propagator) * _beta 212 | # checknan(propagator=propagator) 213 | 214 | return _propagate(labels, propagator, scaling=1), propagator 215 | 216 | 217 | def _propagate(labels, propagator, scaling=1.): 218 | return torch.matmul(propagator, labels) * scaling -------------------------------------------------------------------------------- /src/models/base_ssl/predict_methods/label_prop.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | # import seaborn as sns 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import warnings 8 | from functools import partial 9 | # from tools.meters import BasicMeter 10 | # from train import TensorLogger 11 | import sys 12 | from embedding_propagation import LabelPropagation 13 | 14 | def label_prop_predict(episode_dict): 15 | S = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda() 16 | S_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda() 17 | nclasses = int(S_labels.max() + 1) 18 | Q_labels = torch.zeros(episode_dict["query"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses 19 | U_labels = torch.zeros(episode_dict["unlabeled"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses 20 | A_labels = torch.cat([S_labels, Q_labels, U_labels], 0) 21 | Q = torch.from_numpy(episode_dict["query"]["samples"]).cuda() 22 | 23 | U = torch.from_numpy(episode_dict["unlabeled"]["samples"]).cuda() 24 | 25 | lp = LabelPropagation(balanced=True) 26 | 27 | SUQ = torch.cat([S, U, Q], dim=0) 28 | logits = lp(SUQ, A_labels, nclasses) 29 | logits_query = logits[-Q.shape[0]:] 30 | 31 | return logits_query.argmax(dim=1).cpu().numpy() -------------------------------------------------------------------------------- /src/models/base_ssl/predict_methods/prototypical.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def prototypical_predict(episode_dict): 6 | support_samples = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda() 7 | support_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda() 8 | query_samples = torch.from_numpy(episode_dict["query"]["samples"]).cuda() 9 | 10 | logits = prototype_distance(support_set=support_samples, 11 | query_set=query_samples, 12 | labels=support_labels, 13 | unlabeled_set=None) 14 | return logits.argmax(dim=1) 15 | 16 | 17 | 18 | 19 | def prototype_distance(support_set, query_set, labels, unlabeled_set=None): 20 | """Computes distance from each element of the query set to prototypes in the sample set. 21 | Args: 22 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of 23 | each images. 24 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of 25 | each images. 26 | labels: Tensor of Long of shape(support_set_size) 27 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation 28 | z of each images. 29 | Returns: 30 | Tensor of shape (batch, n_total_query, n_classes) containing the similarity between each pair of query, 31 | prototypes, for each task. 32 | """ 33 | n_queries, channels = query_set.size() 34 | n_support, channels = support_set.size() 35 | 36 | support_set = support_set.view(n_support, 1, channels) 37 | 38 | way = int(labels.data.max()) + 1 39 | one_hot_labels = torch.zeros(n_support, way, 1, dtype=support_set.dtype, device=support_set.device) 40 | one_hot_labels.scatter_(1, labels.view(n_support, 1, 1), 1) 41 | 42 | total_per_class = one_hot_labels.sum(0, keepdim=True) 43 | prototypes = (support_set * one_hot_labels).sum(0) / total_per_class 44 | prototypes = prototypes.view(1, way, channels) 45 | 46 | query_set = query_set.view(n_queries, 1, channels) 47 | d = query_set - prototypes 48 | return -torch.sum(d ** 2, 2) / np.sqrt(channels) -------------------------------------------------------------------------------- /src/models/base_ssl/selection_methods/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ssl 2 | import numpy as np 3 | 4 | def get_indices(selection_method, episode_dict, support_size_max=None): 5 | # random 6 | if selection_method == "random": 7 | ind = np.random.choice(episode_dict["unlabeled"]["samples"].shape[0], 1, replace=False) 8 | 9 | # random imbalanced 10 | if selection_method == "random_imbalanced": 11 | ind = np.random.choice(episode_dict["unlabeled"]["samples"].shape[0], 1, replace=False) 12 | 13 | # ssl 14 | if selection_method == "ssl": 15 | ind = ssl.ssl_get_next_best_indices(episode_dict) 16 | 17 | 18 | # episode_dict["selected_indices"] = ind 19 | return ind 20 | -------------------------------------------------------------------------------- /src/models/base_ssl/selection_methods/ssl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from ..predict_methods import label_prop as lp 4 | import numpy as np 5 | from .. import utils as ut 6 | from scipy import stats 7 | from embedding_propagation import LabelPropagation 8 | 9 | def ssl_get_next_best_indices(episode_dict, support_size_max=None): 10 | S = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda() 11 | Q = torch.from_numpy(episode_dict["query"]["samples"]).cuda() 12 | U = torch.from_numpy(episode_dict["unlabeled"]["samples"]).cuda() 13 | S_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda() 14 | nclasses = int(S_labels.max() + 1) 15 | Q_labels = torch.zeros(episode_dict["query"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses 16 | U_labels = torch.zeros(episode_dict["unlabeled"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses 17 | A_labels = torch.cat([S_labels, Q_labels, U_labels], 0) 18 | 19 | SQU = torch.cat([S, Q, U], dim=0) # Information gain is measured in the whole system 20 | 21 | # train label_prop 22 | lp = LabelPropagation(balanced=True) 23 | logits = lp(SQU, A_labels, nclasses) 24 | 25 | U_logits = logits[-U.shape[0]:] 26 | # modify the labels of the unlabeled 27 | episode_dict["unlabeled"]["labels"] = U_logits.argmax(dim=1).cpu().numpy() 28 | 29 | # unlabeled_scores = U_logits.max(dim=1).cpu().numpy() 30 | 31 | if support_size_max is None: 32 | # choose all the unlabeled examples 33 | return np.arange(U.shape[0]) 34 | else: 35 | # score each 36 | score_list = U_logits.max(dim=1)[0].cpu().numpy() 37 | return score_list.argsort()[-support_size_max:] 38 | 39 | def predict(S, S_labels, UQ, U_shape): 40 | label_prop = lp.Labelprop() 41 | label_prop.fit(support_set=S, unlabeled_set=UQ) 42 | 43 | logits = label_prop.predict(support_labels=S_labels, 44 | unlabeled_pseudolabels=None, 45 | balanced_flag=True) 46 | 47 | U_logits = logits[S.shape[0]:S.shape[0]+U_shape[0]] 48 | return U_logits.argmax(dim=1).cpu().numpy() -------------------------------------------------------------------------------- /src/models/base_ssl/utils.py: -------------------------------------------------------------------------------- 1 | 2 | # %% Import libraries & loading data & Helper Sampling Class & Pytorch helpers and imports 3 | import torch 4 | import sys 5 | import copy 6 | from scipy.stats import pearsonr 7 | import h5py 8 | import os 9 | import numpy as np 10 | import pylab 11 | from sklearn.linear_model import LogisticRegression 12 | from sklearn.cluster import KMeans 13 | # from ipywidgets import interact, interactive, fixed, interact_manual 14 | # import ipywidgets as widgets 15 | import torch 16 | import torch.nn.functional as F 17 | from functools import partial 18 | import json 19 | from skimage.io import imsave 20 | import tqdm 21 | import pprint 22 | import torch 23 | import sys 24 | from .distances import prototype_distance # here we import the labelpropagation algorithm inside the "Distance" class 25 | import pandas 26 | import json 27 | 28 | # loading data 29 | # from trainers.few_shot_parallel_alpha_scale import inner_loop_lbfgs2, inner_loop_lbfgs_bootstrap 30 | from torch.utils.data import DataLoader 31 | 32 | 33 | 34 | def get_unlabeled_set(method, unlabeled_set,unlabel_labels, unlabeled_size, 35 | support_set_reshaped): 36 | support_labels = get_support_labels(support_set_reshaped).ravel().astype(int) 37 | support_size = support_set_reshaped.shape[1] 38 | n_classes = support_set_reshaped.shape[0] 39 | if method == "cheating": 40 | unlabeled_set_new = unlabeled_set 41 | unlabel_labels_new = unlabel_labels 42 | unlabeled_size_new = unlabeled_size 43 | 44 | 45 | elif method == "prototypical": 46 | support_labels_reshaped = support_labels.reshape((n_classes, support_size)) 47 | unlabeled_set_torch = torch.FloatTensor(unlabeled_set).cuda() 48 | unlabeled_set_view = unlabeled_set_torch.view(1, n_classes, unlabeled_size, a512) 49 | 50 | if True: 51 | unlabeled_set_view, tmp_unlabel_labels = predict_sort(support_set_reshaped, 52 | unlabeled_set_view, 53 | n_classes=n_classes) 54 | else: 55 | tmp_unlabel_labels = predict(support_set_reshaped, 56 | unlabeled_set_view, 57 | n_classes=n_classes) 58 | 59 | unlabeled_set_list = [] 60 | # label 61 | unlabeled_size_new = np.inf 62 | for c in range(n_classes): 63 | set = unlabeled_set_torch[tmp_unlabel_labels == c] 64 | unlabeled_set_list += [set] 65 | unlabeled_size_new = min(len(set), unlabeled_size_new) 66 | 67 | # cut to shortest 68 | unlabel_labels_new = [] 69 | for c in range(n_classes): 70 | unlabeled_set_list[c] = unlabeled_set_list[c][:unlabeled_size_new] 71 | unlabel_labels_new += [np.ones(unlabeled_size_new) * c] 72 | 73 | unlabeled_set_new = torch.cat(unlabeled_set_list).cpu().numpy().astype(unlabeled_set.dtype) 74 | unlabel_labels_new = np.vstack(unlabel_labels_new).ravel().astype("int64") 75 | 76 | return unlabeled_set_new, unlabel_labels_new, unlabeled_size_new 77 | 78 | 79 | 80 | 81 | def xlogy(x, y=None): 82 | z = torch.zeros(()) 83 | if y is None: 84 | y = x 85 | assert y.min() >= 0 86 | return x * torch.where(x == 0., z, torch.log(y)) 87 | 88 | 89 | def get_support_labels(support_set_features): 90 | support_labels=[] 91 | for c in range(5): 92 | support_labels += [np.ones(support_set_features.shape[1])*c] 93 | 94 | support_labels = np.vstack(support_labels) 95 | return support_labels 96 | 97 | 98 | def get_entropy_support_set(monitor, support_size): 99 | support_set_ind = [[],[],[],[],[]] 100 | for s in range(1, support_size+1): 101 | # Get best next support 102 | entropy_best = 0 103 | for i in range(monitor.unlabeled_size): 104 | support_set_tmp = copy.deepcopy(support_set_ind) 105 | if i in support_set_tmp[0]: 106 | continue 107 | for c in range(monitor.n_classes): 108 | s_ind = monitor.unlabeled_size*c 109 | ind = s_ind + i 110 | 111 | support_set_tmp[c] += [ind] 112 | 113 | entropy_tmp = monitor.compute_entropy(support_set_tmp) 114 | if entropy_tmp >= entropy_best: 115 | support_set_best = support_set_tmp 116 | entropy_best = entropy_tmp 117 | 118 | support_set_ind = support_set_best 119 | 120 | for c in range(monitor.n_classes): 121 | ind_c = np.arange(monitor.unlabeled_size*c, monitor.unlabeled_size*(c+1)) 122 | # 1. Within class ind sanity check 123 | assert False not in [sc in ind_c for sc in support_set_ind[c]] 124 | # 2. Uniqueness sanity check 125 | assert np.unique(support_set_ind[c]).size == np.array(support_set_ind[c]).size 126 | 127 | support_set_list = [monitor.unlabeled_set[i_list] for i_list in support_set_ind] 128 | support_set = np.vstack(support_set_list) 129 | return support_set 130 | 131 | def get_kmeans_support_set(monitor, support_size): 132 | # unlabeled_set = unlabeled_set 133 | # unlabel_labels = unlabel_labels 134 | 135 | # greedy 136 | support_set_list = [[],[],[],[],[]] 137 | for c in range(monitor.n_classes): 138 | s_ind = monitor.unlabeled_size*c 139 | 140 | X = monitor.unlabeled_set[s_ind:s_ind+monitor.unlabeled_size] 141 | k_means = KMeans(n_clusters=support_size, random_state=0).fit(X) 142 | support_set_list[c] = k_means.cluster_centers_ 143 | support_set = np.vstack(support_set_list) 144 | return support_set 145 | 146 | 147 | 148 | def get_greedy_support_set(monitor, support_size): 149 | # unlabeled_set = unlabeled_set 150 | # unlabel_labels = unlabel_labels 151 | 152 | # greedy 153 | support_set_ind = [[],[],[],[],[]] 154 | for s in range(1, support_size+1): 155 | # Get best next support 156 | acc_best = 0. 157 | for i in range(monitor.unlabeled_size): 158 | support_set_tmp = copy.deepcopy(support_set_ind) 159 | if i in support_set_tmp[0]: 160 | continue 161 | for c in range(monitor.n_classes): 162 | s_ind = monitor.unlabeled_size*c 163 | ind = s_ind + i 164 | 165 | support_set_tmp[c] += [ind] 166 | 167 | acc_tmp = monitor.compute_acc(support_set_tmp) 168 | if acc_tmp >= acc_best: 169 | support_set_best = support_set_tmp 170 | acc_best = acc_tmp 171 | 172 | support_set_ind = support_set_best 173 | 174 | for c in range(monitor.n_classes): 175 | ind_c = np.arange(monitor.unlabeled_size*c, monitor.unlabeled_size*(c+1)) 176 | # 1. Within class ind sanity check 177 | assert False not in [sc in ind_c for sc in support_set_ind[c]] 178 | # 2. Uniqueness sanity check 179 | assert np.unique(support_set_ind[c]).size == np.array(support_set_ind[c]).size 180 | 181 | support_set_list = [monitor.unlabeled_set[i_list] for i_list in support_set_ind] 182 | support_set = np.vstack(support_set_list) 183 | return support_set 184 | 185 | 186 | 187 | 188 | def get_random_support_set(monitor, support_size): 189 | support_set_list = [] 190 | for c in range(monitor.n_classes): 191 | ind = np.arange(monitor.unlabeled_size*c, monitor.unlabeled_size*(c+1)) 192 | ind_c = np.random.choice(ind, support_size, replace=False) 193 | support_set_list += [monitor.unlabeled_set[ind_c]] 194 | 195 | support_set = np.vstack(support_set_list) 196 | return support_set 197 | 198 | @torch.no_grad() 199 | def calc_accuracy(split, iters, distance_fn, support_size, 200 | query_size, unlabeled_size, method=None, model=None, n_classes=5): 201 | sampler = Sampler(split) 202 | accuracies = [] 203 | 204 | if method == "pairs": 205 | check_pairs(split, iters, distance_fn, support_size, 206 | query_size, unlabeled_size, n_classes=5) 207 | 208 | 209 | 210 | 211 | # Helper Sampling Class 212 | def make_labels(size, classes): 213 | """ 214 | Helper function. Generates the labels of a set: e.g 0000 1111 2222 for size=4, classes=3 215 | """ 216 | return (np.arange(classes).reshape((classes, 1)) + np.zeros((classes, size), dtype=np.int32)).flatten() 217 | 218 | # Pytorch helpers and imports 219 | def to_pytorch(datum, n, k, c): 220 | if k > 0: 221 | return torch.from_numpy(datum).view(-1, n, k, c) 222 | else: 223 | return None 224 | -------------------------------------------------------------------------------- /src/models/base_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | class BaseWrapper(torch.nn.Module): 5 | def get_state_dict(self): 6 | raise NotImplementedError 7 | def load_state_dict(self, state_dict): 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /src/models/finetuning.py: -------------------------------------------------------------------------------- 1 | """ 2 | Few-Shot Parallel: trains a model as a series of tasks computed in parallel on multiple GPUs 3 | 4 | """ 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from src.tools.meters import BasicMeter 11 | from src.modules.distances import prototype_distance 12 | from embedding_propagation import EmbeddingPropagation, LabelPropagation 13 | from .base_wrapper import BaseWrapper 14 | from haven import haven_utils as haven 15 | from scipy.stats import sem, t 16 | import shutil as sh 17 | import sys 18 | 19 | 20 | class FinetuneWrapper(BaseWrapper): 21 | """Finetunes a model using an episodic scheme on multiple GPUs""" 22 | 23 | def __init__(self, model, nclasses, exp_dict): 24 | """ Constructor 25 | Args: 26 | model: architecture to train 27 | nclasses: number of output classes 28 | exp_dict: reference to dictionary with the hyperparameters 29 | """ 30 | super().__init__() 31 | self.model = model 32 | self.exp_dict = exp_dict 33 | self.ngpu = self.exp_dict["ngpu"] 34 | 35 | self.embedding_propagation = EmbeddingPropagation() 36 | self.label_propagation = LabelPropagation() 37 | self.model.add_classifier(nclasses, modalities=0) 38 | self.nclasses = nclasses 39 | 40 | if self.exp_dict["rotation_weight"] > 0: 41 | self.model.add_classifier(4, "classifier_rot") 42 | 43 | best_accuracy = -1 44 | if self.exp_dict["pretrained_weights_root"] is not None: 45 | for exp_hash in os.listdir(self.exp_dict['pretrained_weights_root']): 46 | base_path = os.path.join(self.exp_dict['pretrained_weights_root'], exp_hash) 47 | exp_dict_path = os.path.join(base_path, 'exp_dict.json') 48 | if not os.path.exists(exp_dict_path): 49 | continue 50 | loaded_exp_dict = haven.load_json(exp_dict_path) 51 | pkl_path = os.path.join(base_path, 'score_list_best.pkl') 52 | if (loaded_exp_dict["model"]["name"] == 'pretraining' and 53 | loaded_exp_dict["dataset_train"].split('_')[-1] == exp_dict["dataset_train"].split('_')[-1] and 54 | loaded_exp_dict["model"]["backbone"] == exp_dict['model']["backbone"] and 55 | # loaded_exp_dict["labelprop_alpha"] == exp_dict["labelprop_alpha"] and 56 | # loaded_exp_dict["labelprop_scale"] == exp_dict["labelprop_scale"] and 57 | os.path.exists(pkl_path)): 58 | accuracy = haven.load_pkl(pkl_path)[-1]["val_accuracy"] 59 | try: 60 | self.model.load_state_dict(torch.load(os.path.join(base_path, 'checkpoint_best.pth'))['model'], strict=False) 61 | if accuracy > best_accuracy: 62 | best_path = os.path.join(base_path, 'checkpoint_best.pth') 63 | best_accuracy = accuracy 64 | except: 65 | continue 66 | assert(best_accuracy > 0.1) 67 | print("Finetuning %s with original accuracy : %f" %(base_path, best_accuracy)) 68 | self.model.load_state_dict(torch.load(best_path)['model'], strict=False) 69 | 70 | # Add optimizers here 71 | self.optimizer = torch.optim.SGD(self.model.parameters(), 72 | lr=self.exp_dict["lr"], 73 | momentum=0.9, 74 | weight_decay=self.exp_dict["weight_decay"], 75 | nesterov=True) 76 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 77 | mode="min" if "loss" in self.exp_dict["target_loss"] else "max", 78 | patience=self.exp_dict["patience"]) 79 | self.model.cuda() 80 | if self.ngpu > 1: 81 | self.parallel_model = torch.nn.DataParallel(self.model, device_ids=list(range(self.ngpu))) 82 | 83 | def get_logits(self, embeddings, support_size, query_size, nclasses): 84 | """Computes the logits from the queries of an episode 85 | 86 | Args: 87 | embeddings (torch.Tensor): episode embeddings 88 | support_size (int): size of the support set 89 | query_size (int): size of the query set 90 | nclasses (int): number of classes 91 | 92 | Returns: 93 | torch.Tensor: logits 94 | """ 95 | b, c = embeddings.size() 96 | 97 | propagator = None 98 | if self.exp_dict["embedding_prop"] == True: 99 | embeddings = self.embedding_propagation(embeddings) 100 | 101 | if self.exp_dict["distance_type"] == "labelprop": 102 | support_labels = torch.arange(nclasses, device=embeddings.device).view(1, nclasses).repeat(support_size, 1).view(support_size, nclasses) 103 | unlabeled_labels = nclasses * torch.ones(query_size * nclasses, dtype=support_labels.dtype, device=support_labels.device).view(query_size, nclasses) 104 | labels = torch.cat([support_labels, unlabeled_labels], 0).view(-1) 105 | logits = self.label_propagation(embeddings, labels, nclasses) 106 | logits = logits.view(-1, nclasses, nclasses)[support_size:(support_size + query_size), ...].view(-1, nclasses) 107 | 108 | elif self.exp_dict["distance_tpe"] == "prototypical": 109 | embeddings = embeddings.view(-1, nclasses, c) 110 | support_embeddings = embeddings[:support_size] 111 | query_embeddings = embeddings[support_size:] 112 | logits = prototype_distance((support_embeddings.view(1, support_size, nclasses, c), False), 113 | (query_embeddings.view(1, query_size, nclasses, c), False)).view(query_size * nclasses, nclasses) 114 | return logits 115 | 116 | def train_on_batch(self, batch): 117 | """Computes the loss on an episode 118 | 119 | Args: 120 | batch (dict): Episode dict 121 | 122 | Returns: 123 | tuple: loss and accuracy of the episode 124 | """ 125 | episode = batch[0] 126 | nclasses = episode["nclasses"] 127 | support_size = episode["support_size"] 128 | query_size = episode["query_size"] 129 | labels = episode["targets"].view(support_size + query_size, nclasses, -1).cuda(non_blocking=True).long() 130 | k = (support_size + query_size) 131 | c = episode["channels"] 132 | h = episode["height"] 133 | w = episode["width"] 134 | 135 | tx = episode["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True) 136 | vx = episode["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True) 137 | x = torch.cat([tx, vx], 0) 138 | x = x.view(-1, c, h, w).cuda(non_blocking=True) 139 | if self.ngpu > 1: 140 | embeddings = self.parallel_model(x, is_support=True) 141 | else: 142 | embeddings = self.model(x, is_support=True) 143 | b, c = embeddings.size() 144 | 145 | logits = self.get_logits(embeddings, support_size, query_size, nclasses) 146 | 147 | loss = 0 148 | if self.exp_dict["classification_weight"] > 0: 149 | loss += F.cross_entropy(self.model.classifier(embeddings.view(b, c)), labels.view(-1)) * self.exp_dict["classification_weight"] 150 | 151 | query_labels = torch.arange(nclasses, device=logits.device).view(1, nclasses).repeat(query_size, 1).view(-1) 152 | loss += F.cross_entropy(logits, query_labels) * self.exp_dict["few_shot_weight"] 153 | return loss 154 | 155 | def predict_on_batch(self, batch): 156 | """Computes the logits of an episode 157 | 158 | Args: 159 | batch (dict): episode dict 160 | 161 | Returns: 162 | tensor: logits for the queries of the current episode 163 | """ 164 | nclasses = batch["nclasses"] 165 | support_size = batch["support_size"] 166 | query_size = batch["query_size"] 167 | k = (support_size + query_size) 168 | c = batch["channels"] 169 | h = batch["height"] 170 | w = batch["width"] 171 | 172 | tx = batch["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True) 173 | vx = batch["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True) 174 | x = torch.cat([tx, vx], 0) 175 | x = x.view(-1, c, h, w).cuda(non_blocking=True) 176 | 177 | if self.ngpu > 1: 178 | embeddings = self.parallel_model(x, is_support=True) 179 | else: 180 | embeddings = self.model(x, is_support=True) 181 | 182 | return self.get_logits(embeddings, support_size, query_size, nclasses) 183 | 184 | def val_on_batch(self, batch): 185 | """Computes the loss and accuracy on a validation batch 186 | 187 | Args: 188 | batch (dict): Episode dict 189 | 190 | Returns: 191 | tuple: loss and accuracy of the episode 192 | """ 193 | nclasses = batch["nclasses"] 194 | query_size = batch["query_size"] 195 | 196 | logits = self.predict_on_batch(batch) 197 | 198 | query_labels = torch.arange(nclasses, device=logits.device).view(1, nclasses).repeat(query_size, 1).view(-1) 199 | loss = F.cross_entropy(logits, query_labels) 200 | accuracy = float(logits.max(-1)[1].eq(query_labels).float().mean()) 201 | 202 | return loss, accuracy 203 | 204 | def train_on_loader(self, data_loader, max_iter=None, debug_plot_path=None): 205 | """Iterate over the training set 206 | 207 | Args: 208 | data_loader: iterable training data loader 209 | max_iter: max number of iterations to perform if the end of the dataset is not reached 210 | """ 211 | self.model.train() 212 | train_loss_meter = BasicMeter.get("train_loss").reset() 213 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU 214 | self.optimizer.zero_grad() 215 | for batch_idx, batch in enumerate(data_loader): 216 | loss = self.train_on_batch(batch) / self.exp_dict["tasks_per_batch"] 217 | train_loss_meter.update(float(loss), 1) 218 | loss.backward() 219 | if ((batch_idx + 1) % self.exp_dict["tasks_per_batch"]) == 0: 220 | self.optimizer.step() 221 | self.optimizer.zero_grad() 222 | if batch_idx + 1 == max_iter: 223 | break 224 | return {"train_loss": train_loss_meter.mean()} 225 | 226 | 227 | @torch.no_grad() 228 | def val_on_loader(self, data_loader, max_iter=None): 229 | """Iterate over the validation set 230 | 231 | Args: 232 | data_loader: iterable validation data loader 233 | max_iter: max number of iterations to perform if the end of the dataset is not reached 234 | """ 235 | self.model.eval() 236 | val_loss_meter = BasicMeter.get("val_loss").reset() 237 | val_accuracy_meter = BasicMeter.get("val_accuracy").reset() 238 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU 239 | for batch_idx, _data in enumerate(data_loader): 240 | batch = _data[0] 241 | loss, accuracy = self.val_on_batch(batch) 242 | val_loss_meter.update(float(loss), 1) 243 | val_accuracy_meter.update(float(accuracy), 1) 244 | loss = BasicMeter.get(self.exp_dict["target_loss"], recursive=True, force=False).mean() 245 | self.scheduler.step(loss) # update the learning rate monitor 246 | return {"val_loss": val_loss_meter.mean(), "val_accuracy": val_accuracy_meter.mean()} 247 | 248 | @torch.no_grad() 249 | def test_on_loader(self, data_loader, max_iter=None): 250 | """Iterate over the validation set 251 | 252 | Args: 253 | data_loader: iterable validation data loader 254 | max_iter: max number of iterations to perform if the end of the dataset is not reached 255 | """ 256 | self.model.eval() 257 | test_loss_meter = BasicMeter.get("test_loss").reset() 258 | test_accuracy_meter = BasicMeter.get("test_accuracy").reset() 259 | test_accuracy = [] 260 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU 261 | for batch_idx, _data in enumerate(data_loader): 262 | batch = _data[0] 263 | loss, accuracy = self.val_on_batch(batch) 264 | test_loss_meter.update(float(loss), 1) 265 | test_accuracy_meter.update(float(accuracy), 1) 266 | test_accuracy.append(float(accuracy)) 267 | from scipy.stats import sem, t 268 | confidence = 0.95 269 | n = len(test_accuracy) 270 | std_err = sem(np.array(test_accuracy)) 271 | h = std_err * t.ppf((1 + confidence) / 2, n - 1) 272 | return {"test_loss": test_loss_meter.mean(), "test_accuracy": test_accuracy_meter.mean(), "test_confidence": h} 273 | 274 | def get_state_dict(self): 275 | """Obtains the state dict of this model including optimizer, scheduler, etc 276 | 277 | Returns: 278 | dict: state dict 279 | """ 280 | ret = {} 281 | ret["optimizer"] = self.optimizer.state_dict() 282 | ret["model"] = self.model.state_dict() 283 | ret["scheduler"] = self.scheduler.state_dict() 284 | return ret 285 | 286 | def load_state_dict(self, state_dict): 287 | """Loads the state of the model 288 | 289 | Args: 290 | state_dict (dict): The state to load 291 | """ 292 | self.optimizer.load_state_dict(state_dict["optimizer"]) 293 | self.model.load_state_dict(state_dict["model"]) 294 | self.scheduler.load_state_dict(state_dict["scheduler"]) 295 | 296 | def get_lr(self): 297 | ret = {} 298 | for i, param_group in enumerate(self.optimizer.param_groups): 299 | ret["current_lr_%d" % i] = float(param_group["lr"]) 300 | return ret 301 | 302 | def is_end_of_training(self): 303 | lr = self.get_lr()["current_lr_0"] 304 | return lr <= (self.exp_dict["lr"] * self.exp_dict["min_lr_decay"]) -------------------------------------------------------------------------------- /src/models/pretraining.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | from src.tools.meters import BasicMeter 6 | from embedding_propagation import EmbeddingPropagation, LabelPropagation 7 | from .base_wrapper import BaseWrapper 8 | from src.modules.distances import prototype_distance 9 | 10 | class PretrainWrapper(BaseWrapper): 11 | """Trains a model using an episodic scheme on multiple GPUs""" 12 | 13 | def __init__(self, model, nclasses, exp_dict): 14 | """ Constructor 15 | Args: 16 | model: architecture to train 17 | nclasses: number of output classes 18 | exp_dict: reference to dictionary with the hyperparameters 19 | """ 20 | super().__init__() 21 | self.model = model 22 | self.exp_dict = exp_dict 23 | self.ngpu = self.exp_dict["ngpu"] 24 | self.embedding_propagation = EmbeddingPropagation() 25 | self.label_propagation = LabelPropagation() 26 | self.model.add_classifier(nclasses, modalities=0) 27 | self.nclasses = nclasses 28 | 29 | 30 | if self.exp_dict["rotation_weight"] > 0: 31 | self.model.add_classifier(4, "classifier_rot") 32 | 33 | # Add optimizers here 34 | self.optimizer = torch.optim.SGD(self.model.parameters(), 35 | lr=self.exp_dict["lr"], 36 | momentum=0.9, 37 | weight_decay=self.exp_dict["weight_decay"], 38 | nesterov=True) 39 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 40 | mode="min" if "loss" in self.exp_dict["target_loss"] else "max", 41 | patience=self.exp_dict["patience"]) 42 | self.model.cuda() 43 | if self.ngpu > 1: 44 | self.parallel_model = torch.nn.DataParallel(self.model, device_ids=list(range(self.ngpu))) 45 | 46 | def get_logits(self, embeddings, support_size, query_size, nclasses): 47 | """Computes the logits from the queries of an episode 48 | 49 | Args: 50 | embeddings (torch.Tensor): episode embeddings 51 | support_size (int): size of the support set 52 | query_size (int): size of the query set 53 | nclasses (int): number of classes 54 | 55 | Returns: 56 | torch.Tensor: logits 57 | """ 58 | b, c = embeddings.size() 59 | 60 | if self.exp_dict["embedding_prop"] == True: 61 | embeddings = self.embedding_propagation(embeddings) 62 | if self.exp_dict["distance_type"] == "labelprop": 63 | support_labels = torch.arange(nclasses, device=embeddings.device).view(1, nclasses).repeat(support_size, 1).view(support_size, nclasses) 64 | unlabeled_labels = nclasses * torch.ones(query_size * nclasses, dtype=support_labels.dtype, device=support_labels.device).view(query_size, nclasses) 65 | labels = torch.cat([support_labels, unlabeled_labels], 0).view(-1) 66 | logits = self.label_propagation(embeddings, labels, nclasses) 67 | logits = logits.view(-1, nclasses, nclasses)[support_size:(support_size + query_size), ...].view(-1, nclasses) 68 | 69 | elif self.exp_dict["distance_type"] == "prototypical": 70 | embeddings = embeddings.view(-1, nclasses, c) 71 | support_embeddings = embeddings[:support_size] 72 | query_embeddings = embeddings[support_size:] 73 | logits = prototype_distance(support_embeddings.view(-1, c), 74 | query_embeddings.view(-1, c), 75 | support_labels.view(-1)) 76 | return logits 77 | 78 | def train_on_batch(self, batch): 79 | """Computes the loss of a batch 80 | 81 | Args: 82 | batch (tuple): Inputs and labels 83 | 84 | Returns: 85 | loss: Loss on the batch 86 | """ 87 | x, y, r = batch 88 | y = y.cuda(non_blocking=True).view(-1) 89 | r = r.cuda(non_blocking=True).view(-1) 90 | k, n, c, h, w = x.size() 91 | x = x.view(n*k, c, h, w).cuda(non_blocking=True) 92 | if self.ngpu > 1: 93 | embeddings = self.parallel_model(x, is_support=True) 94 | else: 95 | embeddings = self.model(x, is_support=True) 96 | b, c = embeddings.size() 97 | 98 | loss = 0 99 | if self.exp_dict["rotation_weight"] > 0: 100 | rot = self.model.classifier_rot(embeddings) 101 | loss += F.cross_entropy(rot, r) * self.exp_dict["rotation_weight"] 102 | 103 | if self.exp_dict["embedding_prop"] == True: 104 | embeddings = self.embedding_propagation(embeddings) 105 | logits = self.model.classifier(embeddings) 106 | loss += F.cross_entropy(logits, y) * self.exp_dict["cross_entropy_weight"] 107 | return loss 108 | 109 | def val_on_batch(self, batch): 110 | """Computes the loss and accuracy on a validation batch 111 | 112 | Args: 113 | batch (dict): Episode dict 114 | 115 | Returns: 116 | tuple: loss and accuracy of the episode 117 | """ 118 | nclasses = batch["nclasses"] 119 | query_size = batch["query_size"] 120 | 121 | logits = self.predict_on_batch(batch) 122 | 123 | query_labels = torch.arange(nclasses, device=logits.device).view(1, nclasses).repeat(query_size, 1).view(-1) 124 | loss = F.cross_entropy(logits, query_labels) 125 | accuracy = float(logits.max(-1)[1].eq(query_labels).float().mean()) 126 | 127 | return loss, accuracy 128 | 129 | def predict_on_batch(self, batch): 130 | """Computes the logits of an episode 131 | 132 | Args: 133 | batch (dict): episode dict 134 | 135 | Returns: 136 | tensor: logits for the queries of the current episode 137 | """ 138 | nclasses = batch["nclasses"] 139 | support_size = batch["support_size"] 140 | query_size = batch["query_size"] 141 | k = (support_size + query_size) 142 | c = batch["channels"] 143 | h = batch["height"] 144 | w = batch["width"] 145 | 146 | tx = batch["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True) 147 | vx = batch["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True) 148 | x = torch.cat([tx, vx], 0) 149 | x = x.view(-1, c, h, w).cuda(non_blocking=True) 150 | 151 | if self.ngpu > 1: 152 | embeddings = self.parallel_model(x, is_support=True) 153 | else: 154 | embeddings = self.model(x, is_support=True) 155 | b, c = embeddings.size() 156 | 157 | return self.get_logits(embeddings, support_size, query_size, nclasses) 158 | 159 | def train_on_loader(self, data_loader, max_iter=None, debug_plot_path=None): 160 | """Iterate over the training set 161 | 162 | Args: 163 | data_loader (torch.utils.data.DataLoader): a pytorch dataloader 164 | max_iter (int, optional): Max number of iterations if the end of the dataset is not reached. Defaults to None. 165 | 166 | Returns: 167 | metrics: dictionary with metrics of the training set 168 | """ 169 | self.model.train() 170 | train_loss_meter = BasicMeter.get("train_loss").reset() 171 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU 172 | for batch_idx, batch in enumerate(data_loader): 173 | self.optimizer.zero_grad() 174 | loss = self.train_on_batch(batch) 175 | train_loss_meter.update(float(loss), 1) 176 | loss.backward() 177 | self.optimizer.step() 178 | if batch_idx + 1 == max_iter: 179 | break 180 | return {"train_loss": train_loss_meter.mean()} 181 | 182 | 183 | @torch.no_grad() 184 | def val_on_loader(self, data_loader, max_iter=None): 185 | """Iterate over the validation set 186 | 187 | Args: 188 | data_loader: iterable validation data loader 189 | max_iter: max number of iterations to perform if the end of the dataset is not reached 190 | """ 191 | self.model.eval() 192 | val_loss_meter = BasicMeter.get("val_loss").reset() 193 | val_accuracy_meter = BasicMeter.get("val_accuracy").reset() 194 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU 195 | for batch_idx, _data in enumerate(data_loader): 196 | batch = _data[0] 197 | loss, accuracy = self.val_on_batch(batch) 198 | val_loss_meter.update(float(loss), 1) 199 | val_accuracy_meter.update(float(accuracy), 1) 200 | loss = BasicMeter.get(self.exp_dict["target_loss"], recursive=True, force=False).mean() 201 | self.scheduler.step(loss) # update the learning rate monitor 202 | return {"val_loss": val_loss_meter.mean(), "val_accuracy": val_accuracy_meter.mean()} 203 | 204 | @torch.no_grad() 205 | def test_on_loader(self, data_loader, max_iter=None): 206 | """Iterate over the validation set 207 | 208 | Args: 209 | data_loader: iterable validation data loader 210 | max_iter: max number of iterations to perform if the end of the dataset is not reached 211 | """ 212 | self.model.eval() 213 | test_loss_meter = BasicMeter.get("test_loss").reset() 214 | test_accuracy_meter = BasicMeter.get("test_accuracy").reset() 215 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU 216 | for batch_idx, _data in enumerate(data_loader): 217 | batch = _data[0] 218 | loss, accuracy = self.val_on_batch(batch) 219 | test_loss_meter.update(float(loss), 1) 220 | test_accuracy_meter.update(float(accuracy), 1) 221 | return {"test_loss": test_loss_meter.mean(), "test_accuracy": test_accuracy_meter.mean()} 222 | 223 | def get_state_dict(self): 224 | """Obtains the state dict of this model including optimizer, scheduler, etc 225 | 226 | Returns: 227 | dict: state dict 228 | """ 229 | ret = {} 230 | ret["optimizer"] = self.optimizer.state_dict() 231 | ret["model"] = self.model.state_dict() 232 | ret["scheduler"] = self.scheduler.state_dict() 233 | return ret 234 | 235 | def load_state_dict(self, state_dict): 236 | """Loads the state of the model 237 | 238 | Args: 239 | state_dict (dict): The state to load 240 | """ 241 | self.optimizer.load_state_dict(state_dict["optimizer"]) 242 | self.model.load_state_dict(state_dict["model"]) 243 | self.scheduler.load_state_dict(state_dict["scheduler"]) 244 | 245 | def get_lr(self): 246 | ret = {} 247 | for i, param_group in enumerate(self.optimizer.param_groups): 248 | ret["current_lr_%d" % i] = float(param_group["lr"]) 249 | return ret 250 | 251 | def is_end_of_training(self): 252 | lr = self.get_lr()["current_lr_0"] 253 | return lr <= (self.exp_dict["lr"] * self.exp_dict["min_lr_decay"]) -------------------------------------------------------------------------------- /src/models/ssl_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Few-Shot Parallel: trains a model as a series of tasks computed in parallel on multiple GPUs 3 | 4 | """ 5 | import copy 6 | import numpy as np 7 | import os 8 | from .base_ssl import oracle 9 | from scipy.stats import sem, t 10 | import torch 11 | import pandas as pd 12 | import torch.nn.functional as F 13 | import tqdm 14 | from src.tools.meters import BasicMeter 15 | from src.modules.distances import standarized_label_prop, _propagate, prototype_distance 16 | from .base_wrapper import BaseWrapper 17 | from haven import haven_utils as hu 18 | import glob 19 | from scipy.stats import sem, t 20 | import shutil as sh 21 | from .base_ssl import selection_methods as sm 22 | from .base_ssl import predict_methods as pm 23 | from embedding_propagation import EmbeddingPropagation 24 | 25 | class SSLWrapper(BaseWrapper): 26 | """Trains a model using an episodic scheme on multiple GPUs""" 27 | 28 | def __init__(self, model, n_classes, exp_dict, pretrained_savedir=None, savedir_base=None): 29 | """ Constructor 30 | Args: 31 | model: architecture to train 32 | exp_dict: reference to dictionary with the global state of the application 33 | """ 34 | super().__init__() 35 | self.model = model 36 | self.exp_dict = exp_dict 37 | self.ngpu = self.exp_dict["ngpu"] 38 | self.predict_method = exp_dict['predict_method'] 39 | 40 | self.model.add_classifier(n_classes, modalities=0) 41 | self.nclasses = n_classes 42 | 43 | best_accuracy = -1 44 | self.label = exp_dict['model']['backbone'] + "_" + exp_dict['dataset_test'].split('_')[1].replace('-imagenet','') 45 | print('=============') 46 | print('dataset:', exp_dict["dataset_train"].split('_')[-1]) 47 | print('backbone:', exp_dict['model']["backbone"]) 48 | print('n_classes:', exp_dict['n_classes']) 49 | print('support_size_train:', exp_dict['support_size_train']) 50 | 51 | if pretrained_savedir is None: 52 | # find the best checkpoint 53 | savedir_base = exp_dict["finetuned_weights_root"] 54 | if not os.path.exists(savedir_base): 55 | raise ValueError("Please set the variable named \ 56 | 'finetuned_weights_root' with the path of the folder \ 57 | with the episodic finetuning experiments") 58 | for exp_hash in os.listdir(savedir_base): 59 | base_path = os.path.join(savedir_base, exp_hash) 60 | exp_dict_path = os.path.join(base_path, 'exp_dict.json') 61 | if not os.path.exists(exp_dict_path): 62 | continue 63 | loaded_exp_dict = hu.load_json(exp_dict_path) 64 | pkl_path = os.path.join(base_path, 'score_list_best.pkl') 65 | 66 | if exp_dict['support_size_train'] in [2,3,4]: 67 | support_size_needed = 1 68 | else: 69 | support_size_needed = exp_dict['support_size_train'] 70 | 71 | if (loaded_exp_dict["model"]["name"] == 'finetuning' and 72 | loaded_exp_dict["dataset_train"].split('_')[-1] == exp_dict["dataset_train"].split('_')[-1] and 73 | loaded_exp_dict["model"]["backbone"] == exp_dict['model']["backbone"] and 74 | loaded_exp_dict['n_classes'] == exp_dict["n_classes"] and 75 | loaded_exp_dict['support_size_train'] == support_size_needed, 76 | loaded_exp_dict["embedding_prop"] == exp_dict["embedding_prop"]): 77 | 78 | model_path = os.path.join(base_path, 'checkpoint_best.pth') 79 | 80 | try: 81 | print("Attempting to load ", model_path) 82 | accuracy = hu.load_pkl(pkl_path)[-1]["val_accuracy"] 83 | self.model.load_state_dict(torch.load(model_path)['model'], strict=False) 84 | if accuracy > best_accuracy: 85 | best_path = os.path.join(base_path, 'checkpoint_best.pth') 86 | best_accuracy = accuracy 87 | except Exception as e: 88 | print(e) 89 | 90 | assert(best_accuracy > 0.1) 91 | print("Finetuning %s with original accuracy : %f" %(base_path, best_accuracy)) 92 | self.model.load_state_dict(torch.load(best_path)['model'], strict=False) 93 | self.best_accuracy = best_accuracy 94 | self.acc_sum = 0.0 95 | self.n_count = 0 96 | self.model.cuda() 97 | 98 | def get_embeddings(self, embeddings, support_size, query_size, nclasses): 99 | b, c = embeddings.size() 100 | 101 | if self.exp_dict["embedding_prop"] == True: 102 | embeddings = EmbeddingPropagation()(embeddings) 103 | return embeddings.view(b, c) 104 | 105 | def get_episode_dict(self, batch): 106 | nclasses = batch["nclasses"] 107 | support_size = batch["support_size"] 108 | query_size = batch["query_size"] 109 | k = (support_size + query_size) 110 | c = batch["channels"] 111 | h = batch["height"] 112 | w = batch["width"] 113 | 114 | tx = batch["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True) 115 | vx = batch["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True) 116 | ux = batch["unlabeled_set"].view(batch["unlabeled_size"], nclasses, c, h, w).cuda(non_blocking=True) 117 | x = torch.cat([tx, vx, ux], 0) 118 | x = x.view(-1, c, h, w).cuda(non_blocking=True) 119 | 120 | if self.ngpu > 1: 121 | features = self.parallel_model(x, is_support=True) 122 | else: 123 | features = self.model(x, is_support=True) 124 | 125 | embeddings = self.get_embeddings(features, 126 | support_size, 127 | query_size+ 128 | batch['unlabeled_size'], 129 | nclasses) # (b, channels) 130 | 131 | uniques = np.unique(batch['targets']) 132 | labels = torch.zeros(batch['targets'].shape[0]) 133 | for i, u in enumerate(uniques): 134 | labels[batch['targets']==u] = i 135 | 136 | ## perform ssl 137 | # 1. indices 138 | episode_dict = {} 139 | ns = support_size*nclasses 140 | nq = query_size*nclasses 141 | episode_dict["support"] = {'samples':embeddings[:ns], 142 | 'labels':labels[:ns]} 143 | episode_dict["query"] = {'samples':embeddings[ns:ns+nq], 144 | 'labels':labels[ns:ns+nq]} 145 | episode_dict["unlabeled"] = {'samples':embeddings[ns+nq:]} 146 | # batch["support_so_far"] = {'samples':embeddings, 147 | # 'labels':labels} 148 | 149 | 150 | for k, v in episode_dict.items(): 151 | episode_dict[k]['samples'] = episode_dict[k]['samples'].cpu().numpy() 152 | if 'labels' in episode_dict[k]: 153 | episode_dict[k]['labels'] = episode_dict[k]['labels'].cpu().numpy().astype(int) 154 | return episode_dict 155 | 156 | def predict_on_batch(self, episode_dict, support_size_max=None): 157 | ind_selected = sm.get_indices(selection_method="ssl", 158 | episode_dict=episode_dict, 159 | support_size_max=support_size_max) 160 | episode_dict = update_episode_dict(ind_selected, episode_dict) 161 | pred_labels = pm.get_predictions(predict_method=self.predict_method, 162 | episode_dict=episode_dict) 163 | 164 | return pred_labels 165 | 166 | def val_on_batch(self, batch): 167 | # if self.exp_dict['ora'] 168 | if self.exp_dict.get("pretrained_weights_root") == 'hdf5': 169 | episode_dict = self.sampler.sample_episode(int(self.exp_dict['support_size_test']), 170 | self.exp_dict['query_size_test'], 171 | self.exp_dict['unlabeled_size_test'], 172 | apply_ten_flag=self.exp_dict.get("apply_ten_flag")) 173 | else: 174 | episode_dict = self.get_episode_dict(batch) 175 | episode_dict["support_so_far"] = copy.deepcopy(episode_dict["support"]) 176 | episode_dict["n_classes"] = 5 177 | 178 | pred_labels = self.predict_on_batch(episode_dict, support_size_max=self.exp_dict['unlabeled_size_test']*self.exp_dict['classes_test']) 179 | accuracy = oracle.compute_acc(pred_labels=pred_labels, 180 | true_labels=episode_dict["query"]["labels"]) 181 | 182 | # query_labels = episode_dict["query"]["labels"] 183 | # accuracy = float((pred_labels == query_labels.cuda()).float().mean()) 184 | 185 | self.acc_sum += accuracy 186 | self.n_count += 1 187 | return -1, accuracy 188 | 189 | @torch.no_grad() 190 | def test_on_loader(self, data_loader, max_iter=None): 191 | """Iterate over the validation set 192 | 193 | Args: 194 | data_loader: iterable validation data loader 195 | max_iter: max number of iterations to perform if the end of the dataset is not reached 196 | """ 197 | self.model.eval() 198 | 199 | test_accuracy_meter = BasicMeter.get("test_accuracy").reset() 200 | test_accuracy = [] 201 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU 202 | # dirname = os.path.split(self.exp_dict["pretrained_weights_root"])[-1] 203 | with tqdm.tqdm(total=len(data_loader)) as pbar: 204 | for batch_all in data_loader: 205 | batch = batch_all[0] 206 | loss, accuracy = self.val_on_batch(batch) 207 | 208 | test_accuracy_meter.update(float(accuracy), 1) 209 | test_accuracy.append(float(accuracy)) 210 | 211 | string = ("'%s' - ssl: %.3f" % 212 | (self.label, 213 | # dirname, 214 | test_accuracy_meter.mean())) 215 | # print(string) 216 | pbar.update(1) 217 | pbar.set_description(string) 218 | 219 | confidence = 0.95 220 | n = len(test_accuracy) 221 | std_err = sem(np.array(test_accuracy)) 222 | h = std_err * t.ppf((1 + confidence) / 2, n - 1) 223 | return {"test_loss": -1, 224 | "ssl_accuracy": test_accuracy_meter.mean(), 225 | "ssl_confidence": h, 226 | 'finetuned_accuracy': self.best_accuracy} 227 | 228 | def update_episode_dict(ind, episode_dict): 229 | # 1. update supports so far 230 | selected_samples = episode_dict["unlabeled"]["samples"][ind] 231 | selected_labels = episode_dict["unlabeled"]["labels"][ind] 232 | 233 | selected_support_dict = {"samples": selected_samples, "labels": selected_labels} 234 | 235 | for k, v in episode_dict["support_so_far"].items(): 236 | episode_dict["support_so_far"][k] = np.concatenate([v, selected_support_dict[k]], axis=0) 237 | 238 | # 2. update unlabeled samples 239 | n_unlabeled = episode_dict["unlabeled"]["samples"].shape[0] 240 | ind_rest = np.setdiff1d(np.arange(n_unlabeled), ind) 241 | 242 | new_unlabeled_dict = {} 243 | for k, v in episode_dict["unlabeled"].items(): 244 | new_unlabeled_dict[k] = v[ind_rest] 245 | 246 | episode_dict["unlabeled"] = new_unlabeled_dict 247 | 248 | return episode_dict -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/modules/__init__.py -------------------------------------------------------------------------------- /src/modules/distances.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | def _make_aligned_labels(inputs): 5 | batch, n_sample_pc, n_classes, z_dim = inputs.shape 6 | identity = torch.eye(n_classes, dtype=inputs.dtype, device=inputs.device) 7 | return identity[None, None, :, :].expand(batch, n_sample_pc, -1, -1).contiguous() 8 | def generalized_pw_sq_dist(data, d_type="euclidean"): 9 | batch, samples, z_dim = data.size() 10 | if d_type == "euclidean": 11 | return torch.sum((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim) 12 | elif d_type == "l1": 13 | return torch.mean(torch.abs(data[:, :, None, :] - data[:, None, :, :]), dim=3) 14 | elif d_type == "stable_euclidean": 15 | return torch.sqrt(1e-6 + torch.mean((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim)) 16 | elif d_type == "cosine": 17 | data = F.normalize(data, dim=2) 18 | return torch.bmm(data, data.transpose(2, 1)) 19 | else: 20 | raise ValueError("Distance type not recognized") 21 | def standarized_label_prop(embeddings, 22 | labels, 23 | gaussian_scale=1, alpha=0.5, 24 | weights=None, 25 | apply_log=False, 26 | scale_bound="", 27 | standarize="all", 28 | kernel="rbf", 29 | square_root=False, 30 | norm_prop=0, 31 | epsilon=1e-6): 32 | propagator_scale = gaussian_scale 33 | gaussian_scale = 1 34 | if scale_bound == "softplus": 35 | gaussian_scale = 0.01 + F.softplus(gaussian_scale) 36 | alpha = 0.1 + F.softplus(alpha) 37 | elif scale_bound == "square": 38 | gaussian_scale = 1e-4 + gaussian_scale ** 2 39 | alpha = 0.1 + alpha ** 2 40 | elif scale_bound == "convex_relu": 41 | #gaussian_scale = gaussian_scale ** 2 42 | alpha = F.relu(alpha) + 0.1 43 | elif scale_bound == "convex_square": 44 | # gaussian_scale = gaussian_scale ** 2 45 | alpha = 0.1 + alpha ** 2 46 | elif scale_bound == "relu": 47 | gaussian_scale = F.relu(gaussian_scale) + 0.01 48 | alpha = F.relu(alpha) + 0.1 49 | elif scale_bound == "constant": 50 | gaussian_scale = 1 51 | alpha = 1 52 | elif scale_bound == "alpha_square": 53 | alpha = 0.1 + F.relu(alpha) 54 | # Compute the pairwise distance between the examples of the sample and query sets 55 | # XXX: labels are set to a constant for the query set 56 | sq_dist = generalized_pw_sq_dist(embeddings, "euclidean") 57 | if square_root: 58 | sq_dist = (sq_dist + epsilon).sqrt() 59 | if standarize == "all": 60 | mask = sq_dist != 0 61 | # sq_dist = sq_dist - sq_dist[mask].mean() 62 | sq_dist = sq_dist / sq_dist[mask].std() 63 | elif standarize == "median": 64 | mask = sq_dist != 0 65 | gaussian_scale = torch.sqrt( 66 | 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1))) 67 | elif standarize == "frobenius": 68 | mask = sq_dist != 0 69 | sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt() 70 | elif standarize == "percentile": 71 | mask = sq_dist != 2 72 | sorted, indices = torch.sort(sq_dist.data[mask]) 73 | total = sorted.size(0) 74 | gaussian_scale = sorted[int(total * 0.1)].detach() 75 | if kernel == "rbf": 76 | weights = torch.exp(-sq_dist * gaussian_scale) 77 | elif kernel == "convex_rbf": 78 | scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype) 79 | weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1)) 80 | weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1) 81 | # checknan(timessoftmax=weights) 82 | elif kernel == "euclidean": 83 | # Compute similarity between the examples -- inversely proportional to distance 84 | weights = 1 / (gaussian_scale + sq_dist) 85 | elif kernel == "softmax": 86 | weights = F.softmax(-sq_dist / gaussian_scale, -1) 87 | mask = torch.eye(weights.size(1), dtype=torch.bool, device=weights.device)[None, :, :] 88 | weights = weights * (~mask).float() 89 | logits, propagator = global_consistency(weights, labels, alpha=alpha, norm_prop=norm_prop, scale=propagator_scale) 90 | if apply_log: 91 | logits = torch.log(logits + epsilon) 92 | return logits, propagator 93 | def global_consistency(weights, labels, alpha=1, norm_prop=0, scale=1): 94 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug) 95 | Args: 96 | weights: Tensor of shape (batch, n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and 97 | s the scale parameter. 98 | labels: Tensor of shape (batch, n, n_classes) 99 | alpha: Scaler, acts as a smoothing factor 100 | Returns: 101 | Tensor of shape (batch, n, n_classes) representing the logits of each classes 102 | """ 103 | n = weights.shape[1] 104 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)[None, :, :] 105 | #weights = weights * (1. - identity) # zero out diagonal 106 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=2)) 107 | # checknan(laplacian=isqrt_diag) 108 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, :, None] 109 | # checknan(normalizedlaplacian=S) 110 | propagator = identity - alpha * S 111 | propagator = torch.inverse(propagator) 112 | # checknan(propagator=propagator) 113 | if norm_prop > 0: 114 | propagator = F.normalize(propagator, p=norm_prop, dim=-1) 115 | elif norm_prop < 0: 116 | propagator = F.softmax(propagator, dim=-1) 117 | propagator = propagator * scale 118 | return _propagate(labels, propagator, scaling=1), propagator 119 | def _propagate(labels, propagator, scaling=1.): 120 | return torch.matmul(propagator, labels) * scaling 121 | def prototype_distance(support_set, query_set, labels, unlabeled_set=None): 122 | """Computes distance from each element of the query set to prototypes in the sample set. 123 | Args: 124 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of 125 | each images. 126 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of 127 | each images. 128 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation 129 | z of each images. 130 | Returns: 131 | Tensor of shape (batch, n_total_query, n_classes) containing the similarity between each pair of query, 132 | prototypes, for each task. 133 | """ 134 | n_queries, channels = query_set.size() 135 | n_support, channels = support_set.size() 136 | support_set = support_set.view(n_support, 1, channels) 137 | way = int(labels.data.max()) + 1 138 | one_hot_labels = torch.zeros(n_support, way, 1, dtype=support_set.dtype, device=support_set.device) 139 | one_hot_labels.scatter_(1, labels.view(n_support, 1, 1), 1) 140 | total_per_class = one_hot_labels.sum(0, keepdim=True) 141 | prototypes = (support_set * one_hot_labels).sum(0) / total_per_class 142 | prototypes = prototypes.view(1, way, channels) 143 | query_set = query_set.view(n_queries, 1, channels) 144 | d = query_set - prototypes 145 | return -torch.sum(d ** 2, 2) / np.sqrt(channels) -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/tools/__init__.py -------------------------------------------------------------------------------- /src/tools/meters.py: -------------------------------------------------------------------------------- 1 | class BasicMeter(object): 2 | """ 3 | Basic class to monitor scores 4 | """ 5 | meters = {} 6 | submeters = {} 7 | 8 | @staticmethod 9 | def get(name, recursive=False, tag=None, force=True): 10 | """ Creates a new meter or returns an already existing one with the given name. 11 | 12 | Args: 13 | name: meter name 14 | 15 | Returns: BasicMeter instance 16 | 17 | """ 18 | 19 | if name not in BasicMeter.meters: 20 | if recursive: 21 | for supername, meter in BasicMeter.meters.items(): 22 | for subname in meter.submeters: 23 | if "%s_%s" %(supername, subname) == name: 24 | return meter.submeters[subname] 25 | if force: 26 | BasicMeter.meters[name] = BasicMeter(name) 27 | else: 28 | raise ModuleNotFoundError 29 | 30 | if tag is not None: 31 | if force: 32 | return BasicMeter.meters[name].get_submeter(tag) 33 | else: 34 | raise ModuleNotFoundError 35 | 36 | return BasicMeter.meters[name] 37 | 38 | @staticmethod 39 | 40 | 41 | @staticmethod 42 | def dict(): 43 | """ Obtain meters in a dictionary 44 | 45 | Returns: dictionary of BasicMeter 46 | 47 | """ 48 | return BasicMeter.meters 49 | 50 | def __init__(self, name=""): 51 | """ 52 | Constructor 53 | """ 54 | self.count = 0. 55 | self.total = 0. 56 | self.name = name 57 | self.submeters = {} 58 | 59 | def get_submeter(self, name): 60 | if name not in self.submeters: 61 | name_ = "%s_%s" %(self.name, name) 62 | self.submeters[name] = BasicMeter(name_) 63 | return self.submeters[name] 64 | 65 | def update(self, v, count, tag=None): 66 | """ Update meter values 67 | 68 | Args: 69 | v: current value 70 | count: N if value is the average of N values. 71 | 72 | Returns: self 73 | 74 | """ 75 | self.count += count 76 | self.total += v 77 | 78 | if tag is not None: 79 | if not isinstance(tag, list): 80 | tag = [tag] 81 | for t in tag: 82 | self.get_submeter(t).update(v, count) 83 | return self 84 | 85 | def mean(self, tag=None, recursive=False): 86 | """ Computes the mean of the current values 87 | 88 | Returns: mean of the current values (float) 89 | 90 | """ 91 | if recursive: 92 | try: 93 | ret = { self.name: self.total / self.count } 94 | except ZeroDivisionError: 95 | return { self.name: 0 } 96 | for submeter in self.submeters: 97 | ret.update(self.get_submeter(submeter).mean(recursive=True)) 98 | return ret 99 | if tag is not None: 100 | return self.get_submeter(tag).mean() 101 | else: 102 | return self.total / self.count 103 | 104 | def reset(self): 105 | """ Resets the meter. 106 | 107 | Returns: self 108 | 109 | """ 110 | for submeter in self.submeters: 111 | self.submeters[submeter].reset() 112 | self.count = 0 113 | self.total = 0 114 | 115 | return self 116 | 117 | -------------------------------------------------------------------------------- /src/tools/plot_episode.py: -------------------------------------------------------------------------------- 1 | import pylab 2 | 3 | def plot_episode(episode, classes_first=True): 4 | sample_set = episode["support_set"].cpu() 5 | query_set = episode["query_set"].cpu() 6 | support_size = episode["support_size"] 7 | query_size = episode["query_size"] 8 | if not classes_first: 9 | sample_set = sample_set.permute(1, 0, 2, 3, 4) 10 | query_set = query_set.permute(1, 0, 2, 3, 4) 11 | n, support_size, c, h, w = sample_set.size() 12 | n, query_size, c, h, w = query_set.size() 13 | sample_set = ((sample_set / 2 + 0.5) * 255).numpy().astype('uint8').transpose((0, 3, 1, 4, 2)).reshape((n *h, support_size * w, c)) 14 | pylab.imsave('support_set.png', sample_set) 15 | query_set = ((query_set / 2 + 0.5) * 255).numpy().astype('uint8').transpose((0, 3, 1, 4, 2)).reshape((n *h, query_size * w, c)) 16 | pylab.imsave('query_set.png', query_set) 17 | # pylab.imshow(query_set) 18 | # pylab.title("query_set") 19 | # pylab.show() 20 | # pylab.savefig('query_set.png') 21 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | 3 | def get_collate(name): 4 | if name == "identity": 5 | return lambda x: x 6 | else: 7 | return default_collate -------------------------------------------------------------------------------- /trainval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import pandas as pd 4 | import sys 5 | import os 6 | from torch import nn 7 | from torch.nn import functional as F 8 | import tqdm 9 | import pprint 10 | from src import utils as ut 11 | import torchvision 12 | import numpy as np 13 | 14 | from src import datasets, models 15 | from src.models import backbones 16 | from torch.utils.data import DataLoader 17 | import exp_configs 18 | from torch.utils.data.sampler import RandomSampler 19 | 20 | from haven import haven_utils as hu 21 | from haven import haven_results as hr 22 | from haven import haven_chk as hc 23 | from haven import haven_jupyter as hj 24 | 25 | 26 | def trainval(exp_dict, savedir_base, datadir, reset=False, 27 | num_workers=0, pretrained_weights_dir=None): 28 | # bookkeeping 29 | # --------------- 30 | 31 | # get experiment directory 32 | exp_id = hu.hash_dict(exp_dict) 33 | savedir = os.path.join(savedir_base, exp_id) 34 | 35 | if reset: 36 | # delete and backup experiment 37 | hc.delete_experiment(savedir, backup_flag=True) 38 | 39 | # create folder and save the experiment dictionary 40 | os.makedirs(savedir, exist_ok=True) 41 | hu.save_json(os.path.join(savedir, 'exp_dict.json'), exp_dict) 42 | pprint.pprint(exp_dict) 43 | print('Experiment saved in %s' % savedir) 44 | 45 | # load datasets 46 | # ========================== 47 | train_set = datasets.get_dataset(dataset_name=exp_dict["dataset_train"], 48 | data_root=os.path.join(datadir, exp_dict["dataset_train_root"]), 49 | split="train", 50 | transform=exp_dict["transform_train"], 51 | classes=exp_dict["classes_train"], 52 | support_size=exp_dict["support_size_train"], 53 | query_size=exp_dict["query_size_train"], 54 | n_iters=exp_dict["train_iters"], 55 | unlabeled_size=exp_dict["unlabeled_size_train"]) 56 | 57 | val_set = datasets.get_dataset(dataset_name=exp_dict["dataset_val"], 58 | data_root=os.path.join(datadir, exp_dict["dataset_val_root"]), 59 | split="val", 60 | transform=exp_dict["transform_val"], 61 | classes=exp_dict["classes_val"], 62 | support_size=exp_dict["support_size_val"], 63 | query_size=exp_dict["query_size_val"], 64 | n_iters=exp_dict.get("val_iters", None), 65 | unlabeled_size=exp_dict["unlabeled_size_val"]) 66 | 67 | test_set = datasets.get_dataset(dataset_name=exp_dict["dataset_test"], 68 | data_root=os.path.join(datadir, exp_dict["dataset_test_root"]), 69 | split="test", 70 | transform=exp_dict["transform_val"], 71 | classes=exp_dict["classes_test"], 72 | support_size=exp_dict["support_size_test"], 73 | query_size=exp_dict["query_size_test"], 74 | n_iters=exp_dict["test_iters"], 75 | unlabeled_size=exp_dict["unlabeled_size_test"]) 76 | 77 | # get dataloaders 78 | # ========================== 79 | train_loader = torch.utils.data.DataLoader( 80 | train_set, 81 | batch_size=exp_dict["batch_size"], 82 | shuffle=True, 83 | num_workers=num_workers, 84 | collate_fn=ut.get_collate(exp_dict["collate_fn"]), 85 | drop_last=True) 86 | val_loader = torch.utils.data.DataLoader( 87 | val_set, 88 | batch_size=1, 89 | shuffle=False, 90 | num_workers=num_workers, 91 | collate_fn=lambda x: x, 92 | drop_last=True) 93 | test_loader = torch.utils.data.DataLoader( 94 | test_set, 95 | batch_size=1, 96 | shuffle=False, 97 | num_workers=num_workers, 98 | collate_fn=lambda x: x, 99 | drop_last=True) 100 | 101 | 102 | # create model and trainer 103 | # ========================== 104 | 105 | # Create model, opt, wrapper 106 | backbone = backbones.get_backbone(backbone_name=exp_dict['model']["backbone"], exp_dict=exp_dict) 107 | model = models.get_model(model_name=exp_dict["model"]['name'], backbone=backbone, 108 | n_classes=exp_dict["n_classes"], 109 | exp_dict=exp_dict, 110 | pretrained_weights_dir=pretrained_weights_dir, 111 | savedir_base=savedir_base) 112 | 113 | # Pretrain or Fine-tune or run SSL 114 | if exp_dict["model"]['name'] == 'ssl': 115 | # runs the SSL experiments 116 | score_list_path = os.path.join(savedir, 'score_list.pkl') 117 | if not os.path.exists(score_list_path): 118 | test_dict = model.test_on_loader(test_loader, max_iter=None) 119 | hu.save_pkl(score_list_path, [test_dict]) 120 | return 121 | 122 | # Checkpoint 123 | # ----------- 124 | checkpoint_path = os.path.join(savedir, 'checkpoint.pth') 125 | score_list_path = os.path.join(savedir, 'score_list.pkl') 126 | 127 | if os.path.exists(score_list_path): 128 | # resume experiment 129 | model.load_state_dict(hu.torch_load(checkpoint_path)) 130 | score_list = hu.load_pkl(score_list_path) 131 | s_epoch = score_list[-1]['epoch'] + 1 132 | else: 133 | # restart experiment 134 | score_list = [] 135 | s_epoch = 0 136 | 137 | # Run training and validation 138 | for epoch in range(s_epoch, exp_dict["max_epoch"]): 139 | score_dict = {"epoch": epoch} 140 | score_dict.update(model.get_lr()) 141 | 142 | # train 143 | score_dict.update(model.train_on_loader(train_loader)) 144 | 145 | # validate 146 | score_dict.update(model.val_on_loader(val_loader)) 147 | score_dict.update(model.test_on_loader(test_loader)) 148 | 149 | # Add score_dict to score_list 150 | score_list += [score_dict] 151 | 152 | # Report 153 | score_df = pd.DataFrame(score_list) 154 | print(score_df.tail()) 155 | 156 | # Save checkpoint 157 | hu.save_pkl(score_list_path, score_list) 158 | hu.torch_save(checkpoint_path, model.get_state_dict()) 159 | print("Saved: %s" % savedir) 160 | 161 | if "accuracy" in exp_dict["target_loss"]: 162 | is_best = score_dict[exp_dict["target_loss"]] >= score_df[exp_dict["target_loss"]][:-1].max() 163 | else: 164 | is_best = score_dict[exp_dict["target_loss"]] <= score_df[exp_dict["target_loss"]][:-1].min() 165 | 166 | # Save best checkpoint 167 | if is_best: 168 | hu.save_pkl(os.path.join(savedir, "score_list_best.pkl"), score_list) 169 | hu.torch_save(os.path.join(savedir, "checkpoint_best.pth"), model.get_state_dict()) 170 | print("Saved Best: %s" % savedir) 171 | 172 | # Check for end of training conditions 173 | if model.is_end_of_training(): 174 | break 175 | 176 | 177 | if __name__ == '__main__': 178 | parser = argparse.ArgumentParser() 179 | 180 | parser.add_argument('-e', '--exp_group_list', nargs='+') 181 | parser.add_argument('-sb', '--savedir_base', required=True) 182 | parser.add_argument('-d', '--datadir', default='') 183 | parser.add_argument('-r', '--reset', default=0, type=int) 184 | parser.add_argument('-ei', '--exp_id', type=str, default=None) 185 | parser.add_argument('-j', '--run_jobs', type=int, default=0) 186 | parser.add_argument('-nw', '--num_workers', default=0, type=int) 187 | parser.add_argument('-p', '--pretrained_weights_dir', type=str, default=None) 188 | 189 | args = parser.parse_args() 190 | 191 | # Collect experiments 192 | # ------------------- 193 | if args.exp_id is not None: 194 | # select one experiment 195 | savedir = os.path.join(args.savedir_base, args.exp_id) 196 | exp_dict = hu.load_json(os.path.join(savedir, 'exp_dict.json')) 197 | 198 | exp_list = [exp_dict] 199 | 200 | else: 201 | # select exp group 202 | exp_list = [] 203 | for exp_group_name in args.exp_group_list: 204 | exp_list += exp_configs.EXP_GROUPS[exp_group_name] 205 | 206 | 207 | # Run experiments or View them 208 | # ---------------------------- 209 | if args.run_jobs: 210 | pass 211 | else: 212 | # run experiments 213 | for exp_dict in exp_list: 214 | # do trainval 215 | trainval(exp_dict=exp_dict, 216 | savedir_base=args.savedir_base, 217 | datadir=args.datadir, 218 | reset=args.reset, 219 | num_workers=args.num_workers, 220 | pretrained_weights_dir=args.pretrained_weights_dir) --------------------------------------------------------------------------------