├── .gitignore
├── LICENSE
├── NOTICE
├── README.md
├── embedding_prop.jpeg
├── embedding_propagation
├── __init__.py
├── batch_embedding_propagation.py
└── embedding_propagation.py
├── exp_configs
├── __init__.py
├── finetune_exps.py
├── pretrain_exps.py
└── ssl_exps.py
├── requirements.txt
├── setup.py
├── src
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── cub.py
│ ├── episodic_cub.py
│ ├── episodic_dataset.py
│ ├── episodic_miniimagenet.py
│ ├── episodic_tiered_imagenet.py
│ ├── miniimagenet.py
│ └── tiered_imagenet.py
├── models
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── conv4.py
│ │ ├── resnet12.py
│ │ └── wrn.py
│ ├── base_ssl
│ │ ├── __init__.py
│ │ ├── distances.py
│ │ ├── oracle.py
│ │ ├── predict_methods
│ │ │ ├── __init__.py
│ │ │ ├── adaptive.py
│ │ │ ├── label_prop.py
│ │ │ └── prototypical.py
│ │ ├── selection_methods
│ │ │ ├── __init__.py
│ │ │ └── ssl.py
│ │ └── utils.py
│ ├── base_wrapper.py
│ ├── finetuning.py
│ ├── pretraining.py
│ └── ssl_wrapper.py
├── modules
│ ├── __init__.py
│ └── distances.py
├── tools
│ ├── __init__.py
│ ├── meters.py
│ └── plot_episode.py
└── utils.py
└── trainval.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | results/
6 | user_configs.py
7 | user_config.py
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
108 | *.vscode
109 | *.npy
110 | *.png
111 | *.jpg
112 | *.svg
113 | *.eps
114 | *.pdf
115 |
116 | data/
117 | user_config.py
118 |
119 | *_eai.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2021 ServiceNow
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2021 ServiceNow, Inc.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | *ServiceNow completed its acquisition of Element AI on January 8, 2021. All references to Element AI in the materials that are part of this project should refer to ServiceNow.*
2 |
3 |
4 | [](LICENSE)
5 |
6 |
7 |
Embedding Propagation
8 | Smoother Manifold for Few-Shot Classification [Paper] (ECCV2020)
9 |
10 |
11 |
12 | Embedding propagation can be used to regularize the intermediate features so that generalization performance is improved.
13 |
14 | 
15 |
16 | ## Usage
17 |
18 | Add an embedding propagation layer to your network.
19 |
20 | ```
21 | pip install git+https://github.com/ElementAI/embedding-propagation
22 | ```
23 |
24 | ```python
25 | import torch
26 | from embedding_propagation import EmbeddingPropagation
27 |
28 | ep = EmbeddingPropagation()
29 | features = torch.randn(32, 32)
30 | embeddings = ep(features)
31 | ```
32 |
33 | ## Experiments
34 |
35 | Generate the results from the [Paper].
36 |
37 | ### Install requirements
38 |
39 | `pip install -r requirements.txt`
40 |
41 | This command installs the [Haven library](https://github.com/haven-ai/haven-ai) which helps in managing the experiments.
42 |
43 | ### Download the Datasets
44 |
45 | * [mini-imagenet](https://github.com/renmengye/few-shot-ssl-public#miniimagenet) ([pre-processing](https://github.com/ElementAI/TADAM/tree/master/datasets))
46 | * [tiered-imagenet](https://github.com/renmengye/few-shot-ssl-public#tieredimagenet)
47 | * [CUB](https://github.com/wyharveychen/CloserLookFewShot/tree/master/filelists/CUB)
48 |
49 | If you have the `pkl` version of miniimagenet, you can still use it by setting the dataset name to "episodic_miniimagenet_pkl", in each of the files in `exp_configs`.
50 |
51 |
52 |
53 | ### Reproduce the results in the paper
54 |
55 | #### 1. Pre-training
56 |
57 | ```
58 | python3 trainval.py -e pretrain -sb ./logs/pretraining -d
59 | ```
60 | where `` is the directory where the data is saved.
61 |
62 | #### 2. Fine-tuning
63 |
64 | In `exp_configs/finetune_exps.py`, set `"pretrained_weights_root": ./logs/pretraining/`
65 |
66 | ```
67 | python3 trainval.py -e finetune -sb ./logs/finetuning -d
68 | ```
69 |
70 | #### 3. SSL experirments with 100 unlabeled
71 |
72 | In `exp_configs/ssl_exps.py`, set `"pretrained_weights_root": ./logs/finetuning/`
73 |
74 | ```
75 | python3 trainval.py -e ssl_large -sb ./logs/ssl/ -d
76 | ```
77 |
78 | #### 4. SSL experirments with 20-100% unlabeled
79 |
80 | In `exp_configs/ssl_exps.py`, set `"pretrained_weights_root": ./logs/finetuning/`
81 |
82 | ```
83 | python3 trainval.py -e ssl_small -sb ./logs/ssl/ -d
84 | ```
85 |
86 | ### Results
87 |
88 | |dataset|model|1-shot|5-shot|
89 | |-------|-----|------|------|
90 | |episodic_cub|conv4|65.94 ± 0.93|78.80 ± 0.64|
91 | |episodic_cub|resnet12|81.32 ± 0.84|91.02 ± 0.44|
92 | |episodic_cub|wrn|87.48 ± 0.68|93.74 ± 0.35|
93 | |episodic_miniimagenet|conv4|57.41 ± 0.85|72.35 ± 0.62|
94 | |episodic_miniimagenet|resnet12|64.82 ± 0.89|80.59 ± 0.64|
95 | |episodic_miniimagenet|wrn|69.92 ± 0.81|83.64 ± 0.54|
96 | |episodic_tiered-imagenet|conv4|58.63 ± 0.92|72.80 ± 0.78|
97 | |episodic_tiered-imagenet|resnet12|75.90 ± 0.90|86.83 ± 0.58|
98 | |episodic_tiered-imagenet|wrn|78.46 ± 0.90|87.46 ± 0.62|
99 |
100 | Different from the paper, these results were obtained on a run with fixed hyperparameters during fine-tuning: lr=0.001, alpha=0.2 (now default), train_iters=600, classification_weight=0.1
101 |
102 | ### Pre-trained weights
103 | https://zenodo.org/record/5552602#.YV2b-UbMKvU
104 |
105 | ## Citation
106 | ```
107 | @article{rodriguez2020embedding,
108 | title={Embedding Propagation: Smoother Manifold for Few-Shot Classification},
109 | author={Pau Rodríguez and Issam Laradji and Alexandre Drouin and Alexandre Lacoste},
110 | year={2020},
111 | journal={arXiv preprint arXiv:2003.04151},
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/embedding_prop.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/embedding_prop.jpeg
--------------------------------------------------------------------------------
/embedding_propagation/__init__.py:
--------------------------------------------------------------------------------
1 | from .embedding_propagation import *
2 | from .batch_embedding_propagation import *
3 |
--------------------------------------------------------------------------------
/embedding_propagation/batch_embedding_propagation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 |
6 | class BatchEmbeddingPropagation(torch.nn.Module):
7 | def __init__(self, alpha=0.5, rbf_scale=1, norm_prop=False):
8 | super().__init__()
9 | self.alpha = alpha
10 | self.rbf_scale = rbf_scale
11 | self.norm_prop = norm_prop
12 |
13 | def forward(self, x, propagator=None):
14 | return batch_embedding_propagation(x, self.alpha, self.rbf_scale, self.norm_prop, propagator=propagator)
15 |
16 |
17 | class BatchLabelPropagation(torch.nn.Module):
18 | def __init__(self, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True, balanced=False):
19 | super().__init__()
20 | self.alpha = alpha
21 | self.rbf_scale = rbf_scale
22 | self.norm_prop = norm_prop
23 | self.apply_log = apply_log
24 | self.balanced = balanced
25 |
26 | def forward(self, x, labels, nclasses, propagator=None):
27 | """Applies label propagation given a set of embeddings and labels
28 |
29 | Arguments:
30 | x {Tensor} -- Input embeddings
31 | labels {Tensor} -- Input labels from 0 to nclasses + 1. The highest value corresponds to unlabeled samples.
32 | nclasses {int} -- Total number of classes
33 |
34 | Keyword Arguments:
35 | propagator {Tensor} -- A pre-computed propagator (default: {None})
36 |
37 | Returns:
38 | tuple(Tensor, Tensor) -- Logits and Propagator
39 | """
40 | return batch_label_propagation(x, labels, nclasses, self.alpha, self.rbf_scale,
41 | self.norm_prop, self.apply_log, propagator=propagator,
42 | balanced=self.balanced)
43 |
44 | def batch_get_similarity_matrix(x, rbf_scale):
45 | b, e, c = x.size()
46 | sq_dist = ((x.view(b, e, 1, c) - x.view(b, 1, e, c))**2).sum(-1) / np.sqrt(c)
47 | mask = sq_dist != 0
48 | sq_dist = sq_dist / sq_dist[mask].std()
49 | weights = torch.exp(-sq_dist * rbf_scale)
50 | mask = torch.eye(weights.size(-1), dtype=torch.bool, device=weights.device)
51 | weights = weights * (~mask).float()
52 | return weights
53 |
54 |
55 | def batch_embedding_propagation(x, alpha, rbf_scale, norm_prop, propagator=None):
56 | if propagator is None:
57 | weights = batch_get_similarity_matrix(x, rbf_scale)
58 | propagator = batch_global_consistency(
59 | weights, alpha=alpha, norm_prop=norm_prop)
60 | return torch.bmm(propagator, x)
61 |
62 |
63 | def batch_global_consistency(weights, alpha=1, norm_prop=False):
64 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug)
65 | Args:
66 | weights: Tensor of shape (n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and
67 | s the scale parameter.
68 | labels: Tensor of shape (n, n_classes)
69 | alpha: Scaler, acts as a smoothing factor
70 | Returns:
71 | Tensor of shape (n, n_classes) representing the logits of each classes
72 | """
73 | n = weights.shape[-1]
74 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)
75 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=-1))
76 | # checknan(laplacian=isqrt_diag)
77 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, None, :]
78 | # checknan(normalizedlaplacian=S)
79 |
80 | propagator = identity[None] - alpha * S
81 | propagator = torch.inverse(propagator)
82 | # checknan(propagator=propagator)
83 | if norm_prop:
84 | propagator = F.normalize(propagator, p=1, dim=-1)
85 | return propagator
86 |
87 |
88 | def batch_label_propagation(x, labels, nclasses, alpha, rbf_scale, norm_prop, apply_log, propagator=None, balanced=False, epsilon=1e-6):
89 | labels = F.one_hot(labels, nclasses + 1)
90 | labels = labels[..., :nclasses].float() # the max label is unlabeled
91 | if balanced:
92 | labels = labels / labels.sum(-1, keepdim=True)
93 | if propagator is None:
94 | weights = batch_get_similarity_matrix(x, rbf_scale)
95 | propagator = batch_global_consistency(
96 | weights, alpha=alpha, norm_prop=norm_prop)
97 | y_pred = torch.bmm(propagator, labels)
98 | if apply_log:
99 | y_pred = torch.log(y_pred + epsilon)
100 |
101 | return y_pred
102 |
--------------------------------------------------------------------------------
/embedding_propagation/embedding_propagation.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class EmbeddingPropagation(torch.nn.Module):
8 | def __init__(self, alpha=0.5, rbf_scale=1, norm_prop=False):
9 | super().__init__()
10 | self.alpha = alpha
11 | self.rbf_scale = rbf_scale
12 | self.norm_prop = norm_prop
13 |
14 | def forward(self, x, propagator=None):
15 | return embedding_propagation(x, self.alpha, self.rbf_scale, self.norm_prop, propagator=propagator)
16 |
17 |
18 | class LabelPropagation(torch.nn.Module):
19 | def __init__(self, alpha=0.2, rbf_scale=1, norm_prop=True, apply_log=True, balanced=False):
20 | super().__init__()
21 | self.alpha = alpha
22 | self.rbf_scale = rbf_scale
23 | self.norm_prop = norm_prop
24 | self.apply_log = apply_log
25 | self.balanced = balanced
26 |
27 | def forward(self, x, labels, nclasses, propagator=None):
28 | """Applies label propagation given a set of embeddings and labels
29 |
30 | Arguments:
31 | x {Tensor} -- Input embeddings
32 | labels {Tensor} -- Input labels from 0 to nclasses + 1. The highest value corresponds to unlabeled samples.
33 | nclasses {int} -- Total number of classes
34 |
35 | Keyword Arguments:
36 | propagator {Tensor} -- A pre-computed propagator (default: {None})
37 |
38 | Returns:
39 | tuple(Tensor, Tensor) -- Logits and Propagator
40 | """
41 | return label_propagation(x, labels, nclasses, self.alpha, self.rbf_scale,
42 | self.norm_prop, self.apply_log, propagator=propagator,
43 | balanced=self.balanced)
44 |
45 |
46 | def get_similarity_matrix(x, rbf_scale):
47 | b, c = x.size()
48 | sq_dist = ((x.view(b, 1, c) - x.view(1, b, c))**2).sum(-1) / np.sqrt(c)
49 | mask = sq_dist != 0
50 | sq_dist = sq_dist / sq_dist[mask].std()
51 | weights = torch.exp(-sq_dist * rbf_scale)
52 | mask = torch.eye(weights.size(1), dtype=torch.bool, device=weights.device)
53 | weights = weights * (~mask).float()
54 | return weights
55 |
56 |
57 | def embedding_propagation(x, alpha, rbf_scale, norm_prop, propagator=None):
58 | if propagator is None:
59 | weights = get_similarity_matrix(x, rbf_scale)
60 | propagator = global_consistency(
61 | weights, alpha=alpha, norm_prop=norm_prop)
62 | return torch.mm(propagator, x)
63 |
64 |
65 | def label_propagation(x, labels, nclasses, alpha, rbf_scale, norm_prop, apply_log, propagator=None, balanced=False, epsilon=1e-6):
66 | labels = F.one_hot(labels, nclasses + 1)
67 | labels = labels[:, :nclasses].float() # the max label is unlabeled
68 | if balanced:
69 | labels = labels / labels.sum(0, keepdim=True)
70 | if propagator is None:
71 | weights = get_similarity_matrix(x, rbf_scale)
72 | propagator = global_consistency(
73 | weights, alpha=alpha, norm_prop=norm_prop)
74 | y_pred = torch.mm(propagator, labels)
75 | if apply_log:
76 | y_pred = torch.log(y_pred + epsilon)
77 |
78 | return y_pred
79 |
80 |
81 | def global_consistency(weights, alpha=1, norm_prop=False):
82 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug)
83 |
84 | Args:
85 | weights: Tensor of shape (n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and
86 | s the scale parameter.
87 | labels: Tensor of shape (n, n_classes)
88 | alpha: Scaler, acts as a smoothing factor
89 | Returns:
90 | Tensor of shape (n, n_classes) representing the logits of each classes
91 | """
92 | n = weights.shape[1]
93 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)
94 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=-1))
95 | # checknan(laplacian=isqrt_diag)
96 | S = weights * isqrt_diag[None, :] * isqrt_diag[:, None]
97 | # checknan(normalizedlaplacian=S)
98 | propagator = identity - alpha * S
99 | propagator = torch.inverse(propagator[None, ...])[0]
100 | # checknan(propagator=propagator)
101 | if norm_prop:
102 | propagator = F.normalize(propagator, p=1, dim=-1)
103 | return propagator
104 |
--------------------------------------------------------------------------------
/exp_configs/__init__.py:
--------------------------------------------------------------------------------
1 | from . import pretrain_exps, ssl_exps
2 | from . import pretrain_exps
3 | from . import finetune_exps
4 |
5 | EXP_GROUPS = {}
6 | EXP_GROUPS = pretrain_exps.EXP_GROUPS
7 | EXP_GROUPS.update(ssl_exps.EXP_GROUPS)
8 | EXP_GROUPS.update(finetune_exps.EXP_GROUPS)
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/exp_configs/finetune_exps.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | conv4 = {
4 | "name": "finetuning",
5 | 'backbone':'conv4',
6 | "depth": 4,
7 | "width": 1,
8 | "transform_train": "basic",
9 | "transform_val": "basic",
10 | "transform_test": "basic"
11 | }
12 |
13 | wrn = {
14 | "name": "finetuning",
15 | "backbone": 'wrn',
16 | "depth": 28,
17 | "width": 10,
18 | "transform_train": "wrn_finetune_train",
19 | "transform_val": "wrn_val",
20 | "transform_test": "wrn_val"
21 | }
22 |
23 | resnet12 = {
24 | "name": "finetuning",
25 | "backbone": 'resnet12',
26 | "depth": 12,
27 | "width": 1,
28 | "transform_train": "basic",
29 | "transform_val": "basic",
30 | "transform_test": "basic"
31 | }
32 |
33 | miniimagenet = {
34 | "dataset": "miniimagenet",
35 | "dataset_train": "episodic_miniimagenet",
36 | "dataset_val": "episodic_miniimagenet",
37 | "dataset_test": "episodic_miniimagenet",
38 | "data_root": "mini-imagenet",
39 | "n_classes": 64
40 | }
41 |
42 | tiered_imagenet = {
43 | "dataset": "tiered-imagenet",
44 | "n_classes": 351,
45 | "dataset_train": "episodic_tiered-imagenet",
46 | "dataset_val": "episodic_tiered-imagenet",
47 | "dataset_test": "episodic_tiered-imagenet",
48 | "data_root": "tiered-imagenet",
49 | }
50 |
51 | cub = {
52 | "dataset": "cub",
53 | "n_classes": 100,
54 | "dataset_train": "episodic_cub",
55 | "dataset_val": "episodic_cub",
56 | "dataset_test": "episodic_cub",
57 | "data_root": "CUB_200_2011"
58 | }
59 |
60 | EXP_GROUPS = {"finetune": []}
61 |
62 | for dataset in [miniimagenet, tiered_imagenet, cub]:
63 | for backbone in [conv4, resnet12, wrn]:
64 | for lr in [0.01, 0.001]:
65 | for shot in [1, 5]:
66 | for train_iters in [100, 600]:
67 | for classification_weight in [0, 0.1, 0.5]:
68 | EXP_GROUPS['finetune'] += [{"model": backbone,
69 |
70 | # Hardware
71 | "ngpu": 2,
72 | "random_seed": 42,
73 |
74 | # Optimization
75 | "batch_size": 1,
76 | "target_loss": "val_accuracy",
77 | "lr": lr,
78 | "min_lr_decay": 0.0001,
79 | "weight_decay": 0.0005,
80 | "patience": 10,
81 | "max_epoch": 200,
82 | "train_iters": train_iters,
83 | "val_iters": 600,
84 | "test_iters": 1000,
85 | "tasks_per_batch": 1,
86 | "pretrained_weights_root": "./logs/pretraining",
87 |
88 | # Model
89 | "dropout": 0.1,
90 | "avgpool": True,
91 |
92 | # Data
93 | 'n_classes': dataset["n_classes"],
94 | "collate_fn": "identity",
95 | "transform_train": backbone["transform_train"],
96 | "transform_val": backbone["transform_val"],
97 | "transform_test": backbone["transform_test"],
98 |
99 | "dataset_train": dataset["dataset_train"],
100 | 'dataset_train_root': dataset["data_root"],
101 | "classes_train": 5,
102 | "support_size_train": shot,
103 | "query_size_train": 15,
104 | "unlabeled_size_train": 0,
105 |
106 | "dataset_val": dataset["dataset_val"],
107 | 'dataset_val_root': dataset["data_root"],
108 | "classes_val": 5,
109 | "support_size_val": shot,
110 | "query_size_val": 15,
111 | "unlabeled_size_val": 0,
112 |
113 | "dataset_test": dataset["dataset_test"],
114 | 'dataset_test_root': dataset["data_root"],
115 | "classes_test": 5,
116 | "support_size_test": shot,
117 | "query_size_test": 15,
118 | "unlabeled_size_test": 0,
119 |
120 | # Hparams
121 | "embedding_prop" : True,
122 | "few_shot_weight": 1,
123 | "classification_weight": classification_weight,
124 | "rotation_weight": 0,
125 | "active_size": 0,
126 | "distance_type": "labelprop",
127 | "rotation_labels": [0],
128 | }]
129 |
--------------------------------------------------------------------------------
/exp_configs/pretrain_exps.py:
--------------------------------------------------------------------------------
1 | from haven import haven_utils as hu
2 |
3 | conv4 = {
4 | "name": "pretraining",
5 | 'backbone':'conv4',
6 | "depth": 4,
7 | "width": 1,
8 | "transform_train": "basic",
9 | "transform_val": "basic",
10 | "transform_test": "basic"
11 | }
12 |
13 | wrn = {
14 | "name": "pretraining",
15 | "backbone": 'wrn',
16 | "depth": 28,
17 | "width": 10,
18 | "transform_train": "wrn_pretrain_train",
19 | "transform_val": "wrn_val",
20 | "transform_test": "wrn_val"
21 | }
22 |
23 | resnet12 = {
24 | "name": "pretraining",
25 | "backbone": 'resnet12',
26 | "depth": 12,
27 | "width": 1,
28 | "transform_train": "basic",
29 | "transform_val": "basic",
30 | "transform_test": "basic"
31 | }
32 |
33 | miniimagenet = {
34 | "dataset": "miniimagenet",
35 | "dataset_train": "rotated_episodic_miniimagenet_pkl",
36 | "dataset_val": "episodic_miniimagenet_pkl",
37 | "dataset_test": "episodic_miniimagenet_pkl",
38 | "n_classes": 64,
39 | "data_root": "mini-imagenet"
40 | }
41 |
42 | tiered_imagenet = {
43 | "dataset": "tiered-imagenet",
44 | "n_classes": 351,
45 | "dataset_train": "rotated_tiered-imagenet",
46 | "dataset_val": "episodic_tiered-imagenet",
47 | "dataset_test": "episodic_tiered-imagenet",
48 | "data_root": "tiered-imagenet",
49 | }
50 |
51 | cub = {
52 | "dataset": "cub",
53 | "n_classes": 100,
54 | "dataset_train": "rotated_cub",
55 | "dataset_val": "episodic_cub",
56 | "dataset_test": "episodic_cub",
57 | "data_root": "CUB_200_2011"
58 | }
59 |
60 | EXP_GROUPS = {"pretrain": []}
61 |
62 | for dataset in [miniimagenet, tiered_imagenet, cub]:
63 | for backbone in [conv4, resnet12, wrn]:
64 | for lr in [0.2, 0.1]:
65 | EXP_GROUPS['pretrain'] += [{"model": backbone,
66 |
67 | # Hardware
68 | "ngpu": 4,
69 | "random_seed": 42,
70 |
71 | # Optimization
72 | "batch_size": 128,
73 | "target_loss": "val_accuracy",
74 | "lr": lr,
75 | "min_lr_decay": 0.0001,
76 | "weight_decay": 0.0005,
77 | "patience": 10,
78 | "max_epoch": 200,
79 | "train_iters": 600,
80 | "val_iters": 600,
81 | "test_iters": 600,
82 | "tasks_per_batch": 1,
83 |
84 | # Model
85 | "dropout": 0.1,
86 | "avgpool": True,
87 |
88 | # Data
89 | 'n_classes': dataset["n_classes"],
90 | "collate_fn": "default",
91 | "transform_train": backbone["transform_train"],
92 | "transform_val": backbone["transform_val"],
93 | "transform_test": backbone["transform_test"],
94 |
95 | "dataset_train": dataset["dataset_train"],
96 | "dataset_train_root": dataset["data_root"],
97 | "classes_train": 5,
98 | "support_size_train": 5,
99 | "query_size_train": 15,
100 | "unlabeled_size_train": 0,
101 |
102 | "dataset_val": dataset["dataset_val"],
103 | "dataset_val_root": dataset["data_root"],
104 | "classes_val": 5,
105 | "support_size_val": 5,
106 | "query_size_val": 15,
107 | "unlabeled_size_val": 0,
108 |
109 | "dataset_test": dataset["dataset_test"],
110 | "dataset_test_root": dataset["data_root"],
111 | "classes_test": 5,
112 | "support_size_test": 5,
113 | "query_size_test": 15,
114 | "unlabeled_size_test": 0,
115 |
116 |
117 | # Hparams
118 | "embedding_prop": True,
119 | "cross_entropy_weight": 1,
120 | "few_shot_weight": 0,
121 | "rotation_weight": 1,
122 | "active_size": 0,
123 | "distance_type": "labelprop",
124 | "kernel_bound": "",
125 | "rotation_labels": [0, 1, 2, 3]
126 | }]
127 |
--------------------------------------------------------------------------------
/exp_configs/ssl_exps.py:
--------------------------------------------------------------------------------
1 | import os
2 | from haven import haven_utils as hu
3 |
4 | import os
5 | from haven import haven_utils as hu
6 |
7 | conv4 = {
8 | "name": "ssl",
9 | 'backbone':'conv4',
10 | "depth": 4,
11 | "width": 1,
12 | "transform_train": "basic",
13 | "transform_val": "basic",
14 | "transform_test": "basic"
15 | }
16 |
17 | wrn = {
18 | "name": "ssl",
19 | "backbone": 'wrn',
20 | "depth": 28,
21 | "width": 10,
22 | "transform_train": "wrn_finetune_train",
23 | "transform_val": "wrn_val",
24 | "transform_test": "wrn_val"
25 | }
26 |
27 | resnet12 = {
28 | "name": "ssl",
29 | "backbone": 'resnet12',
30 | "depth": 12,
31 | "width": 1,
32 | "transform_train": "basic",
33 | "transform_val": "basic",
34 | "transform_test": "basic"
35 | }
36 |
37 | miniimagenet = {
38 | "dataset": "miniimagenet",
39 | "dataset_train": "episodic_miniimagenet",
40 | "dataset_val": "episodic_miniimagenet",
41 | "dataset_test": "episodic_miniimagenet",
42 | "n_classes": 64,
43 | 'data_root':'mini-imagenet/'
44 | }
45 |
46 | tiered_imagenet = {
47 | "dataset": "tiered-imagenet",
48 | "n_classes": 351,
49 | "dataset_train": "episodic_tiered-imagenet",
50 | "dataset_val": "episodic_tiered-imagenet",
51 | "dataset_test": "episodic_tiered-imagenet",
52 | 'data_root':'tiered-imagenet'
53 | }
54 |
55 | cub = {
56 | "dataset": "cub",
57 | "n_classes": 100,
58 | "dataset_train": "episodic_cub",
59 | "dataset_val": "episodic_cub",
60 | "dataset_test": "episodic_cub",
61 | 'data_root':'CUB_200_2011'
62 | }
63 |
64 | EXP_GROUPS = {}
65 | EXP_GROUPS['ssl_large'] = []
66 | # 12 exps
67 | for dataset in [miniimagenet, tiered_imagenet]:
68 | for backbone in [resnet12, conv4, wrn]:
69 | for embedding_prop in [True]:
70 | for shot in [1, 5]:
71 | EXP_GROUPS['ssl_large'] += [{
72 | 'dataset_train_root': dataset["data_root"],
73 | 'dataset_val_root': dataset["data_root"],
74 | 'dataset_test_root': dataset["data_root"],
75 | "model": backbone,
76 |
77 | # Hardware
78 | "ngpu": 1,
79 | "random_seed": 42,
80 |
81 | # Optimization
82 | "batch_size": 1,
83 | "train_iters": 10,
84 | "test_iters": 600,
85 | "tasks_per_batch": 1,
86 |
87 | # Model
88 | "dropout": 0.1,
89 | "avgpool": True,
90 |
91 | # Data
92 | 'n_classes': dataset["n_classes"],
93 | "collate_fn": "identity",
94 | "transform_train": backbone["transform_train"],
95 | "transform_val": backbone["transform_val"],
96 | "transform_test": backbone["transform_test"],
97 |
98 | "dataset_train": dataset["dataset_train"],
99 | "classes_train": 5,
100 | "support_size_train": shot,
101 | "query_size_train": 15,
102 | "unlabeled_size_train": 0,
103 |
104 | "dataset_val": dataset["dataset_val"],
105 | "classes_val": 5,
106 | "support_size_val": shot,
107 | "query_size_val": 15,
108 | "unlabeled_size_val": 0,
109 |
110 | "dataset_test": dataset["dataset_test"],
111 | "classes_test": 5,
112 | "support_size_test": shot,
113 | "query_size_test": 15,
114 | "unlabeled_size_test": 100,
115 | "predict_method": "labelprop",
116 | "finetuned_weights_root": "./logs/finetuning",
117 |
118 | # Hparams
119 | "embedding_prop" : embedding_prop,
120 | }]
121 |
122 | # 24 exps
123 | EXP_GROUPS['ssl_small'] = []
124 | for dataset in [tiered_imagenet, miniimagenet]:
125 | for backbone in [conv4, resnet12, wrn]:
126 | for embedding_prop in [True]:
127 | for shot, ust in zip([1,2,3],
128 | [4,3,2]):
129 | EXP_GROUPS['ssl_small'] += [{
130 | 'dataset_train_root': dataset["data_root"],
131 | 'dataset_val_root': dataset["data_root"],
132 | 'dataset_test_root': dataset["data_root"],
133 |
134 | "model": backbone,
135 |
136 | # Hardware
137 | "ngpu": 1,
138 | "random_seed": 42,
139 |
140 | # Optimization
141 | "batch_size": 1,
142 | "train_iters": 10,
143 | "val_iters": 600,
144 | "test_iters": 600,
145 | "tasks_per_batch": 1,
146 |
147 | # Model
148 | "dropout": 0.1,
149 | "avgpool": True,
150 |
151 | # Data
152 | 'n_classes': dataset["n_classes"],
153 | "collate_fn": "identity",
154 | "transform_train": backbone["transform_train"],
155 | "transform_val": backbone["transform_val"],
156 | "transform_test": backbone["transform_test"],
157 |
158 | "dataset_train": dataset["dataset_train"],
159 | "classes_train": 5,
160 | "support_size_train": shot,
161 | "query_size_train": 15,
162 | "unlabeled_size_train": 0,
163 |
164 | "dataset_val": dataset["dataset_val"],
165 | "classes_val": 5,
166 | "support_size_val": shot,
167 | "query_size_val": 15,
168 | "unlabeled_size_val": 0,
169 |
170 | "dataset_test": dataset["dataset_test"],
171 | "classes_test": 5,
172 | "support_size_test": shot,
173 | "query_size_test": 15,
174 | "unlabeled_size_test": ust,
175 | "predict_method":"labelprop",
176 | "finetuned_weights_root": "./logs/finetuning",
177 |
178 | # Hparams
179 | "embedding_prop" : embedding_prop,
180 | }]
181 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | haven-ai>=0.0
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='embedding_propagation',
4 | version='0.6.0',
5 | description='Manifold Regularization',
6 | url='https://github.com/ElementAI/embedding-propagation',
7 | maintainer='Pau Rodríguez',
8 | maintainer_email='issam.laradji@elementai.com',
9 | license='MIT',
10 | packages=['embedding_propagation'],
11 | zip_safe=False,
12 | install_requires=[
13 | 'tqdm>=0.0',
14 | 'matplotlib>=0.0',
15 | 'numpy>=0.0',
16 | 'pandas>=0.0',
17 | 'Pillow>=0.0',
18 | 'scikit-image>=0.0',
19 | 'scikit-learn>=0.0',
20 | 'sklearn>=0.0',
21 | 'torch>=0.0',
22 | 'torchvision>=0.0'
23 | ]),
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/__init__.py
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # from . import trancos, fish_reg
3 | from torchvision import transforms
4 | import torchvision
5 | import cv2, os
6 | from src import utils as ut
7 | import pandas as pd
8 | import numpy as np
9 |
10 | import pandas as pd
11 | import numpy as np
12 | from torchvision.datasets import CIFAR10, CIFAR100
13 | from .episodic_dataset import FewShotSampler
14 | from .episodic_miniimagenet import EpisodicMiniImagenet, EpisodicMiniImagenetPkl
15 | from .miniimagenet import NonEpisodicMiniImagenet, RotatedNonEpisodicMiniImagenet, RotatedNonEpisodicMiniImagenetPkl
16 | from .episodic_tiered_imagenet import EpisodicTieredImagenet
17 | from .tiered_imagenet import RotatedNonEpisodicTieredImagenet
18 | from .cub import RotatedNonEpisodicCUB, NonEpisodicCUB
19 | from .episodic_cub import EpisodicCUB
20 | from .episodic_tiered_imagenet import EpisodicTieredImagenet
21 | from .tiered_imagenet import RotatedNonEpisodicTieredImagenet, NonEpisodicTieredImagenet
22 |
23 |
24 | def get_dataset(dataset_name,
25 | data_root,
26 | split,
27 | transform,
28 | classes,
29 | support_size,
30 | query_size,
31 | unlabeled_size,
32 | n_iters):
33 |
34 | transform_func = get_transformer(transform, split)
35 | if dataset_name == "rotated_miniimagenet":
36 | dataset = RotatedNonEpisodicMiniImagenet(data_root,
37 | split,
38 | transform_func)
39 |
40 | elif dataset_name == "miniimagenet":
41 | dataset = NonEpisodicMiniImagenet(data_root,
42 | split,
43 | transform_func)
44 | elif dataset_name == "episodic_miniimagenet":
45 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size)
46 | dataset = EpisodicMiniImagenet(data_root=data_root,
47 | split=split,
48 | sampler=few_shot_sampler,
49 | size=n_iters,
50 | transforms=transform_func)
51 | elif dataset_name == "rotated_episodic_miniimagenet_pkl":
52 | dataset = RotatedNonEpisodicMiniImagenetPkl(data_root=data_root,
53 | split=split,
54 | transforms=transform_func)
55 | elif dataset_name == "episodic_miniimagenet_pkl":
56 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size)
57 | dataset = EpisodicMiniImagenetPkl(data_root=data_root,
58 | split=split,
59 | sampler=few_shot_sampler,
60 | size=n_iters,
61 | transforms=transform_func)
62 | elif dataset_name == "cub":
63 | dataset = NonEpisodicCUB(data_root, split, transform_func)
64 | elif dataset_name == "rotated_cub":
65 | dataset = RotatedNonEpisodicCUB(data_root, split, transform_func)
66 | elif dataset_name == "episodic_cub":
67 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size)
68 | dataset = EpisodicCUB(data_root=data_root,
69 | split=split,
70 | sampler=few_shot_sampler,
71 | size=n_iters,
72 | transforms=transform_func)
73 | elif dataset_name == "tiered-imagenet":
74 | dataset = NonEpisodicTieredImagenet(data_root, split, transform_func)
75 | elif dataset_name == "rotated_tiered-imagenet":
76 | dataset = RotatedNonEpisodicTieredImagenet(data_root, split, transform_func)
77 | elif dataset_name == "episodic_tiered-imagenet":
78 | few_shot_sampler = FewShotSampler(classes, support_size, query_size, unlabeled_size)
79 | dataset = EpisodicTieredImagenet(data_root,
80 | split=split,
81 | sampler=few_shot_sampler,
82 | size=n_iters,
83 | transforms=transform_func)
84 | return dataset
85 |
86 |
87 | # ===================================================
88 | # helpers
89 | def get_transformer(transform, split):
90 | if transform == "data_augmentation":
91 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
92 | torchvision.transforms.Resize((84,84)),
93 | torchvision.transforms.ToTensor()])
94 |
95 | return transform
96 |
97 | if "{}_{}".format(transform, split) == "cifar_train":
98 | transform = torchvision.transforms.transforms.Compose([
99 | torchvision.transforms.RandomCrop(32, padding=4),
100 | torchvision.transforms.RandomHorizontalFlip(),
101 | torchvision.transforms.ToTensor(),
102 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
103 | ])
104 | return transform
105 |
106 | if "{}_{}".format(transform, split) == "cifar_test" or "{}_{}".format(transform, split) == "cifar_val":
107 | transform = torchvision.transforms.Compose([
108 | torchvision.transforms.ToTensor(),
109 | torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
110 | ])
111 | return transform
112 |
113 | if transform == "basic":
114 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
115 | torchvision.transforms.Resize((84,84)),
116 | torchvision.transforms.ToTensor()])
117 |
118 | return transform
119 |
120 | if transform == "wrn_pretrain_train":
121 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
122 | torchvision.transforms.RandomResizedCrop((80, 80), scale=(0.08, 1)),
123 | torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
124 | torchvision.transforms.ToTensor()
125 | ])
126 | return transform
127 | elif transform == "wrn_finetune_train":
128 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
129 | torchvision.transforms.Resize((92, 92)),
130 | torchvision.transforms.CenterCrop(80),
131 | torchvision.transforms.RandomHorizontalFlip(),
132 | torchvision.transforms.ToTensor()
133 | ])
134 | return transform
135 |
136 | elif "wrn" in transform:
137 | transform = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
138 | torchvision.transforms.Resize((92, 92)),
139 | torchvision.transforms.CenterCrop(80),
140 | torchvision.transforms.ToTensor()])
141 | return transform
142 |
143 | raise NotImplementedError
144 |
145 |
--------------------------------------------------------------------------------
/src/datasets/cub.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torchvision
3 | import torch
4 | from torch.utils.data import Dataset
5 | import json
6 | import os
7 | import numpy as np
8 | from PIL import Image
9 |
10 | class NonEpisodicCUB(Dataset):
11 | name="CUB"
12 | task="cls"
13 | split_paths = {"train":"train", "test":"test", "valid": "val"}
14 | c = 3
15 | h = 84
16 | w = 84
17 |
18 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs):
19 | """ Constructor
20 |
21 | Args:
22 | split: data split
23 | few_shot_sampler: FewShotSampler instance
24 | task: dataset task (if more than one)
25 | size: number of tasks to generate (int)
26 | disjoint: whether to create disjoint splits.
27 | """
28 | self.data_root = data_root
29 | self.split = {"train":"base", "val":"val", "valid":"val", "test":"novel"}[split]
30 | with open(os.path.join(self.data_root, "few_shot_lists", "%s.json" %self.split), 'r') as infile:
31 | self.metadata = json.load(infile)
32 | self.transforms = transforms
33 | self.rotation_labels = rotation_labels
34 | self.labels = np.array(self.metadata['image_labels'])
35 | label_map = {l: i for i, l in enumerate(sorted(np.unique(self.labels)))}
36 | self.labels = np.array([label_map[l] for l in self.labels])
37 | self.size = len(self.metadata["image_labels"])
38 |
39 | def next_run(self):
40 | pass
41 |
42 | def rotate_img(self, img, rot):
43 | if rot == 0: # 0 degrees rotation
44 | return img
45 | elif rot == 90: # 90 degrees rotation
46 | return np.flipud(np.transpose(img, (1, 0, 2)))
47 | elif rot == 180: # 90 degrees rotation
48 | return np.fliplr(np.flipud(img))
49 | elif rot == 270: # 270 degrees rotation / or -90
50 | return np.transpose(np.flipud(img), (1, 0, 2))
51 | else:
52 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
53 |
54 | def __getitem__(self, item):
55 | image = np.array(Image.open(self.metadata["image_names"][item]).convert("RGB"))
56 | images = self.transforms(image) * 2 - 1
57 | return images, int(self.labels[item])
58 |
59 | def __len__(self):
60 | return len(self.labels)
61 |
62 | class RotatedNonEpisodicCUB(NonEpisodicCUB):
63 | name="CUB"
64 | task="cls"
65 | split_paths = {"train":"train", "test":"test", "valid": "val"}
66 | c = 3
67 | h = 84
68 | w = 84
69 |
70 | def __init__(self, *args, **kwargs):
71 | """ Constructor
72 |
73 | Args:
74 | split: data split
75 | few_shot_sampler: FewShotSampler instance
76 | task: dataset task (if more than one)
77 | size: number of tasks to generate (int)
78 | disjoint: whether to create disjoint splits.
79 | """
80 | super().__init__(*args, **kwargs)
81 |
82 | def rotate_img(self, img, rot):
83 | if rot == 0: # 0 degrees rotation
84 | return img
85 | elif rot == 90: # 90 degrees rotation
86 | return np.flipud(np.transpose(img, (1, 0, 2)))
87 | elif rot == 180: # 90 degrees rotation
88 | return np.fliplr(np.flipud(img))
89 | elif rot == 270: # 270 degrees rotation / or -90
90 | return np.transpose(np.flipud(img), (1, 0, 2))
91 | else:
92 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
93 |
94 | def __getitem__(self, item):
95 | image = np.array(Image.open(self.metadata["image_names"][item]).convert("RGB"))
96 | if np.random.randint(2):
97 | image = np.fliplr(image)
98 | image_90 = self.transforms(self.rotate_img(image, 90))
99 | image_180 = self.transforms(self.rotate_img(image, 180))
100 | image_270 = self.transforms(self.rotate_img(image, 270))
101 | images = torch.stack([self.transforms(image), image_90, image_180, image_270]) * 2 - 1
102 | return images, torch.ones(4, dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels)
--------------------------------------------------------------------------------
/src/datasets/episodic_cub.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torchvision
3 | import torch
4 | from torch.utils.data import Dataset
5 | from .episodic_dataset import EpisodicDataset, FewShotSampler
6 | import json
7 | import os
8 | import numpy as np
9 | from PIL import Image
10 | import numpy
11 |
12 | class EpisodicCUB(EpisodicDataset):
13 | h = 84
14 | w = 84
15 | c = 3
16 | name="CUB"
17 | task="cls"
18 | split_paths = {"train":"base", "val":"val", "valid":"val", "test":"novel"}
19 | def __init__(self, data_root, split, sampler, size, transforms):
20 | self.data_root = data_root
21 | self.split = split
22 | with open(os.path.join(self.data_root, "few_shot_lists", "%s.json" %self.split_paths[split]), 'r') as infile:
23 | self.metadata = json.load(infile)
24 | labels = np.array(self.metadata['image_labels'])
25 | label_map = {l: i for i, l in enumerate(sorted(np.unique(labels)))}
26 | labels = np.array([label_map[l] for l in labels])
27 | super().__init__(labels, sampler, size, transforms)
28 |
29 | def sample_images(self, indices):
30 | return [np.array(Image.open(self.metadata['image_names'][i]).convert("RGB")) for i in indices]
31 |
32 | def __iter__(self):
33 | return super().__iter__()
34 |
35 | if __name__ == '__main__':
36 | from torch.utils.data import DataLoader
37 | from tools.plot_episode import plot_episode
38 | sampler = FewShotSampler(5, 5, 15, 0)
39 | transforms = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
40 | torchvision.transforms.Resize((84,84)),
41 | torchvision.transforms.ToTensor(),
42 | ])
43 | dataset = EpisodicCUB("train", sampler, 10, transforms)
44 | loader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x)
45 | for batch in loader:
46 | plot_episode(batch[0], classes_first=False)
47 |
48 |
--------------------------------------------------------------------------------
/src/datasets/episodic_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 | from torch.utils.data import DataLoader
5 | import collections
6 | from copy import deepcopy
7 |
8 | _DataLoader = DataLoader
9 | class EpisodicDataLoader(_DataLoader):
10 | def __iter__(self):
11 | if isinstance(self.dataset, EpisodicDataset):
12 | self.dataset.__iter__()
13 | else:
14 | pass
15 | return super().__iter__()
16 | torch.utils.data.DataLoader = EpisodicDataLoader
17 |
18 | class FewShotSampler():
19 | FewShotTask = collections.namedtuple("FewShotTask", ["nclasses", "support_size", "query_size", "unlabeled_size"])
20 | def __init__(self, nclasses, support_size, query_size, unlabeled_size):
21 | self.task = self.FewShotTask(nclasses, support_size, query_size, unlabeled_size)
22 |
23 | def sample(self):
24 | return deepcopy(self.task)
25 |
26 | class EpisodicDataset(Dataset):
27 | def __init__(self, labels, sampler, size, transforms):
28 | self.labels = labels
29 | self.sampler = sampler
30 | self.labelset = np.unique(labels)
31 | self.indices = np.arange(len(labels))
32 | self.transforms = transforms
33 | self.reshuffle()
34 | self.size = size
35 |
36 | def reshuffle(self):
37 | """
38 | Helper method to randomize tasks again
39 | """
40 | self.clss_idx = [np.random.permutation(self.indices[self.labels == label]) for label in self.labelset]
41 | self.starts = np.zeros(len(self.clss_idx), dtype=int)
42 | self.lengths = np.array([len(x) for x in self.clss_idx])
43 |
44 | def gen_few_shot_task(self, nclasses, size):
45 | """ Iterates through the dataset sampling tasks
46 |
47 | Args:
48 | n: FewShotTask.n
49 | sample_size: FewShotTask.k
50 | query_size: FewShotTask.k (default), else query_set_size // FewShotTask.n
51 |
52 | Returns: Sampled task or None in the case the dataset has been exhausted.
53 |
54 | """
55 | classes = np.random.choice(self.labelset, nclasses, replace=False)
56 | starts = self.starts[classes]
57 | reminders = self.lengths[classes] - starts
58 | if np.min(reminders) < size:
59 | return None
60 | sample_indices = np.array(
61 | [self.clss_idx[classes[i]][starts[i]:(starts[i] + size)] for i in range(len(classes))])
62 | sample_indices = np.reshape(sample_indices, [nclasses, size]).transpose()
63 | self.starts[classes] += size
64 | return sample_indices.flatten()
65 |
66 | def sample_task_list(self):
67 | """ Generates a list of tasks (until the dataset is exhausted)
68 |
69 | Returns: the list of tasks [(FewShotTask object, task_indices), ...]
70 |
71 | """
72 | task_list = []
73 | task_info = self.sampler.sample()
74 | nclasses, support_size, query_size, unlabeled_size = task_info
75 | unlabeled_size = min(unlabeled_size, self.lengths.min() - support_size - query_size)
76 | task_info = FewShotSampler.FewShotTask(nclasses=nclasses,
77 | support_size=support_size,
78 | query_size=query_size,
79 | unlabeled_size=unlabeled_size)
80 | k = support_size + query_size + unlabeled_size
81 | if np.any(k > self.lengths):
82 | raise RuntimeError("Requested more samples than existing")
83 | few_shot_task = self.gen_few_shot_task(nclasses, k)
84 |
85 | while few_shot_task is not None:
86 | task_list.append((task_info, few_shot_task))
87 | task_info = self.sampler.sample()
88 | nclasses, support_size, query_size, unlabeled_size = task_info
89 | k = support_size + query_size + unlabeled_size
90 | few_shot_task = self.gen_few_shot_task(nclasses, k)
91 | return task_list
92 |
93 | def sample_images(self, indices):
94 | raise NotImplementedError
95 |
96 | def __getitem__(self, idx):
97 | """ Reads the idx th task (episode) from disk
98 |
99 | Args:
100 | idx: task index
101 |
102 | Returns: task dictionary with (dataset (char), task (char), dim (tuple), episode (Tensor))
103 |
104 | """
105 | fs_task_info, indices = self.task_list[idx]
106 | ordered_argindices = np.argsort(indices)
107 | ordered_indices = np.sort(indices)
108 | nclasses, support_size, query_size, unlabeled_size = fs_task_info
109 | k = support_size + query_size + unlabeled_size
110 | _images = self.sample_images(ordered_indices)
111 | images = torch.stack([self.transforms(_images[i]) for i in np.argsort(ordered_argindices)])
112 | total, c, h, w = images.size()
113 | assert(total == (k * nclasses))
114 | images = images.view(k, nclasses, c, h, w)
115 | del(_images)
116 | images = images * 2 - 1
117 | targets = np.zeros([nclasses * k], dtype=int)
118 | targets[ordered_argindices] = self.labels[ordered_indices, ...].ravel()
119 | sample = {"dataset": self.name,
120 | "channels": c,
121 | "height": h,
122 | "width": w,
123 | "nclasses": nclasses,
124 | "support_size": support_size,
125 | "query_size": query_size,
126 | "unlabeled_size": unlabeled_size,
127 | "targets": torch.from_numpy(targets),
128 | "support_set": images[:support_size, ...],
129 | "query_set": images[support_size:(support_size +
130 | query_size), ...],
131 | "unlabeled_set": None if unlabeled_size == 0 else images[(support_size + query_size):, ...]}
132 | return sample
133 |
134 |
135 | def __iter__(self):
136 | # print("Prefetching new epoch episodes")
137 | self.task_list = []
138 | while len(self.task_list) < self.size:
139 | self.reshuffle()
140 | self.task_list += self.sample_task_list()
141 | # print("done prefetching.")
142 | return []
143 |
144 | def __len__(self):
145 | return self.size
--------------------------------------------------------------------------------
/src/datasets/episodic_miniimagenet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision
4 | import os
5 | from src.datasets.episodic_dataset import EpisodicDataset, FewShotSampler
6 | import pickle as pkl
7 |
8 | # Inherit order is important, FewShotDataset constructor is prioritary
9 | class EpisodicMiniImagenet(EpisodicDataset):
10 | tasks_type = "clss"
11 | name = "miniimagenet"
12 | episodic=True
13 | split_paths = {"train":"train", "valid":"val", "val":"val", "test": "test"}
14 | # c = 3
15 | # h = 84
16 | # w = 84
17 |
18 | def __init__(self, data_root, split, sampler, size, transforms):
19 | """ Constructor
20 |
21 | Args:
22 | split: data split
23 | few_shot_sampler: FewShotSampler instance
24 | task: dataset task (if more than one)
25 | size: number of tasks to generate (int)
26 | disjoint: whether to create disjoint splits.
27 | """
28 | self.data_root = os.path.join(data_root, "mini-imagenet-%s.npz")
29 | self.split = split
30 | data = np.load(self.data_root % self.split_paths[split])
31 | self.features = data["features"]
32 | labels = data["targets"]
33 | del(data)
34 | super().__init__(labels, sampler, size, transforms)
35 |
36 | def sample_images(self, indices):
37 | return self.features[indices]
38 |
39 | def __iter__(self):
40 | return super().__iter__()
41 |
42 | # Inherit order is important, FewShotDataset constructor is prioritary
43 | class EpisodicMiniImagenetPkl(EpisodicDataset):
44 | tasks_type = "clss"
45 | name = "miniimagenet"
46 | episodic=True
47 | split_paths = {"train":"train", "valid":"val", "val":"val", "test": "test"}
48 | # c = 3
49 | # h = 84
50 | # w = 84
51 |
52 | def __init__(self, data_root, split, sampler, size, transforms):
53 | """ Constructor
54 |
55 | Args:
56 | split: data split
57 | few_shot_sampler: FewShotSampler instance
58 | task: dataset task (if more than one)
59 | size: number of tasks to generate (int)
60 | disjoint: whether to create disjoint splits.
61 | """
62 | self.data_root = os.path.join(data_root, "mini-imagenet-cache-%s.pkl")
63 | self.split = split
64 | with open(self.data_root % self.split_paths[split], 'rb') as infile:
65 | data = pkl.load(infile)
66 | self.features = data["image_data"]
67 | label_names = data["class_dict"].keys()
68 | labels = np.zeros((self.features.shape[0],), dtype=int)
69 | for i, name in enumerate(sorted(label_names)):
70 | labels[np.array(data['class_dict'][name])] = i
71 | del(data)
72 | super().__init__(labels, sampler, size, transforms)
73 |
74 | def sample_images(self, indices):
75 | return self.features[indices]
76 |
77 | def __iter__(self):
78 | return super().__iter__()
79 |
80 | if __name__ == '__main__':
81 | from torch.utils.data import DataLoader
82 | from src.tools.plot_episode import plot_episode
83 | import time
84 | sampler = FewShotSampler(5, 5, 15, 0)
85 | transforms = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
86 | torchvision.transforms.ToTensor(),
87 | ])
88 | dataset = EpisodicMiniImagenetPkl('./miniimagenet', 'train', sampler, 1000, transforms)
89 | loader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x)
90 | for batch in loader:
91 | print(np.unique(batch[0]["targets"].view(20, 5).numpy()))
92 | # plot_episode(batch[0], classes_first=False)
93 | # time.sleep(1)
94 |
95 |
--------------------------------------------------------------------------------
/src/datasets/episodic_tiered_imagenet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torchvision
3 | import torch
4 | from torch.utils.data import Dataset
5 | from .episodic_dataset import EpisodicDataset, FewShotSampler
6 | import json
7 | import os
8 | import numpy as np
9 | import numpy
10 | import cv2
11 | import pickle as pkl
12 |
13 | # Inherit order is important, FewShotDataset constructor is prioritary
14 | class EpisodicTieredImagenet(EpisodicDataset):
15 | tasks_type = "clss"
16 | name = "tiered-imagenet"
17 | split_paths = {"train":"train", "test":"test", "valid": "val"}
18 | c = 3
19 | h = 84
20 | w = 84
21 | def __init__(self, data_root, split, sampler, size, transforms):
22 | self.data_root = data_root
23 | self.split = split
24 | img_path = os.path.join(self.data_root, "%s_images_png.pkl" %(split))
25 | label_path = os.path.join(self.data_root, "%s_labels.pkl" %(split))
26 | with open(img_path, 'rb') as infile:
27 | self.features = pkl.load(infile, encoding="bytes")
28 | with open(label_path, 'rb') as infile:
29 | labels = pkl.load(infile, encoding="bytes")[b'label_specific']
30 | super().__init__(labels, sampler, size, transforms)
31 |
32 | def sample_images(self, indices):
33 | return [cv2.imdecode(self.features[i], cv2.IMREAD_COLOR)[:,:,::-1] for i in indices]
34 |
35 | def __iter__(self):
36 | return super().__iter__()
37 |
38 | if __name__ == '__main__':
39 | import sys
40 | from torch.utils.data import DataLoader
41 | from tools.plot_episode import plot_episode
42 | sampler = FewShotSampler(5, 5, 15, 0)
43 | transforms = torchvision.transforms.Compose([torchvision.transforms.ToPILImage(),
44 | torchvision.transforms.Resize((84,84)),
45 | torchvision.transforms.ToTensor(),
46 | ])
47 | dataset = EpisodicTieredImagenet("train", sampler, 10, transforms)
48 | loader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x)
49 | for batch in loader:
50 | plot_episode(batch[0], classes_first=False)
51 |
--------------------------------------------------------------------------------
/src/datasets/miniimagenet.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | import numpy as np
4 | import os
5 | import pickle as pkl
6 |
7 | class NonEpisodicMiniImagenet(Dataset):
8 | tasks_type = "clss"
9 | name = "miniimagenet"
10 | split_paths = {"train": "train", "val":"val", "valid": "val", "test": "test"}
11 | episodic=False
12 | c = 3
13 | h = 84
14 | w = 84
15 |
16 | def __init__(self, data_root, split, transforms, **kwargs):
17 | """ Constructor
18 |
19 | Args:
20 | split: data split
21 | few_shot_sampler: FewShotSampler instance
22 | task: dataset task (if more than one)
23 | size: number of tasks to generate (int)
24 | disjoint: whether to create disjoint splits.
25 | """
26 | self.data_root = os.path.join(data_root, "mini-imagenet-%s.npz")
27 | data = np.load(self.data_root % self.split_paths[split])
28 | self.features = data["features"]
29 | self.labels = data["targets"]
30 | self.transforms = transforms
31 |
32 | def next_run(self):
33 | pass
34 |
35 | def __getitem__(self, item):
36 | image = self.transforms(self.features[item])
37 | image = image * 2 - 1
38 | return image, self.labels[item]
39 |
40 | def __len__(self):
41 | return len(self.features)
42 |
43 | class RotatedNonEpisodicMiniImagenet(Dataset):
44 | tasks_type = "clss"
45 | name = "miniimagenet"
46 | split_paths = {"train": "train", "val":"val", "valid": "val", "test": "test"}
47 | episodic=False
48 | c = 3
49 | h = 84
50 | w = 84
51 |
52 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs):
53 | """ Constructor
54 |
55 | Args:
56 | split: data split
57 | few_shot_sampler: FewShotSampler instance
58 | task: dataset task (if more than one)
59 | size: number of tasks to generate (int)
60 | disjoint: whether to create disjoint splits.
61 | """
62 | self.data_root = os.path.join(data_root, "mini-imagenet-%s.npz")
63 | data = np.load(self.data_root % self.split_paths[split])
64 | self.features = data["features"]
65 | self.labels = data["targets"]
66 | self.transforms = transforms
67 | self.size = len(self.features)
68 | self.rotation_labels = rotation_labels
69 |
70 | def next_run(self):
71 | pass
72 |
73 | def rotate_img(self, img, rot):
74 | if rot == 0: # 0 degrees rotation
75 | return img
76 | elif rot == 90: # 90 degrees rotation
77 | return np.flipud(np.transpose(img, (1, 0, 2)))
78 | elif rot == 180: # 90 degrees rotation
79 | return np.fliplr(np.flipud(img))
80 | elif rot == 270: # 270 degrees rotation / or -90
81 | return np.transpose(np.flipud(img), (1, 0, 2))
82 | else:
83 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
84 |
85 | def __getitem__(self, item):
86 | image = self.features[item]
87 | if np.random.randint(2):
88 | image = np.fliplr(image).copy()
89 | cat = [self.transforms(image)]
90 | if len(self.rotation_labels) > 1:
91 | image_90 = self.transforms(self.rotate_img(image, 90))
92 | image_180 = self.transforms(self.rotate_img(image, 180))
93 | image_270 = self.transforms(self.rotate_img(image, 270))
94 | cat.extend([image_90, image_180, image_270])
95 | images = torch.stack(cat) * 2 - 1
96 | return images, torch.ones(len(self.rotation_labels), dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels)
97 |
98 | def __len__(self):
99 | return self.size
100 |
101 | class RotatedNonEpisodicMiniImagenetPkl(Dataset):
102 | tasks_type = "clss"
103 | name = "miniimagenet"
104 | split_paths = {"train": "train", "val":"val", "valid": "val", "test": "test"}
105 | episodic=False
106 | c = 3
107 | h = 84
108 | w = 84
109 |
110 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs):
111 | """ Constructor
112 |
113 | Args:
114 | split: data split
115 | few_shot_sampler: FewShotSampler instance
116 | task: dataset task (if more than one)
117 | size: number of tasks to generate (int)
118 | disjoint: whether to create disjoint splits.
119 | """
120 | self.data_root = os.path.join(data_root, "mini-imagenet-cache-%s.pkl")
121 | with open(self.data_root % self.split_paths[split], 'rb') as infile:
122 | data = pkl.load(infile)
123 | self.features = data["image_data"]
124 | label_names = data["class_dict"].keys()
125 | self.labels = np.zeros((self.features.shape[0],), dtype=int)
126 | for i, name in enumerate(sorted(label_names)):
127 | self.labels[np.array(data['class_dict'][name])] = i
128 | del(data)
129 | self.transforms = transforms
130 | self.size = len(self.features)
131 | self.rotation_labels = rotation_labels
132 |
133 | def next_run(self):
134 | pass
135 |
136 | def rotate_img(self, img, rot):
137 | if rot == 0: # 0 degrees rotation
138 | return img
139 | elif rot == 90: # 90 degrees rotation
140 | return np.flipud(np.transpose(img, (1, 0, 2)))
141 | elif rot == 180: # 90 degrees rotation
142 | return np.fliplr(np.flipud(img))
143 | elif rot == 270: # 270 degrees rotation / or -90
144 | return np.transpose(np.flipud(img), (1, 0, 2))
145 | else:
146 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
147 |
148 | def __getitem__(self, item):
149 | image = self.features[item]
150 | if np.random.randint(2):
151 | image = np.fliplr(image).copy()
152 | cat = [self.transforms(image)]
153 | if len(self.rotation_labels) > 1:
154 | image_90 = self.transforms(self.rotate_img(image, 90))
155 | image_180 = self.transforms(self.rotate_img(image, 180))
156 | image_270 = self.transforms(self.rotate_img(image, 270))
157 | cat.extend([image_90, image_180, image_270])
158 | images = torch.stack(cat) * 2 - 1
159 | return images, torch.ones(len(self.rotation_labels), dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels)
160 |
161 | def __len__(self):
162 | return self.size
163 |
--------------------------------------------------------------------------------
/src/datasets/tiered_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import PIL
4 | import pickle as pkl
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 | import torch
8 |
9 | class NonEpisodicTieredImagenet(Dataset):
10 | tasks_type = "clss"
11 | name = "tiered-imagenet"
12 | split_paths = {"train":"train", "test":"test", "valid": "val"}
13 | c = 3
14 | h = 84
15 | w = 84
16 |
17 | def __init__(self, data_root, split, transforms, rotation_labels=[0, 1, 2, 3], **kwargs):
18 | """ Constructor
19 |
20 | Args:
21 | split: data split
22 | few_shot_sampler: FewShotSampler instance
23 | task: dataset task (if more than one)
24 | size: number of tasks to generate (int)
25 | disjoint: whether to create disjoint splits.
26 | """
27 | split = self.split_paths[split]
28 | self.data_root = data_root
29 | img_path = os.path.join(self.data_root, "%s_images_png.pkl" %(split))
30 | label_path = os.path.join(self.data_root, "%s_labels.pkl" %(split))
31 | self.transforms = transforms
32 | self.rotation_labels = rotation_labels
33 | with open(img_path, 'rb') as infile:
34 | self.features = pkl.load(infile, encoding="bytes")
35 | with open(label_path, 'rb') as infile:
36 | self.labels = pkl.load(infile, encoding="bytes")[b'label_specific']
37 |
38 | def next_run(self):
39 | pass
40 |
41 | def __getitem__(self, item):
42 | image = cv2.imdecode(self.features[item], cv2.IMREAD_COLOR)[..., ::-1]
43 | images = self.transforms(image) * 2 - 1
44 | return images, int(self.labels[item])
45 |
46 | def __len__(self):
47 | return len(self.labels)
48 |
49 | class RotatedNonEpisodicTieredImagenet(NonEpisodicTieredImagenet):
50 | tasks_type = "clss"
51 | name = "tiered-imagenet"
52 | split_paths = {"train":"train", "test":"test", "valid": "val"}
53 | c = 3
54 | h = 84
55 | w = 84
56 |
57 | def __init__(self, *args, **kwargs):
58 | """ Constructor
59 |
60 | Args:
61 | split: data split
62 | few_shot_sampler: FewShotSampler instance
63 | task: dataset task (if more than one)
64 | size: number of tasks to generate (int)
65 | disjoint: whether to create disjoint splits.
66 | """
67 | super().__init__(*args, **kwargs)
68 |
69 | def rotate_img(self, img, rot):
70 | if rot == 0: # 0 degrees rotation
71 | return img
72 | elif rot == 90: # 90 degrees rotation
73 | return np.flipud(np.transpose(img, (1, 0, 2)))
74 | elif rot == 180: # 90 degrees rotation
75 | return np.fliplr(np.flipud(img))
76 | elif rot == 270: # 270 degrees rotation / or -90
77 | return np.transpose(np.flipud(img), (1, 0, 2))
78 | else:
79 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
80 |
81 | def __getitem__(self, item):
82 | image = cv2.imdecode(self.features[item], cv2.IMREAD_COLOR)[..., ::-1]
83 | if np.random.randint(2):
84 | image = np.fliplr(image)
85 | image_90 = self.transforms(self.rotate_img(image, 90))
86 | image_180 = self.transforms(self.rotate_img(image, 180))
87 | image_270 = self.transforms(self.rotate_img(image, 270))
88 | images = torch.stack([self.transforms(image), image_90, image_180, image_270]) * 2 - 1
89 | return images, torch.ones(4, dtype=torch.long)*int(self.labels[item]), torch.LongTensor(self.rotation_labels)
90 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | import argparse
4 | import pandas as pd
5 | import pickle, os
6 | import numpy as np
7 | from . import pretraining, finetuning, ssl_wrapper
8 |
9 | def get_model(model_name, backbone, n_classes, exp_dict, pretrained_weights_dir=None, savedir_base=None):
10 | if model_name == "pretraining":
11 | model = pretraining.PretrainWrapper(backbone, n_classes, exp_dict)
12 |
13 | elif model_name == "finetuning":
14 | model = finetuning.FinetuneWrapper(backbone, n_classes, exp_dict)
15 |
16 | elif model_name == "ssl":
17 | model = ssl_wrapper.SSLWrapper(backbone, n_classes, exp_dict, savedir_base=savedir_base)
18 |
19 | else:
20 | raise ValueError('model does not exist...')
21 |
22 | # load pretrained model
23 | if pretrained_weights_dir:
24 | s_path = os.path.join(os.path.dirname(pretrained_weights_dir), 'score_list_best.pkl')
25 | if not os.path.exists(s_path):
26 | s_path = os.path.join(exp_dict['checkpoint_exp_id'],
27 | 'score_list.pkl')
28 | print('Loaded checkpoint from exp_id: %s' %
29 | os.path.split(os.path.dirname(pretrained_weights_dir))[-1]
30 | )
31 | print('Fine-tuned accuracy: %.3f' % hu.load_pkl(s_path)[-1]['test_accuracy'])
32 | model.model.load_state_dict(torch.load(pretrained_weights_dir)['model'])
33 |
34 | return model
35 |
36 | # ===============================================
37 | # Trainers
38 | def train_on_loader(model, train_loader):
39 | model.train()
40 |
41 | n_batches = len(train_loader)
42 | train_monitor = TrainMonitor()
43 | for e in range(1):
44 | for i, batch in enumerate(train_loader):
45 | score_dict = model.train_on_batch(batch)
46 |
47 | train_monitor.add(score_dict)
48 | if i % 10 == 0:
49 | msg = "%d/%d %s" % (i, n_batches, train_monitor.get_avg_score())
50 |
51 | print(msg)
52 |
53 | return train_monitor.get_avg_score()
54 |
55 | def val_on_loader(model, val_loader, val_monitor):
56 | model.eval()
57 |
58 | n_batches = len(val_loader)
59 |
60 | for i, batch in enumerate(val_loader):
61 | score = model.val_on_batch(batch)
62 |
63 | val_monitor.add(score)
64 | if i % 10 == 0:
65 | msg = "%d/%d %s" % (i, n_batches, val_monitor.get_avg_score())
66 |
67 | print(msg)
68 |
69 |
70 | return val_monitor.get_avg_score()
71 |
72 |
73 | @torch.no_grad()
74 | def vis_on_loader(model, vis_loader, savedir):
75 | model.eval()
76 |
77 | n_batches = len(vis_loader)
78 | split = vis_loader.dataset.split
79 | for i, batch in enumerate(vis_loader):
80 | print("%d - visualizing %s image - savedir:%s" % (i, batch["meta"]["split"][0], savedir.split("/")[-2]))
81 | model.vis_on_batch(batch, savedir=savedir)
82 |
83 |
84 | def test_on_loader(model, test_loader):
85 | model.eval()
86 | ae = 0.
87 | n_samples = 0.
88 |
89 | n_batches = len(test_loader)
90 | pbar = tqdm.tqdm(total=n_batches)
91 | for i, batch in enumerate(test_loader):
92 | pred_count = model.predict(batch, method="counts")
93 |
94 | ae += abs(batch["counts"].cpu().numpy().ravel() - pred_count.ravel()).sum()
95 | n_samples += batch["counts"].shape[0]
96 |
97 | pbar.set_description("TEST mae: %.4f" % (ae / n_samples))
98 | pbar.update(1)
99 |
100 | pbar.close()
101 | score = ae / n_samples
102 | print({"test_score": score, "test_mae":score})
103 |
104 | return {"test_score": score, "test_mae":score}
105 |
106 |
107 | class TrainMonitor:
108 | def __init__(self):
109 | self.score_dict_sum = {}
110 | self.n = 0
111 |
112 | def add(self, score_dict):
113 | for k,v in score_dict.items():
114 | if k not in self.score_dict_sum:
115 | self.score_dict_sum[k] = score_dict[k]
116 | else:
117 | self.n += 1
118 | self.score_dict_sum[k] += score_dict[k]
119 |
120 | def get_avg_score(self):
121 | return {k:v/(self.n + 1) for k,v in self.score_dict_sum.items()}
122 |
123 |
--------------------------------------------------------------------------------
/src/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from . import resnet12, conv4, wrn
3 |
4 | def get_backbone(backbone_name, exp_dict):
5 | if backbone_name == "resnet12":
6 | backbone = resnet12.Resnet12(width=1, dropout=exp_dict["dropout"])
7 | elif backbone_name == "conv4":
8 | backbone = conv4.Conv4(exp_dict)
9 | elif backbone_name == "wrn":
10 | backbone = wrn.WideResNet(depth=exp_dict["model"]["depth"], width=exp_dict["model"]["width"], exp_dict=exp_dict)
11 |
12 | return backbone
13 |
--------------------------------------------------------------------------------
/src/models/backbones/conv4.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | class Conv4(torch.nn.Module):
7 | def __init__(self, exp_dict):
8 | super().__init__()
9 | self.conv0 = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False)
10 | self.bn0 = torch.nn.BatchNorm2d(64)
11 | self.conv1 = torch.nn.Conv2d(64, 64, 3, 1, 1, bias=False)
12 | self.bn1 = torch.nn.BatchNorm2d(64)
13 | self.conv2 = torch.nn.Conv2d(64, 64, 3, 1, 1, bias=False)
14 | self.bn2 = torch.nn.BatchNorm2d(64)
15 | self.conv3 = torch.nn.Conv2d(64, 64, 3, 1, 1, bias=False)
16 | self.bn3 = torch.nn.BatchNorm2d(64)
17 | self.exp_dict = exp_dict
18 | if self.exp_dict["avgpool"] == True:
19 | self.output_size = 64
20 | else:
21 | self.output_size = 1600
22 |
23 | def add_classifier(self, no, name="classifier", modalities=None):
24 | setattr(self, name, torch.nn.Linear(self.output_size, no))
25 |
26 | def forward(self, x, *args, **kwargs):
27 | *dim, c, h, w = x.size()
28 | x = x.view(-1, c, h, w)
29 | x = self.conv0(x) # 84
30 | x = F.relu(self.bn0(x), True)
31 | x = F.max_pool2d(x, 2, 2, 0) # 84 -> 42
32 | x = self.conv1(x)
33 | x = F.relu(self.bn1(x), True)
34 | x = F.max_pool2d(x, 2, 2, 0) # 42 -> 21
35 | x = self.conv2(x)
36 | x = F.relu(self.bn2(x), True)
37 | x = F.max_pool2d(x, 2, 2, 0) # 21 -> 10
38 | x = self.conv3(x)
39 | x = F.relu(self.bn3(x), True)
40 | x = F.max_pool2d(x, 2, 2, 0) # 21 -> 5
41 | if self.exp_dict["avgpool"] == True:
42 | x = x.mean(3, keepdim=True).mean(2, keepdim=True)
43 | return x.view(*dim, self.output_size)
--------------------------------------------------------------------------------
/src/models/backbones/resnet12.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | class Block(torch.nn.Module):
7 | def __init__(self, ni, no, stride, dropout=0, groups=1):
8 | super().__init__()
9 | self.dropout = torch.nn.Dropout2d(dropout) if dropout > 0 else lambda x: x
10 | self.conv0 = torch.nn.Conv2d(ni, no, 3, stride, padding=1, bias=False)
11 | self.bn0 = torch.nn.BatchNorm2d(no)
12 | self.conv1 = torch.nn.Conv2d(no, no, 3, 1, padding=1, bias=False)
13 | self.bn1 = torch.nn.BatchNorm2d(no)
14 | self.conv2 = torch.nn.Conv2d(no, no, 3, 1, padding=1, bias=False)
15 | self.bn2 = torch.nn.BatchNorm2d(no)
16 | if stride == 2 or ni != no:
17 | self.shortcut = torch.nn.Conv2d(ni, no, 1, stride=1, padding=0)
18 |
19 | def get_parameters(self):
20 | return self.parameters()
21 |
22 | def forward(self, x, is_support=True):
23 | y = F.relu(self.bn0(self.conv0(x)), True)
24 | y = self.dropout(y)
25 | y = F.relu(self.bn1(self.conv1(y)), True)
26 | y = self.dropout(y)
27 | y = self.bn2(self.conv2(y))
28 | return F.relu(y + self.shortcut(x), True)
29 |
30 |
31 | class Resnet12(torch.nn.Module):
32 | def __init__(self, width, dropout):
33 | super().__init__()
34 | self.output_size = 512
35 | assert(width == 1) # Comment for different variants of this model
36 | self.widths = [x * int(width) for x in [64, 128, 256]]
37 | self.widths.append(self.output_size * width)
38 | self.bn_out = torch.nn.BatchNorm1d(self.output_size)
39 |
40 | start_width = 3
41 | for i in range(len(self.widths)):
42 | setattr(self, "group_%d" %i, Block(start_width, self.widths[i], 1, dropout))
43 | start_width = self.widths[i]
44 |
45 | def add_classifier(self, nclasses, name="classifier", modalities=None):
46 | setattr(self, name, torch.nn.Linear(self.output_size, nclasses))
47 |
48 | def up_to_embedding(self, x, is_support):
49 | """ Applies the four residual groups
50 | Args:
51 | x: input images
52 | n: number of few-shot classes
53 | k: number of images per few-shot class
54 | is_support: whether the input is the support set (for non-transductive)
55 | """
56 | for i in range(len(self.widths)):
57 | x = getattr(self, "group_%d" % i)(x, is_support)
58 | x = F.max_pool2d(x, 3, 2, 1)
59 | return x
60 |
61 | def forward(self, x, is_support):
62 | """Main Pytorch forward function
63 |
64 | Returns: class logits
65 |
66 | Args:
67 | x: input mages
68 | is_support: whether the input is the sample set
69 | """
70 | *args, c, h, w = x.size()
71 | x = x.view(-1, c, h, w)
72 | x = self.up_to_embedding(x, is_support)
73 | return F.relu(self.bn_out(x.mean(3).mean(2)), True)
--------------------------------------------------------------------------------
/src/models/backbones/wrn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class Block(torch.nn.Module):
8 | def __init__(self, ni, no, stride, dropout=0):
9 | super().__init__()
10 | self.conv0 = torch.nn.Conv2d(ni, no, 3, stride=stride, padding=1, bias=False)
11 | self.bn0 = torch.nn.BatchNorm2d(no)
12 | torch.nn.init.kaiming_normal_(self.conv0.weight.data)
13 | self.bn1 = torch.nn.BatchNorm2d(no)
14 | if dropout == 0:
15 | self.dropout = lambda x: x
16 | else:
17 | self.dropout = torch.nn.Dropout2d(dropout)
18 | self.conv1 = torch.nn.Conv2d(no, no, 3, stride=1, padding=1, bias=False)
19 | torch.nn.init.kaiming_normal_(self.conv1.weight.data)
20 | self.reduce = ni != no
21 | if self.reduce:
22 | self.conv_reduce = torch.nn.Conv2d(ni, no, 1, stride=stride, bias=False)
23 | torch.nn.init.kaiming_normal_(self.conv_reduce.weight.data)
24 |
25 | def forward(self, x):
26 | y = self.conv0(x)
27 | y = F.relu(self.bn0(y), inplace=True)
28 | y = self.dropout(y)
29 | y = self.conv1(y)
30 | y = self.bn1(y)
31 | if self.reduce:
32 | return F.relu(y + self.conv_reduce(x), True)
33 | else:
34 | return F.relu(y + x, True)
35 |
36 |
37 | class Group(torch.nn.Module):
38 | def __init__(self, ni, no, n, stride, dropout=0):
39 | super().__init__()
40 | self.n = n
41 | for i in range(n):
42 | self.__setattr__("block_%d" % i, Block(ni if i == 0 else no, no, stride if i == 0 else 1, dropout=dropout))
43 |
44 | def forward(self, x):
45 | for i in range(self.n):
46 | x = self.__getattr__("block_%d" % i)(x)
47 | return x
48 |
49 |
50 | class WideResNet(torch.nn.Module):
51 | def __init__(self, depth, width, exp_dict):
52 | super(WideResNet, self).__init__()
53 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
54 | self.n = (depth - 4) // 6
55 | self.output_size = 640
56 | self.widths = torch.Tensor([16, 32, 64]).mul(width).int().numpy().tolist()
57 | self.conv0 = torch.nn.Conv2d(3, self.widths[0] // 2, 3, padding=1, bias=False)
58 | self.bn_0 = torch.nn.BatchNorm2d(self.widths[0] // 2)
59 | self.dropout_prob = exp_dict["dropout"]
60 | self.group_0 = Group(self.widths[0] // 2, self.widths[0], self.n, 2, dropout=self.dropout_prob)
61 | self.group_1 = Group(self.widths[0], self.widths[1], self.n, 2, dropout=self.dropout_prob)
62 | self.group_2 = Group(self.widths[1], self.widths[2], self.n, 2, dropout=self.dropout_prob)
63 | self.bn_out = torch.nn.BatchNorm1d(self.output_size)
64 |
65 | def get_base_parameters(self):
66 | parameters = []
67 | parameters += list(self.conv0.parameters())
68 | parameters += list(self.group_0.parameters())
69 | parameters += list(self.group_1.parameters())
70 | parameters += list(self.group_2.parameters())
71 | parameters += list(self.bn.parameters())
72 | if self.embedding:
73 | parameters += list(self.conv_embed)
74 | return parameters
75 |
76 | def get_classifier_parameters(self):
77 | return self.classifier.parameters()
78 |
79 | def add_classifier(self, nclasses, name="classifier", modalities=None):
80 | setattr(self, name, torch.nn.Linear(self.output_size, nclasses))
81 |
82 | def forward(self, x, **kwargs):
83 | o = F.relu(self.bn_0(self.conv0(x)), True)
84 | o = self.group_0(o)
85 | o = self.group_1(o)
86 | o = self.group_2(o)
87 | o = o.mean(3).mean(2)
88 | o = F.relu(self.bn_out(o.view(o.size(0), -1)))
89 | return o
--------------------------------------------------------------------------------
/src/models/base_ssl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/models/base_ssl/__init__.py
--------------------------------------------------------------------------------
/src/models/base_ssl/distances.py:
--------------------------------------------------------------------------------
1 |
2 | import matplotlib.pyplot as plt
3 | # import seaborn as sns
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import warnings
8 | from functools import partial
9 | # from tools.meters import BasicMeter
10 | # from train import TensorLogger
11 | import sys
12 |
13 | def checknan(**kwargs):
14 | for item, value in kwargs.items():
15 | v = value.data.sum()
16 | if torch.isnan(v) or torch.isinf(torch.abs(v)):
17 | return "%s is NaN or inf" %item
18 |
19 | def _make_aligned_labels(inputs):
20 | batch, n_classes, n_sample_pc, z_dim = inputs.shape
21 | identity = torch.eye(n_classes, dtype=inputs.dtype, device=inputs.device)
22 | return identity[None, :, None, :].expand(batch, -1, n_sample_pc, -1).contiguous()
23 |
24 |
25 | def matching_nets(support_set, query_set, *sets, distance_type="euclidean", **kwargs):
26 | """ Computes the logits using the method described in [1]
27 |
28 | [1] Vinyals, Oriol, et al. "Matching networks for one shot learning." NeurIPS. 2016.
29 |
30 | Args:
31 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of
32 | each images.
33 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of
34 | each images.
35 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation
36 | z of each images.
37 | euclidean: Whether to use the euclidean distance or the cosine distance
38 |
39 | Returns:
40 | Class logits (warning, this function returns log probabilities, so NLL loss is recommended)
41 | """
42 | euclidean = distance_type == "euclidean"
43 | _support_set, _support_labels = support_set
44 | _query_set, _query_labels = query_set
45 | b, n, sk, z = _support_set.size()
46 | b, n, qk, z = _query_set.size()
47 | if isinstance(_support_labels, bool):
48 | labels = _make_aligned_labels(_support_set)
49 | else:
50 | labels = _support_labels
51 | if euclidean:
52 | _support_set = _support_set.view(b, 1, n * sk, z)
53 | _query_set = _query_set.view(b, n * qk, 1, z)
54 | att = - ((_support_set - _query_set) ** 2).sum(3) / np.sqrt(z)
55 | else:
56 | _support_set = F.normalize(_support_set, dim=3)
57 | _query_set = F.normalize(_query_set, dim=3)
58 | _support_set = _support_set.view(b, n * sk, z).transpose(2, 1)
59 | _query_set = _query_set.view(b, n * qk, z)
60 | att = torch.matmul(_query_set, _support_set)
61 | att = F.softmax(att, dim=2).view(b, n * qk, 1, n * sk)
62 | labels = labels.view(b, 1, n * sk, n)
63 | return torch.log(torch.matmul(att, labels).view(b * n * qk, n))
64 |
65 |
66 | def prototype_distance(support_set, query_set, *args, **kwargs):
67 | """Computes distance from each element of the query set to prototypes in the sample set.
68 |
69 | Args:
70 | sample_set: tuple of (Tensor, is_labeled=True). The Tensor has shape
71 | (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of each images.
72 | query_set: tuple of (Tensor, is_labeled=False). The tensor has shape,
73 | (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of each images.
74 |
75 | Returns:
76 | Tensor of shape (batch, n_total_query, n_classes) containing the similarity between each pair of query,
77 | prototypes, for each task.
78 | """
79 | _support_set, _support_labels = support_set
80 | _query_set, _query_labels = query_set
81 | b, n, query_size, c = _query_set.size()
82 | _support_set = _support_set.mean(2).view(b, 1, n, c)
83 | _query_set = _query_set.view(b, n * query_size, 1, c)
84 | d = _query_set - _support_set
85 | return -torch.sum(d ** 2, 3) / np.sqrt(c)
86 |
87 |
88 | def gauss_distance(sample_set, query_set, unlabeled_set=None):
89 | """ (experimental) function to try different approaches to model prototypes as gaussians
90 | Args:
91 | sample_set: features extracted from the sample set
92 | query_set: features extracted from the query set
93 | query_set: features extracted from the unlabeled set
94 |
95 | """
96 | b, n, k, c = sample_set.size()
97 | sample_set_std = sample_set.std(2).view(b, 1, n, c)
98 | sample_set_mean = sample_set.mean(2).view(b, 1, n, c)
99 | query_set = query_set.view(b, n * k, 1, c)
100 | d = (query_set - sample_set_mean) / sample_set_std
101 | return -torch.sum(d ** 2, 3) / np.sqrt(c)
102 |
103 |
104 | def _make_aligned_labels(inputs):
105 | """Uses the shape of inputs to infer batch_size, n_classes, and n_sample_per_class. From this, we build the one-hot
106 | encoding label tensor aligned with inputs. This is used to keep the lable information when tensors are flatenned
107 | across the n_class and n_sample_per_class.
108 |
109 | Args:
110 | inputs: tensor of shape (batch, n_classes, n_sample_per_class, z_dim) containing encoded examples for a task.
111 | Returns:
112 | tensor of shape (batch, n_classes, n_sample_pc, n_classes) containing the one-hot encoding label of each example
113 | """
114 | batch, n_classes, n_sample_pc, z_dim = inputs.shape
115 | identity = torch.eye(n_classes, dtype=inputs.dtype, device=inputs.device)
116 | return identity[None, :, None, :].expand(batch, -1, n_sample_pc, -1).contiguous()
117 |
118 |
119 | def generalized_pw_sq_dist(data, d_type="euclidean"):
120 | batch, n_classes, n_samples, z_dim = data.size()
121 | data = data.view(batch, -1, z_dim)
122 | if d_type == "euclidean":
123 | return torch.sum((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim)
124 | elif d_type == "l1":
125 | return torch.mean(torch.abs(data[:, :, None, :] - data[:, None, :, :]), dim=3)
126 | elif d_type == "stable_euclidean":
127 | return torch.sqrt(1e-6 + torch.mean((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim))
128 | elif d_type == "cosine":
129 | data = F.normalize(data, dim=2)
130 | return torch.bmm(data, data.transpose(2, 1))
131 | else:
132 | raise ValueError("Distance type not recognized")
133 |
134 |
135 | def pw_sq_dist(sample_set, query_set, unlabeled_set=None, label_offset=0):
136 | """Computes distance from each element of the query set to prototypes in the sample set.
137 |
138 | Args:
139 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_class, z_dim) containing the representation z of
140 | each images.
141 | query_set: Tensor of shape (batch, n_classes, n_query_per_class, z_dim) containing the representation z of
142 | each images.
143 |
144 | Returns:
145 | dist: Tensor of shape (batch, n_total, n_total) containing the squared distance between each pair.
146 | labels: Tensor of shape (batch, n_total, n_classes) Containing the one hot vector for the sample set and zeros
147 | for the query set.
148 | """
149 | batch, n_classes, _, z_dim = sample_set.shape
150 | sample_labels = _make_aligned_labels(sample_set)
151 | query_labels = _make_aligned_labels(query_set) * 0. + label_offset # XXX: Set labels to a constant
152 | # TODO: it's a bit sketchy that this function is used to extract the labels. Should be done externally.
153 | labels = torch.cat((sample_labels, query_labels), dim=2).view(batch, -1, n_classes)
154 | samples = torch.cat((sample_set, query_set), dim=2).view(batch, -1, z_dim)
155 | return torch.sum((samples[:, :, None, :] - samples[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim), labels
156 |
157 |
158 | def _ravel_index(index, shape):
159 | shape0 = np.prod(shape[:-1])
160 | shape1 = shape[-1]
161 | new_shape = list(shape[:-1]) + [1]
162 | offsets = (torch.arange(shape0, dtype=index.dtype, device=index.device) * shape1).view(new_shape)
163 | return (index + offsets).view(-1)
164 |
165 |
166 | def _learned_scale_adjustment(scale, mlp, sq_dist, batch, n_sample_set, n_query_set, n_classes):
167 | """
168 | Learn a data-dependent adjustment of the scale factor used in the distance computation
169 |
170 | scale: float
171 | The original scale factor
172 | mlp: nn.Module
173 | The model used to learn the scale factor adjustment
174 | sq_dist: torch.Tensor, shape=(bs, n_total, n_total)
175 | A matrix of squared distances between all the examples
176 | n_sample_set: int
177 | Number of examples in the sample set
178 | n_query_set: int
179 | Number of examples in the query set
180 | n_classes: int
181 | Number of classes
182 |
183 | """
184 | global_mean = sq_dist.view(batch, -1).mean(1, keepdim=True).repeat(batch, sq_dist.size(1))
185 | global_std = sq_dist.view(batch, -1).std(1, keepdim=True).repeat(batch, sq_dist.size(1))
186 | avg_mean = sq_dist.mean(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1))
187 | avg_std = sq_dist.std(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1))
188 | sample_distances = sq_dist[:, :, :(n_classes * n_sample_set)].clone().view(batch, -1, n_sample_set)
189 | class_mean = sample_distances.mean(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1))
190 | class_std = sample_distances.std(2).mean(1, keepdim=True).repeat(batch, sq_dist.size(1))
191 | sample_mean = sq_dist.mean(2)
192 | sample_std = sq_dist.std(2)
193 | n = torch.ones_like(sample_mean) * n_classes
194 | k = torch.ones_like(sample_mean) * n_sample_set
195 | q = torch.ones_like(sample_mean) * n_query_set
196 | vin = torch.stack(
197 | [global_mean, global_std, sample_mean, sample_std, avg_mean, avg_std, class_mean, class_std, n, k, q], 2)
198 | scales = mlp(vin.view(-1, 11)).view(sq_dist.size(0), sq_dist.size(1), 1)
199 | return scale + F.softplus(scales)
200 |
201 |
202 | def label_prop(sample_set, query_set, unlabeled_set=None, alpha=0.1, scale_factor=1, label_offset=0, apply_log=False,
203 | topk=0, method="global_consistancy", epsilon=1e-8, mlp=None, return_all=False, normalize_weights=False,
204 | debug_plot_path=None, propagator=None, labels=None, weights=None):
205 | """Uses the laplacian graph to smooth the labels matrix by "propagating" labels.
206 |
207 | Args:
208 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of
209 | each images.
210 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of
211 | each images.
212 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation
213 | z of each images.
214 | alpha: Smoothing factor in the laplacian graph
215 | scale_factor: scale modifying the euclidean distance before exponential kernel.
216 | label_offset: Applies an offset to the labels before propagating. Has an effect on the degree of uncertainty
217 | when apply_log is True.
218 | apply_log: if True, it is assumed that the label propagation methods returns un-normalized probabilities. Hence
219 | to return logits, applying logarithm is necessary.
220 | topk: limit the weight matrix to the topk most similiar
221 | method: "regularized_laplacian" or "global_consistancy".
222 | epsilon: small value used when apply_log is True.
223 | mlp: If 1, it trains an MLP to predict the scaling factor from weight matrix stats. If 2, it does it without
224 | passing backprop to the main architecture.
225 | return_all: For debugging purpose.
226 |
227 | Returns:
228 | Tensor of shape (batch, n_total_query, n_classes) representing the logits of each classes
229 |
230 | """
231 | init_query_size = query_set.size()[2]
232 | if unlabeled_set is not None:
233 | init_unlabel_size = unlabeled_set.size()[2]
234 | query_set = torch.cat((unlabeled_set, query_set), dim=2) # XXX: Concat order is important to logit extraction
235 |
236 | # Get data shape
237 | batch, n_classes, n_query_pc, z_dim = query_set.size()
238 | batch, n_classes, n_sample_pc, z_dim = sample_set.size()
239 |
240 | if propagator is None:
241 | # Compute the pairwise distance between the examples of the sample and query sets
242 | # XXX: labels are set to a constant for the query set
243 | sq_dist, labels = pw_sq_dist(sample_set, query_set, label_offset)
244 |
245 | # Learn to adjust the scale
246 | if mlp is not None:
247 | scale_factor = _learned_scale_adjustment(scale_factor, mlp, sq_dist, batch, n_sample_pc, n_query_pc,
248 | n_classes)
249 |
250 | # Compute similarity between the examples -- inversely proportional to distance
251 | weights = torch.exp(-0.5 * sq_dist / scale_factor ** 2)
252 |
253 | # Discard similarity for examples other than the top k most similar
254 | if topk > 0:
255 | weights, ind = torch.sort(weights, dim=2) # ascending order
256 | weights[:, :, :-int(topk + 1)] *= 0
257 | weights = weights.scatter(2, ind, weights)
258 |
259 | # Normalize the weights
260 | if normalize_weights:
261 | weights = weights / torch.sum(weights, dim=2, keepdim=True)
262 |
263 | if (method == "regularized_laplacian") or (method is None):
264 | logits, propagator = regularized_laplacian(weights, labels, alpha=alpha)
265 | elif method == "global_consistancy":
266 | logits, propagator = global_consistency(weights, labels, alpha=alpha)
267 | else:
268 | raise Exception("Unknonwn method %s." % method)
269 | else:
270 | logits = _propagate(labels, propagator)
271 |
272 | if debug_plot_path is not None:
273 | print("Sample set:", n_sample_pc, " Query set:", init_query_size,
274 | " unlabeled set:", 0 if unlabeled_set is None else n_unlabeled_pc)
275 | # XXX: Only saves the first batch elements
276 | np.save(debug_plot_path + "_weights.npy", weights[0].detach().cpu().numpy())
277 | np.save(debug_plot_path + "_propagator.npy", propagator[0].detach().cpu().numpy())
278 | plt_labels = \
279 | np.argmax(torch.cat((_make_aligned_labels(sample_set),
280 | _make_aligned_labels(query_set)), dim=2).view(batch, -1, n_classes).cpu().numpy(),
281 | axis=-1)[0]
282 | np.save(debug_plot_path + "_labels.npy", plt_labels)
283 |
284 | if apply_log:
285 | logits = torch.log(logits + epsilon)
286 |
287 | logits = logits.reshape(batch, n_classes, -1, n_classes)
288 | if return_all:
289 | # Extracts only the logits for the query set
290 | # TODO: verify that this works as expected. <<<------- *****
291 | query_labels = _make_aligned_labels(query_set)
292 | query_logits = logits[:, :, -init_query_size:, :].reshape(batch, -1, n_classes)
293 | if unlabeled_set is not None:
294 | unlabeled_labels = query_labels[:, :, :init_unlabel_size, :]
295 | query_labels = query_labels[:, :, init_unlabel_size:, :]
296 | unlabel_logits = logits[:, :, :init_unlabel_size, :].reshape(batch, -1, n_classes)
297 | else:
298 | unlabel_logits = None
299 | unlabeled_labels = None
300 | return query_logits, unlabel_logits, labels, weights, _make_aligned_labels(
301 | sample_set), query_labels, unlabeled_labels, propagator
302 | else:
303 | # Extracts only the logits for the query set
304 | # TODO: verify that this works as expected. <<<------- *****
305 | logits = logits.reshape(batch, n_classes, -1, n_classes)[:, :, -init_query_size:, :].reshape(batch, -1,
306 | n_classes)
307 | return logits
308 |
309 | def make_one_hot_labels(labels, label_offset=0.):
310 | n_classes = labels.max() + 1
311 | assert (n_classes.item() == 5)
312 | to_one_hot = torch.eye(n_classes, device=labels.device, dtype=torch.float)
313 | one_hot_labels = to_one_hot[labels]
314 | mask = labels == -1
315 | one_hot_labels[mask, :] = label_offset
316 | return one_hot_labels
317 |
318 |
319 | # def labelprop(suppot_set, query_set, suppot_labels,
320 | # gaussian_scale=1, alpha=1,
321 | # propagator=None, weights=None, return_all=False,
322 | # apply_log=False, scale_bound="", standarize="", kernel="", square_root=False,
323 | # offset=0, epsilon=1e-6, dropout=0, n_classes=5):
324 | # if scale_bound == "softplus":
325 | # gaussian_scale = 0.01 + F.softplus(gaussian_scale)
326 | # alpha = 0.1 + F.softplus(alpha)
327 | # elif scale_bound == "square":
328 | # gaussian_scale = 1e-4 + gaussian_scale ** 2
329 | # alpha = 0.1 + alpha ** 2
330 | # elif scale_bound == "convex_relu":
331 | # #gaussian_scale = gaussian_scale ** 2
332 | # alpha = F.relu(alpha) + 0.1
333 | # elif scale_bound == "convex_square":
334 | # # gaussian_scale = gaussian_scale ** 2
335 | # alpha = 0.1 + alpha ** 2
336 | # elif scale_bound == "relu":
337 | # gaussian_scale = F.relu(gaussian_scale) + 0.01
338 | # alpha = F.relu(alpha) + 0.1
339 | # elif scale_bound == "constant":
340 | # gaussian_scale = 1
341 | # alpha = 1
342 | # elif scale_bound == "alpha_square":
343 | # alpha = 0.1 + F.relu(alpha)
344 |
345 | # samples = []
346 | # labels = []
347 | # offsets = []
348 | # is_labeled = []
349 | # for data, _labels in sets:
350 | # b, _, sample_size, c = data.size()
351 | # n = n_classes
352 | # samples.append(data)
353 | # offsets.append(sample_size)
354 | # if isinstance(_labels, bool):
355 | # labels.append(_make_aligned_labels(data))
356 | # if not (_labels):
357 | # labels[-1].data[:] = labels[-1].data * 0 + offset
358 | # is_labeled.append(_labels)
359 | # else:
360 | # labels.append(_labels)
361 | # is_labeled.append((_labels.cpu().numpy() > 0).any())
362 | # samples = torch.cat(samples, dim=2)
363 | # labels = torch.cat(labels, dim=2).view(b, -1, n)
364 |
365 | # if propagator is None:
366 | # # Compute the pairwise distance between the examples of the sample and query sets
367 | # # XXX: labels are set to a constant for the query set
368 | # sq_dist = generalized_pw_sq_dist(samples, "euclidean")
369 | # if square_root:
370 | # sq_dist = (sq_dist + epsilon).sqrt()
371 | # if standarize == "all":
372 | # mask = sq_dist != 0
373 | # # sq_dist = sq_dist - sq_dist[mask].mean()
374 | # sq_dist = sq_dist / sq_dist[mask].std()
375 | # elif standarize == "median":
376 | # mask = sq_dist != 0
377 | # gaussian_scale = torch.sqrt(
378 | # 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1)))
379 | # elif standarize == "frobenius":
380 | # mask = sq_dist != 0
381 | # sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt()
382 | # elif standarize == "percentile":
383 | # mask = sq_dist != 2
384 | # sorted, indices = torch.sort(sq_dist.data[mask])
385 | # total = sorted.size(0)
386 | # gaussian_scale = sorted[int(total * 0.1)].detach()
387 | # if kernel == "rbf":
388 | # weights = torch.exp(-sq_dist * gaussian_scale)
389 | # elif kernel == "convex_rbf":
390 | # scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype)
391 | # weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1))
392 | # weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1)
393 | # # checknan(timessoftmax=weights)
394 | # elif kernel == "euclidean":
395 | # # Compute similarity between the examples -- inversely proportional to distance
396 | # weights = 1 / (gaussian_scale + sq_dist)
397 | # elif kernel == "softmax":
398 | # weights = F.softmax(-sq_dist / gaussian_scale, -1)
399 |
400 | # mask = (torch.eye(weights.size(1), dtype=weights.dtype, device=weights.device)[None, :, :]).view(1, weights.size(1), weights.size(2)).expand(weights.size(0), -1, -1)
401 | # weights = weights * (1 - mask)
402 | # # checknan(masking=weights)
403 |
404 |
405 | # logits, propagator = global_consistency(weights, labels, alpha=alpha)
406 | # else:
407 | # logits = _propagate(labels, propagator)
408 |
409 | # if apply_log:
410 | # logits = torch.log(logits + epsilon)
411 |
412 | # logits = logits.view(b, n, -1, n)
413 | # logits_ret = []
414 | # start = 0
415 | # for i, offset in enumerate(offsets):
416 | # logits_ret.append(logits[:, :, start:(start + offset), :].contiguous().view(b, -1, n))
417 | # start += offset
418 | # if return_all:
419 | # # Extracts only the logits for the query set
420 | # # TODO: verify that this works as expected. <<<------- *****
421 | # labels = labels.view(b, n, -1, n)
422 | # labels_ret = []
423 | # start = 0
424 | # for i, offset in enumerate(offsets):
425 | # labels_ret.append(labels[:, :, start:(start + offset), :].contiguous().view(b, -1, n))
426 | # start += offset
427 | # return tuple(logits_ret + labels_ret + [weights, propagator, labels])
428 | # else:
429 | # # Extracts only the logits for the query set
430 | # # TODO: verify that this works as expected. <<<------- *****
431 | # return tuple(logits_ret)
432 |
433 | def standarized_label_prop(*sets, gaussian_scale=1, alpha=1,
434 | propagator=None, weights=None, return_all=False,
435 | apply_log=False, scale_bound="", standarize="", kernel="", square_root=False,
436 | offset=0, epsilon=1e-6, dropout=0, n_classes=5):
437 | if scale_bound == "softplus":
438 | gaussian_scale = 0.01 + F.softplus(gaussian_scale)
439 | alpha = 0.1 + F.softplus(alpha)
440 | elif scale_bound == "square":
441 | gaussian_scale = 1e-4 + gaussian_scale ** 2
442 | alpha = 0.1 + alpha ** 2
443 | elif scale_bound == "convex_relu":
444 | #gaussian_scale = gaussian_scale ** 2
445 | alpha = F.relu(alpha) + 0.1
446 | elif scale_bound == "convex_square":
447 | # gaussian_scale = gaussian_scale ** 2
448 | alpha = 0.1 + alpha ** 2
449 | elif scale_bound == "relu":
450 | gaussian_scale = F.relu(gaussian_scale) + 0.01
451 | alpha = F.relu(alpha) + 0.1
452 | elif scale_bound == "constant":
453 | gaussian_scale = 1
454 | alpha = 1
455 | elif scale_bound == "alpha_square":
456 | alpha = 0.1 + F.relu(alpha)
457 |
458 | samples = []
459 | labels = []
460 | offsets = []
461 | is_labeled = []
462 | for data, _labels in sets:
463 | b, _, sample_size, c = data.size()
464 | n = n_classes
465 | samples.append(data)
466 | offsets.append(sample_size)
467 | if isinstance(_labels, bool):
468 | labels.append(_make_aligned_labels(data))
469 | if not (_labels):
470 | labels[-1].data[:] = labels[-1].data * 0 + offset
471 | is_labeled.append(_labels)
472 | else:
473 | labels.append(_labels)
474 | is_labeled.append((_labels.cpu().numpy() > 0).any())
475 | samples = torch.cat(samples, dim=2)
476 | labels = torch.cat(labels, dim=2).view(b, -1, n)
477 |
478 | if propagator is None:
479 | # Compute the pairwise distance between the examples of the sample and query sets
480 | # XXX: labels are set to a constant for the query set
481 | sq_dist = generalized_pw_sq_dist(samples, "euclidean")
482 | if square_root:
483 | sq_dist = (sq_dist + epsilon).sqrt()
484 | if standarize == "all":
485 | mask = sq_dist != 0
486 | # sq_dist = sq_dist - sq_dist[mask].mean()
487 | sq_dist = sq_dist / sq_dist[mask].std()
488 | elif standarize == "median":
489 | mask = sq_dist != 0
490 | gaussian_scale = torch.sqrt(
491 | 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1)))
492 | elif standarize == "frobenius":
493 | mask = sq_dist != 0
494 | sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt()
495 | elif standarize == "percentile":
496 | mask = sq_dist != 2
497 | sorted, indices = torch.sort(sq_dist.data[mask])
498 | total = sorted.size(0)
499 | gaussian_scale = sorted[int(total * 0.1)].detach()
500 | if kernel == "rbf":
501 | weights = torch.exp(-sq_dist * gaussian_scale)
502 | elif kernel == "convex_rbf":
503 | scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype)
504 | weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1))
505 | weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1)
506 | # checknan(timessoftmax=weights)
507 | elif kernel == "euclidean":
508 | # Compute similarity between the examples -- inversely proportional to distance
509 | weights = 1 / (gaussian_scale + sq_dist)
510 | elif kernel == "softmax":
511 | weights = F.softmax(-sq_dist / gaussian_scale, -1)
512 |
513 | mask = (torch.eye(weights.size(1), dtype=weights.dtype, device=weights.device)[None, :, :]).view(1, weights.size(1), weights.size(2)).expand(weights.size(0), -1, -1)
514 | weights = weights * (1 - mask)
515 | # checknan(masking=weights)
516 |
517 |
518 | logits, propagator = global_consistency(weights, labels, alpha=alpha)
519 | else:
520 | logits = _propagate(labels, propagator)
521 |
522 | if apply_log:
523 | logits = torch.log(logits + epsilon)
524 |
525 | logits = logits.view(b, n, -1, n)
526 | logits_ret = []
527 | start = 0
528 | for i, offset in enumerate(offsets):
529 | logits_ret.append(logits[:, :, start:(start + offset), :].contiguous().view(b, -1, n))
530 | start += offset
531 | if return_all:
532 | # Extracts only the logits for the query set
533 | # TODO: verify that this works as expected. <<<------- *****
534 | labels = labels.view(b, n, -1, n)
535 | labels_ret = []
536 | start = 0
537 | for i, offset in enumerate(offsets):
538 | labels_ret.append(labels[:, :, start:(start + offset), :].contiguous().view(b, -1, n))
539 | start += offset
540 | return tuple(logits_ret + labels_ret + [weights, propagator, labels])
541 | else:
542 | # Extracts only the logits for the query set
543 | # TODO: verify that this works as expected. <<<------- *****
544 | return tuple(logits_ret)
545 |
546 |
547 | def regularized_laplacian(weights, labels, alpha):
548 | """Uses the laplacian graph to smooth the labels matrix by "propagating" labels
549 |
550 | Args:
551 | weights: Tensor of shape (batch, n, n)
552 | labels: Tensor of shape (batch, n, n_classes)
553 | alpha: Scaler, acts as a smoothing factor
554 | apply_log: if True, it is assumed that the label propagation methods returns un-normalized probabilities. Hence
555 | to return logits, applying logarithm is necessary.
556 | epsilon: value added before applying log
557 | Returns:
558 | Tensor of shape (batch, n, n_classes) representing the logits of each classes
559 | """
560 | n = weights.shape[1]
561 | diag = torch.diag_embed(torch.sum(weights, dim=2))
562 | laplacian = diag - weights
563 | identity = torch.eye(n, dtype=laplacian.dtype, device=laplacian.device)[None, :, :]
564 | propagator = torch.inverse(identity + alpha * laplacian)
565 |
566 | return _propagate(labels, propagator), propagator
567 |
568 |
569 | def global_consistency(weights, labels, alpha=0.1):
570 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug)
571 |
572 | Args:
573 | weights: Tensor of shape (batch, n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and
574 | s the scale parameter.
575 | labels: Tensor of shape (batch, n, n_classes)
576 | alpha: Scaler, acts as a smoothing factor
577 | Returns:
578 | Tensor of shape (batch, n, n_classes) representing the logits of each classes
579 | """
580 | alpha_ = 1 / (1 + alpha)
581 | beta_ = alpha / (1 + alpha)
582 | n = weights.shape[1]
583 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)[None, :, :]
584 | #weights = weights * (1. - identity) # zero out diagonal
585 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=2))
586 | # checknan(laplacian=isqrt_diag)
587 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, :, None]
588 | # checknan(normalizedlaplacian=S)
589 | propagator = identity - alpha_ * S
590 | propagator = torch.inverse(propagator) * beta_
591 | # checknan(propagator=propagator)
592 |
593 | return _propagate(labels, propagator, scaling=1), propagator
594 |
595 |
596 | def _propagate(labels, propagator, scaling=1.):
597 | return torch.matmul(propagator, labels) * scaling
598 |
599 |
600 | class MLP(torch.nn.Module):
601 | def __init__(self, detach=False):
602 | super().__init__()
603 | self.__detach = detach
604 | self.linear1 = torch.nn.Linear(11, 128)
605 | self.linear2 = torch.nn.Linear(128, 16)
606 | self.linear3 = torch.nn.Linear(16, 1)
607 |
608 | def forward(self, x):
609 | if self.__detach:
610 | x = x.detach()
611 | x = F.relu(self.linear1(x), True)
612 | x = F.relu(self.linear2(x), True)
613 | return self.linear3(x)
614 |
615 | class Distance(torch.nn.Module):
616 | def __init__(self, exp_params):
617 | """ Helper to obtain a distance function from a string
618 | Args:
619 | distance_type: string indicating the distance type
620 | """
621 | super().__init__()
622 | self.exp_params = exp_params
623 | self.distance_type, *args = exp_params["distance_type"].split(',')
624 | if self.distance_type in ["euclidean", "prototypical"]:
625 | self.d = prototype_distance
626 | elif self.distance_type == "labelprop":
627 | self.register_buffer("moving_alpha", torch.ones(1) * self.exp_params["labelprop_alpha_prior"])
628 | if self.exp_params["kernel_type"] == "convex_rbf":
629 | scale_size = 10
630 | else:
631 | scale_size = 1
632 | self.register_buffer("moving_gaussian_scale", torch.ones(scale_size) * exp_params["labelprop_scale_prior"])
633 | self.d = partial(standarized_label_prop, scale_bound=self.exp_params["kernel_bound"],
634 | kernel=self.exp_params["kernel_type"],
635 | standarize=self.exp_params["kernel_standarization"],
636 | square_root=self.exp_params["kernel_square_root"],
637 | apply_log=True)
638 | elif self.distance_type == "matching":
639 | matching_distance, = args
640 | self.d = partial(matching_nets, distance_type=matching_distance)
641 | elif self.distance_type == "labelprop_boris":
642 | """
643 | TODO(@boris) define here the relation network, and define a new "label_prop" above, alike the original
644 | one parameters inside Distance are taken into account by the optimizer
645 | """
646 | self.d = partial(label_prop, apply_log=True)
647 |
648 | def forward(self, *sets, **kwargs):
649 | return self.d(*sets, **kwargs)
--------------------------------------------------------------------------------
/src/models/base_ssl/oracle.py:
--------------------------------------------------------------------------------
1 | from . import utils as ut
2 | import numpy as np
3 | import torch
4 | import sys
5 | import copy
6 | from scipy.stats import pearsonr
7 | import h5py
8 | import os
9 | import numpy as np
10 | import pylab
11 |
12 | from sklearn.cluster import KMeans
13 | # from ipywidgets import interact, interactive, fixed, interact_manual
14 | # import ipywidgets as widgets
15 | import torch
16 | import torch.nn.functional as F
17 | from functools import partial
18 | import json
19 | from skimage.io import imsave
20 | import tqdm
21 | import pprint
22 |
23 | import torch
24 | import sys
25 | sys.path.insert(0, os.path.abspath('..'))
26 | from .distances import prototype_distance # here we import the labelpropagation algorithm inside the "Distance" class
27 | import pandas
28 | import json
29 |
30 | # loading data
31 | # from trainers.few_shot_parallel_alpha_scale import inner_loop_lbfgs2, inner_loop_lbfgs_bootstrap
32 | from torch.utils.data import DataLoader
33 |
34 | class Sampler(object):
35 | """
36 | Samples few shot tasks from precomputed embeddings
37 | """
38 | def __init__(self, embeddings_fname, n_classes, distract_flag):
39 | self.h5fp = h5py.File(embeddings_fname, 'r')
40 | self.labels = self.h5fp["test_targets"][...]
41 | indices = np.arange(self.labels.shape[0])
42 | self.label_indices = {i: indices[self.labels == i] for i in set(self.labels)}
43 | self.nclasses = len(self.label_indices.keys())
44 | self.n_classes = n_classes
45 | self.distract_flag = distract_flag
46 |
47 | def sample_episode_indices(self, support_size,
48 | query_size, unlabeled_size, ways):
49 | """
50 | Returns the indices of the images of a random episode with predefined support, query and unlabeled sizes.
51 | the number of images is expressed in "ways"
52 | """
53 | label_indices = {k: np.random.permutation(v) for k,v in self.label_indices.items()}
54 | #label_indices = self.label_indices
55 |
56 | if self.distract_flag:
57 | classes = np.random.permutation(self.nclasses)
58 | distract_classes = classes[ways:(ways+ways)]
59 | classes = classes[:ways]
60 | else:
61 | classes = np.random.permutation(self.nclasses)[:ways]
62 |
63 | support_indices = []
64 | query_indices = []
65 | unlabel_indices = []
66 |
67 | for cls in classes:
68 | start = 0
69 | end = support_size
70 | support_indices.append(label_indices[cls][start:end])
71 | start = end
72 | end += query_size
73 | query_indices.append(label_indices[cls][start:end])
74 | start = end
75 | end += unlabeled_size
76 | assert(end < len(label_indices[cls]))
77 | unlabel_indices.append(label_indices[cls][start:end])
78 |
79 | if self.distract_flag:
80 | for cls in distract_classes:
81 | unlabel_indices.append(label_indices[cls][:unlabeled_size])
82 |
83 | return np.vstack(support_indices), np.vstack(query_indices), np.vstack(unlabel_indices)
84 |
85 | def _sample_field(self, field, *indices):
86 | features = self.h5fp["test_{}".format(field)]
87 | ret = []
88 | for _indices in indices:
89 | _indices = _indices.ravel()
90 | argind = np.argsort(_indices)
91 | if len(argind) == 0:
92 | ret.append(None)
93 | else:
94 | ind = _indices[argind]
95 | dset = features[ind.tolist()]
96 | dset[argind] = dset.copy()
97 | ret.append(dset)
98 |
99 | return tuple(ret)
100 |
101 | def sample_features(self, support_indices, query_indices, unlabel_indices):
102 | return self._sample_field("features", support_indices, query_indices, unlabel_indices)
103 |
104 | def sample_labels(self, support_indices, query_indices, unlabel_indices):
105 | return self._sample_field("targets", support_indices, query_indices, unlabel_indices)
106 |
107 | def sample_episode(self, support_size, query_size, unlabeled_size, apply_ten_flag=False):
108 | """
109 | Randomly samples an episode (features and labels) given the size of each set and the number of classes
110 |
111 | Returns: tuple(numpy array). Sets are of the size (set_size * nclasses, a512). a512 is the number
112 | of channels of the embeddings
113 | """
114 | ways = self.n_classes
115 | support_indices, query_indices, unlabel_indices = self.sample_episode_indices(support_size, query_size, unlabeled_size, ways)
116 | support_set, query_set, unlabel_set = self.sample_features(support_indices,
117 | query_indices,
118 | unlabel_indices)
119 | support_labels = ut.make_labels(support_size, ways)
120 | query_labels = ut.make_labels(query_size, ways)
121 | unlabel_labels = ut.make_labels(unlabeled_size, ways)
122 |
123 | episode_dict = episode2dict(support_set, query_set, unlabel_set, support_labels, query_labels, unlabel_labels)
124 |
125 | if apply_ten_flag:
126 | episode_dict = ut.apply_ten_on_episode(episode_dict)
127 |
128 | return episode_dict
129 |
130 | def episode2dict(support_set, query_set, unlabel_set, support_labels, query_labels, unlabel_labels):
131 | n_classes = support_set.shape[0]
132 |
133 | support_dict = {"samples": support_set, "labels":support_labels}
134 | query_dict = {"samples": query_set, "labels":query_labels}
135 | unlabeled_dict = {"samples": unlabel_set, "labels":unlabel_labels}
136 |
137 | return {"support":support_dict,
138 | "query":query_dict,
139 | "unlabeled":unlabeled_dict}
140 |
141 |
142 |
143 | def compute_acc(pred_labels, true_labels):
144 | acc = (true_labels.flatten() == pred_labels.ravel()).astype(float).mean()
145 |
146 | return acc
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/src/models/base_ssl/predict_methods/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from . import label_prop as lp
3 | from . import prototypical
4 | from . import adaptive
5 |
6 | def get_predictions(predict_method, episode_dict):
7 | if predict_method == "labelprop":
8 | return lp.label_prop_predict(episode_dict)
9 |
10 | elif predict_method == "prototypical":
11 | return prototypical.prototypical_predict(episode_dict)
12 | else:
13 | raise ValueError("Prediction method not found")
--------------------------------------------------------------------------------
/src/models/base_ssl/predict_methods/adaptive.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | # from models.wide_resnet_imagenet import WideResnetImagenet, Group
7 | # from modules.dynamic_residual_groups import get_group
8 | # from modules.ten import TEN
9 | # from modules.distances import Distance, _propagate, standarized_label_prop
10 | # from modules.output_heads import get_output_head
11 | # from modules.activations import get_activation
12 | # from modules.layers import MetricLinear
13 | from . import label_prop as lp
14 |
15 |
16 | def adaptive_predict(episode_dict, double_flag=False):
17 | # Get variables
18 | S = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda()
19 | S_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda()
20 | Q = torch.from_numpy(episode_dict["query"]["samples"]).cuda()
21 | U = torch.from_numpy(episode_dict["unlabeled"]["samples"]).cuda()
22 |
23 | SUQ = torch.cat([S, U, Q])[None]
24 | Q_labels = -1*torch.ones(Q.shape[0], device=S_labels.device, dtype=S_labels.dtype).cuda()
25 | U_labels = -1*torch.ones(U.shape[0], device=S_labels.device, dtype=S_labels.dtype).cuda()
26 |
27 | SUQ_labels = torch.cat([S_labels, U_labels, Q_labels])
28 | # Init Adaptive
29 | adaptive = AdaptiveTenInner(output_size=S.shape[1],
30 | nclasses=5,
31 | double_flag=double_flag).cuda()
32 |
33 | # Apply Adaptive
34 | # adaptive.train_step()
35 | UQ_logits = adaptive.forward(x=SUQ,
36 | support_size=S.shape[0],
37 | query_size=U.shape[0] + Q.shape[0],
38 | labels=SUQ_labels)
39 |
40 | return UQ_logits[-Q.shape[0]:].argmax(dim=1)
41 |
42 |
43 | class AdaptiveTenInner(torch.nn.Module):
44 | def __init__(self, output_size, nclasses, double_flag, mu_init=0, scale_init=0, precision_init=0, **exp_dict):
45 | super().__init__()
46 | self.nclasses = nclasses
47 | self.precision = torch.nn.Parameter(torch.randn(1, 1, output_size) * 0.1)
48 | self.classifier = torch.nn.Linear(output_size, nclasses)
49 | self.exp_dict = exp_dict
50 | self.optimizer = torch.optim.LBFGS([self.precision] + list(self.classifier.parameters()), tolerance_grad=1e-5, tolerance_change=1e-5, lr=0.1)
51 | if double_flag:
52 | self.label_prop = lp.LabelpropDouble()
53 | else:
54 | self.label_prop = lp.Labelprop()
55 |
56 | def train_step(self, _x, support_size, query_size, labels):
57 | x = _x.clone()
58 | self.optimizer.zero_grad()
59 | b, k, c = x.size()
60 | x = x * torch.sigmoid(1 + self.precision)
61 |
62 | zeros = torch.zeros(1, k, self.nclasses, device=x.device)
63 |
64 | logits, propagator = standarized_label_prop(x, zeros,
65 | 1, 1,
66 | apply_log=True, scale_bound="",
67 | standarize="all", kernel="rbf")
68 | x = _propagate(x, propagator)
69 | # support_set = x.view(-1, self.nclasses, c)[:support_size, ...].view(-1, c)
70 | support_set = x.view(-1, c)[:support_size, ...]
71 | support_labels = labels.view(-1, self.nclasses)[:support_size, ...].view(-1)
72 | logits = self.classifier(support_set)
73 | loss = F.cross_entropy(logits, support_labels) + 0.0001 * (self.precision ** 2).mean()
74 | loss.backward()
75 | return loss
76 |
77 | def forward(self, x, support_size, query_size, labels):
78 | self.train()
79 | self.optimizer.step(lambda: self.train_step(x, support_size, query_size, labels))
80 | # self.train_step(x, support_size, query_size, labels)
81 | with torch.no_grad():
82 | # mu, scale, precision = (0.5 + self.mu **2).detach(), (0.5 + self.scale ** 2).detach(), (1 + self.precision.detach())
83 | mu, scale, precision = 1, 1, self.precision.detach()
84 | # mu, scale, precision = 1, 1, 1
85 | x = x * torch.sigmoid(1 + precision)
86 | one_hot_labels = F.one_hot(labels.view(-1)).float()
87 | one_hot_labels = one_hot_labels.view(1,
88 | support_size + query_size,
89 | self.nclasses)
90 | one_hot_labels[:, support_size:, ...] = 0
91 |
92 | logits, propagator = standarized_label_prop(x, one_hot_labels,
93 | scale, mu, apply_log=True,
94 | scale_bound="", standarize="all",
95 | kernel="rbf")
96 | x = _propagate(x, propagator)
97 | logits, propagator = standarized_label_prop(x, one_hot_labels,
98 | scale, mu, apply_log=True,
99 | scale_bound="", standarize="all",
100 | kernel="rbf")
101 | logits = logits.view((support_size + query_size), self.nclasses)
102 | return logits[support_size:].view(-1, self.nclasses)
103 |
104 |
105 | def standarized_label_prop(embeddings,
106 | labels,
107 | gaussian_scale=1, alpha=1,
108 | weights=None,
109 | apply_log=False, scale_bound="", standarize="", kernel="", square_root=False,
110 | epsilon=1e-6):
111 | if scale_bound == "softplus":
112 | gaussian_scale = 0.01 + F.softplus(gaussian_scale)
113 | alpha = 0.1 + F.softplus(alpha)
114 | elif scale_bound == "square":
115 | gaussian_scale = 1e-4 + gaussian_scale ** 2
116 | alpha = 0.1 + alpha ** 2
117 | elif scale_bound == "convex_relu":
118 | #gaussian_scale = gaussian_scale ** 2
119 | alpha = F.relu(alpha) + 0.1
120 | elif scale_bound == "convex_square":
121 | # gaussian_scale = gaussian_scale ** 2
122 | alpha = 0.1 + alpha ** 2
123 | elif scale_bound == "relu":
124 | gaussian_scale = F.relu(gaussian_scale) + 0.01
125 | alpha = F.relu(alpha) + 0.1
126 | elif scale_bound == "constant":
127 | gaussian_scale = 1
128 | alpha = 1
129 | elif scale_bound == "alpha_square":
130 | alpha = 0.1 + F.relu(alpha)
131 |
132 | # Compute the pairwise distance between the examples of the sample and query sets
133 | # XXX: labels are set to a constant for the query set
134 | sq_dist = generalized_pw_sq_dist(embeddings, "euclidean")
135 | if square_root:
136 | sq_dist = (sq_dist + epsilon).sqrt()
137 | if standarize == "all":
138 | mask = sq_dist != 0
139 | # sq_dist = sq_dist - sq_dist[mask].mean()
140 | sq_dist = sq_dist / sq_dist[mask].std()
141 | elif standarize == "median":
142 | mask = sq_dist != 0
143 | gaussian_scale = torch.sqrt(
144 | 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1)))
145 | elif standarize == "frobenius":
146 | mask = sq_dist != 0
147 | sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt()
148 | elif standarize == "percentile":
149 | mask = sq_dist != 2
150 | sorted, indices = torch.sort(sq_dist.data[mask])
151 | total = sorted.size(0)
152 | gaussian_scale = sorted[int(total * 0.1)].detach()
153 |
154 | if kernel == "rbf":
155 | weights = torch.exp(-sq_dist * gaussian_scale)
156 | elif kernel == "convex_rbf":
157 | scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype)
158 | weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1))
159 | weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1)
160 | # checknan(timessoftmax=weights)
161 | elif kernel == "euclidean":
162 | # Compute similarity between the examples -- inversely proportional to distance
163 | weights = 1 / (gaussian_scale + sq_dist)
164 | elif kernel == "softmax":
165 | weights = F.softmax(-sq_dist / gaussian_scale, -1)
166 |
167 | mask = torch.eye(weights.size(1), dtype=torch.bool, device=weights.device)[None, :, :]
168 | weights = weights * (~mask).float()
169 |
170 | logits, propagator = global_consistency(weights, labels, alpha=alpha)
171 |
172 | if apply_log:
173 | logits = torch.log(logits + epsilon)
174 |
175 | return logits, propagator
176 |
177 | def generalized_pw_sq_dist(data, d_type="euclidean"):
178 | batch, samples, z_dim = data.size()
179 | if d_type == "euclidean":
180 | return torch.sum((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim)
181 | elif d_type == "l1":
182 | return torch.mean(torch.abs(data[:, :, None, :] - data[:, None, :, :]), dim=3)
183 | elif d_type == "stable_euclidean":
184 | return torch.sqrt(1e-6 + torch.mean((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim))
185 | elif d_type == "cosine":
186 | data = F.normalize(data, dim=2)
187 | return torch.bmm(data, data.transpose(2, 1))
188 | else:
189 | raise ValueError("Distance type not recognized")
190 |
191 | def global_consistency(weights, labels, alpha=0.1):
192 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug)
193 | Args:
194 | weights: Tensor of shape (batch, n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and
195 | s the scale parameter.
196 | labels: Tensor of shape (batch, n, n_classes)
197 | alpha: Scaler, acts as a smoothing factor
198 | Returns:
199 | Tensor of shape (batch, n, n_classes) representing the logits of each classes
200 | """
201 | n = weights.shape[1]
202 | _alpha = 1 / (1 + alpha)
203 | _beta = alpha / (1 + alpha)
204 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)[None, :, :]
205 | #weights = weights * (1. - identity) # zero out diagonal
206 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=2))
207 | # checknan(laplacian=isqrt_diag)
208 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, :, None]
209 | # checknan(normalizedlaplacian=S)
210 | propagator = identity - _alpha * S
211 | propagator = torch.inverse(propagator) * _beta
212 | # checknan(propagator=propagator)
213 |
214 | return _propagate(labels, propagator, scaling=1), propagator
215 |
216 |
217 | def _propagate(labels, propagator, scaling=1.):
218 | return torch.matmul(propagator, labels) * scaling
--------------------------------------------------------------------------------
/src/models/base_ssl/predict_methods/label_prop.py:
--------------------------------------------------------------------------------
1 |
2 | import matplotlib.pyplot as plt
3 | # import seaborn as sns
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import warnings
8 | from functools import partial
9 | # from tools.meters import BasicMeter
10 | # from train import TensorLogger
11 | import sys
12 | from embedding_propagation import LabelPropagation
13 |
14 | def label_prop_predict(episode_dict):
15 | S = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda()
16 | S_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda()
17 | nclasses = int(S_labels.max() + 1)
18 | Q_labels = torch.zeros(episode_dict["query"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses
19 | U_labels = torch.zeros(episode_dict["unlabeled"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses
20 | A_labels = torch.cat([S_labels, Q_labels, U_labels], 0)
21 | Q = torch.from_numpy(episode_dict["query"]["samples"]).cuda()
22 |
23 | U = torch.from_numpy(episode_dict["unlabeled"]["samples"]).cuda()
24 |
25 | lp = LabelPropagation(balanced=True)
26 |
27 | SUQ = torch.cat([S, U, Q], dim=0)
28 | logits = lp(SUQ, A_labels, nclasses)
29 | logits_query = logits[-Q.shape[0]:]
30 |
31 | return logits_query.argmax(dim=1).cpu().numpy()
--------------------------------------------------------------------------------
/src/models/base_ssl/predict_methods/prototypical.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def prototypical_predict(episode_dict):
6 | support_samples = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda()
7 | support_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda()
8 | query_samples = torch.from_numpy(episode_dict["query"]["samples"]).cuda()
9 |
10 | logits = prototype_distance(support_set=support_samples,
11 | query_set=query_samples,
12 | labels=support_labels,
13 | unlabeled_set=None)
14 | return logits.argmax(dim=1)
15 |
16 |
17 |
18 |
19 | def prototype_distance(support_set, query_set, labels, unlabeled_set=None):
20 | """Computes distance from each element of the query set to prototypes in the sample set.
21 | Args:
22 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of
23 | each images.
24 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of
25 | each images.
26 | labels: Tensor of Long of shape(support_set_size)
27 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation
28 | z of each images.
29 | Returns:
30 | Tensor of shape (batch, n_total_query, n_classes) containing the similarity between each pair of query,
31 | prototypes, for each task.
32 | """
33 | n_queries, channels = query_set.size()
34 | n_support, channels = support_set.size()
35 |
36 | support_set = support_set.view(n_support, 1, channels)
37 |
38 | way = int(labels.data.max()) + 1
39 | one_hot_labels = torch.zeros(n_support, way, 1, dtype=support_set.dtype, device=support_set.device)
40 | one_hot_labels.scatter_(1, labels.view(n_support, 1, 1), 1)
41 |
42 | total_per_class = one_hot_labels.sum(0, keepdim=True)
43 | prototypes = (support_set * one_hot_labels).sum(0) / total_per_class
44 | prototypes = prototypes.view(1, way, channels)
45 |
46 | query_set = query_set.view(n_queries, 1, channels)
47 | d = query_set - prototypes
48 | return -torch.sum(d ** 2, 2) / np.sqrt(channels)
--------------------------------------------------------------------------------
/src/models/base_ssl/selection_methods/__init__.py:
--------------------------------------------------------------------------------
1 | from . import ssl
2 | import numpy as np
3 |
4 | def get_indices(selection_method, episode_dict, support_size_max=None):
5 | # random
6 | if selection_method == "random":
7 | ind = np.random.choice(episode_dict["unlabeled"]["samples"].shape[0], 1, replace=False)
8 |
9 | # random imbalanced
10 | if selection_method == "random_imbalanced":
11 | ind = np.random.choice(episode_dict["unlabeled"]["samples"].shape[0], 1, replace=False)
12 |
13 | # ssl
14 | if selection_method == "ssl":
15 | ind = ssl.ssl_get_next_best_indices(episode_dict)
16 |
17 |
18 | # episode_dict["selected_indices"] = ind
19 | return ind
20 |
--------------------------------------------------------------------------------
/src/models/base_ssl/selection_methods/ssl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from ..predict_methods import label_prop as lp
4 | import numpy as np
5 | from .. import utils as ut
6 | from scipy import stats
7 | from embedding_propagation import LabelPropagation
8 |
9 | def ssl_get_next_best_indices(episode_dict, support_size_max=None):
10 | S = torch.from_numpy(episode_dict["support_so_far"]["samples"]).cuda()
11 | Q = torch.from_numpy(episode_dict["query"]["samples"]).cuda()
12 | U = torch.from_numpy(episode_dict["unlabeled"]["samples"]).cuda()
13 | S_labels = torch.from_numpy(episode_dict["support_so_far"]["labels"]).cuda()
14 | nclasses = int(S_labels.max() + 1)
15 | Q_labels = torch.zeros(episode_dict["query"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses
16 | U_labels = torch.zeros(episode_dict["unlabeled"]["samples"].shape[0], dtype=S_labels.dtype).cuda() + nclasses
17 | A_labels = torch.cat([S_labels, Q_labels, U_labels], 0)
18 |
19 | SQU = torch.cat([S, Q, U], dim=0) # Information gain is measured in the whole system
20 |
21 | # train label_prop
22 | lp = LabelPropagation(balanced=True)
23 | logits = lp(SQU, A_labels, nclasses)
24 |
25 | U_logits = logits[-U.shape[0]:]
26 | # modify the labels of the unlabeled
27 | episode_dict["unlabeled"]["labels"] = U_logits.argmax(dim=1).cpu().numpy()
28 |
29 | # unlabeled_scores = U_logits.max(dim=1).cpu().numpy()
30 |
31 | if support_size_max is None:
32 | # choose all the unlabeled examples
33 | return np.arange(U.shape[0])
34 | else:
35 | # score each
36 | score_list = U_logits.max(dim=1)[0].cpu().numpy()
37 | return score_list.argsort()[-support_size_max:]
38 |
39 | def predict(S, S_labels, UQ, U_shape):
40 | label_prop = lp.Labelprop()
41 | label_prop.fit(support_set=S, unlabeled_set=UQ)
42 |
43 | logits = label_prop.predict(support_labels=S_labels,
44 | unlabeled_pseudolabels=None,
45 | balanced_flag=True)
46 |
47 | U_logits = logits[S.shape[0]:S.shape[0]+U_shape[0]]
48 | return U_logits.argmax(dim=1).cpu().numpy()
--------------------------------------------------------------------------------
/src/models/base_ssl/utils.py:
--------------------------------------------------------------------------------
1 |
2 | # %% Import libraries & loading data & Helper Sampling Class & Pytorch helpers and imports
3 | import torch
4 | import sys
5 | import copy
6 | from scipy.stats import pearsonr
7 | import h5py
8 | import os
9 | import numpy as np
10 | import pylab
11 | from sklearn.linear_model import LogisticRegression
12 | from sklearn.cluster import KMeans
13 | # from ipywidgets import interact, interactive, fixed, interact_manual
14 | # import ipywidgets as widgets
15 | import torch
16 | import torch.nn.functional as F
17 | from functools import partial
18 | import json
19 | from skimage.io import imsave
20 | import tqdm
21 | import pprint
22 | import torch
23 | import sys
24 | from .distances import prototype_distance # here we import the labelpropagation algorithm inside the "Distance" class
25 | import pandas
26 | import json
27 |
28 | # loading data
29 | # from trainers.few_shot_parallel_alpha_scale import inner_loop_lbfgs2, inner_loop_lbfgs_bootstrap
30 | from torch.utils.data import DataLoader
31 |
32 |
33 |
34 | def get_unlabeled_set(method, unlabeled_set,unlabel_labels, unlabeled_size,
35 | support_set_reshaped):
36 | support_labels = get_support_labels(support_set_reshaped).ravel().astype(int)
37 | support_size = support_set_reshaped.shape[1]
38 | n_classes = support_set_reshaped.shape[0]
39 | if method == "cheating":
40 | unlabeled_set_new = unlabeled_set
41 | unlabel_labels_new = unlabel_labels
42 | unlabeled_size_new = unlabeled_size
43 |
44 |
45 | elif method == "prototypical":
46 | support_labels_reshaped = support_labels.reshape((n_classes, support_size))
47 | unlabeled_set_torch = torch.FloatTensor(unlabeled_set).cuda()
48 | unlabeled_set_view = unlabeled_set_torch.view(1, n_classes, unlabeled_size, a512)
49 |
50 | if True:
51 | unlabeled_set_view, tmp_unlabel_labels = predict_sort(support_set_reshaped,
52 | unlabeled_set_view,
53 | n_classes=n_classes)
54 | else:
55 | tmp_unlabel_labels = predict(support_set_reshaped,
56 | unlabeled_set_view,
57 | n_classes=n_classes)
58 |
59 | unlabeled_set_list = []
60 | # label
61 | unlabeled_size_new = np.inf
62 | for c in range(n_classes):
63 | set = unlabeled_set_torch[tmp_unlabel_labels == c]
64 | unlabeled_set_list += [set]
65 | unlabeled_size_new = min(len(set), unlabeled_size_new)
66 |
67 | # cut to shortest
68 | unlabel_labels_new = []
69 | for c in range(n_classes):
70 | unlabeled_set_list[c] = unlabeled_set_list[c][:unlabeled_size_new]
71 | unlabel_labels_new += [np.ones(unlabeled_size_new) * c]
72 |
73 | unlabeled_set_new = torch.cat(unlabeled_set_list).cpu().numpy().astype(unlabeled_set.dtype)
74 | unlabel_labels_new = np.vstack(unlabel_labels_new).ravel().astype("int64")
75 |
76 | return unlabeled_set_new, unlabel_labels_new, unlabeled_size_new
77 |
78 |
79 |
80 |
81 | def xlogy(x, y=None):
82 | z = torch.zeros(())
83 | if y is None:
84 | y = x
85 | assert y.min() >= 0
86 | return x * torch.where(x == 0., z, torch.log(y))
87 |
88 |
89 | def get_support_labels(support_set_features):
90 | support_labels=[]
91 | for c in range(5):
92 | support_labels += [np.ones(support_set_features.shape[1])*c]
93 |
94 | support_labels = np.vstack(support_labels)
95 | return support_labels
96 |
97 |
98 | def get_entropy_support_set(monitor, support_size):
99 | support_set_ind = [[],[],[],[],[]]
100 | for s in range(1, support_size+1):
101 | # Get best next support
102 | entropy_best = 0
103 | for i in range(monitor.unlabeled_size):
104 | support_set_tmp = copy.deepcopy(support_set_ind)
105 | if i in support_set_tmp[0]:
106 | continue
107 | for c in range(monitor.n_classes):
108 | s_ind = monitor.unlabeled_size*c
109 | ind = s_ind + i
110 |
111 | support_set_tmp[c] += [ind]
112 |
113 | entropy_tmp = monitor.compute_entropy(support_set_tmp)
114 | if entropy_tmp >= entropy_best:
115 | support_set_best = support_set_tmp
116 | entropy_best = entropy_tmp
117 |
118 | support_set_ind = support_set_best
119 |
120 | for c in range(monitor.n_classes):
121 | ind_c = np.arange(monitor.unlabeled_size*c, monitor.unlabeled_size*(c+1))
122 | # 1. Within class ind sanity check
123 | assert False not in [sc in ind_c for sc in support_set_ind[c]]
124 | # 2. Uniqueness sanity check
125 | assert np.unique(support_set_ind[c]).size == np.array(support_set_ind[c]).size
126 |
127 | support_set_list = [monitor.unlabeled_set[i_list] for i_list in support_set_ind]
128 | support_set = np.vstack(support_set_list)
129 | return support_set
130 |
131 | def get_kmeans_support_set(monitor, support_size):
132 | # unlabeled_set = unlabeled_set
133 | # unlabel_labels = unlabel_labels
134 |
135 | # greedy
136 | support_set_list = [[],[],[],[],[]]
137 | for c in range(monitor.n_classes):
138 | s_ind = monitor.unlabeled_size*c
139 |
140 | X = monitor.unlabeled_set[s_ind:s_ind+monitor.unlabeled_size]
141 | k_means = KMeans(n_clusters=support_size, random_state=0).fit(X)
142 | support_set_list[c] = k_means.cluster_centers_
143 | support_set = np.vstack(support_set_list)
144 | return support_set
145 |
146 |
147 |
148 | def get_greedy_support_set(monitor, support_size):
149 | # unlabeled_set = unlabeled_set
150 | # unlabel_labels = unlabel_labels
151 |
152 | # greedy
153 | support_set_ind = [[],[],[],[],[]]
154 | for s in range(1, support_size+1):
155 | # Get best next support
156 | acc_best = 0.
157 | for i in range(monitor.unlabeled_size):
158 | support_set_tmp = copy.deepcopy(support_set_ind)
159 | if i in support_set_tmp[0]:
160 | continue
161 | for c in range(monitor.n_classes):
162 | s_ind = monitor.unlabeled_size*c
163 | ind = s_ind + i
164 |
165 | support_set_tmp[c] += [ind]
166 |
167 | acc_tmp = monitor.compute_acc(support_set_tmp)
168 | if acc_tmp >= acc_best:
169 | support_set_best = support_set_tmp
170 | acc_best = acc_tmp
171 |
172 | support_set_ind = support_set_best
173 |
174 | for c in range(monitor.n_classes):
175 | ind_c = np.arange(monitor.unlabeled_size*c, monitor.unlabeled_size*(c+1))
176 | # 1. Within class ind sanity check
177 | assert False not in [sc in ind_c for sc in support_set_ind[c]]
178 | # 2. Uniqueness sanity check
179 | assert np.unique(support_set_ind[c]).size == np.array(support_set_ind[c]).size
180 |
181 | support_set_list = [monitor.unlabeled_set[i_list] for i_list in support_set_ind]
182 | support_set = np.vstack(support_set_list)
183 | return support_set
184 |
185 |
186 |
187 |
188 | def get_random_support_set(monitor, support_size):
189 | support_set_list = []
190 | for c in range(monitor.n_classes):
191 | ind = np.arange(monitor.unlabeled_size*c, monitor.unlabeled_size*(c+1))
192 | ind_c = np.random.choice(ind, support_size, replace=False)
193 | support_set_list += [monitor.unlabeled_set[ind_c]]
194 |
195 | support_set = np.vstack(support_set_list)
196 | return support_set
197 |
198 | @torch.no_grad()
199 | def calc_accuracy(split, iters, distance_fn, support_size,
200 | query_size, unlabeled_size, method=None, model=None, n_classes=5):
201 | sampler = Sampler(split)
202 | accuracies = []
203 |
204 | if method == "pairs":
205 | check_pairs(split, iters, distance_fn, support_size,
206 | query_size, unlabeled_size, n_classes=5)
207 |
208 |
209 |
210 |
211 | # Helper Sampling Class
212 | def make_labels(size, classes):
213 | """
214 | Helper function. Generates the labels of a set: e.g 0000 1111 2222 for size=4, classes=3
215 | """
216 | return (np.arange(classes).reshape((classes, 1)) + np.zeros((classes, size), dtype=np.int32)).flatten()
217 |
218 | # Pytorch helpers and imports
219 | def to_pytorch(datum, n, k, c):
220 | if k > 0:
221 | return torch.from_numpy(datum).view(-1, n, k, c)
222 | else:
223 | return None
224 |
--------------------------------------------------------------------------------
/src/models/base_wrapper.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | class BaseWrapper(torch.nn.Module):
5 | def get_state_dict(self):
6 | raise NotImplementedError
7 | def load_state_dict(self, state_dict):
8 | raise NotImplementedError
9 |
--------------------------------------------------------------------------------
/src/models/finetuning.py:
--------------------------------------------------------------------------------
1 | """
2 | Few-Shot Parallel: trains a model as a series of tasks computed in parallel on multiple GPUs
3 |
4 | """
5 | import numpy as np
6 | import os
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from src.tools.meters import BasicMeter
11 | from src.modules.distances import prototype_distance
12 | from embedding_propagation import EmbeddingPropagation, LabelPropagation
13 | from .base_wrapper import BaseWrapper
14 | from haven import haven_utils as haven
15 | from scipy.stats import sem, t
16 | import shutil as sh
17 | import sys
18 |
19 |
20 | class FinetuneWrapper(BaseWrapper):
21 | """Finetunes a model using an episodic scheme on multiple GPUs"""
22 |
23 | def __init__(self, model, nclasses, exp_dict):
24 | """ Constructor
25 | Args:
26 | model: architecture to train
27 | nclasses: number of output classes
28 | exp_dict: reference to dictionary with the hyperparameters
29 | """
30 | super().__init__()
31 | self.model = model
32 | self.exp_dict = exp_dict
33 | self.ngpu = self.exp_dict["ngpu"]
34 |
35 | self.embedding_propagation = EmbeddingPropagation()
36 | self.label_propagation = LabelPropagation()
37 | self.model.add_classifier(nclasses, modalities=0)
38 | self.nclasses = nclasses
39 |
40 | if self.exp_dict["rotation_weight"] > 0:
41 | self.model.add_classifier(4, "classifier_rot")
42 |
43 | best_accuracy = -1
44 | if self.exp_dict["pretrained_weights_root"] is not None:
45 | for exp_hash in os.listdir(self.exp_dict['pretrained_weights_root']):
46 | base_path = os.path.join(self.exp_dict['pretrained_weights_root'], exp_hash)
47 | exp_dict_path = os.path.join(base_path, 'exp_dict.json')
48 | if not os.path.exists(exp_dict_path):
49 | continue
50 | loaded_exp_dict = haven.load_json(exp_dict_path)
51 | pkl_path = os.path.join(base_path, 'score_list_best.pkl')
52 | if (loaded_exp_dict["model"]["name"] == 'pretraining' and
53 | loaded_exp_dict["dataset_train"].split('_')[-1] == exp_dict["dataset_train"].split('_')[-1] and
54 | loaded_exp_dict["model"]["backbone"] == exp_dict['model']["backbone"] and
55 | # loaded_exp_dict["labelprop_alpha"] == exp_dict["labelprop_alpha"] and
56 | # loaded_exp_dict["labelprop_scale"] == exp_dict["labelprop_scale"] and
57 | os.path.exists(pkl_path)):
58 | accuracy = haven.load_pkl(pkl_path)[-1]["val_accuracy"]
59 | try:
60 | self.model.load_state_dict(torch.load(os.path.join(base_path, 'checkpoint_best.pth'))['model'], strict=False)
61 | if accuracy > best_accuracy:
62 | best_path = os.path.join(base_path, 'checkpoint_best.pth')
63 | best_accuracy = accuracy
64 | except:
65 | continue
66 | assert(best_accuracy > 0.1)
67 | print("Finetuning %s with original accuracy : %f" %(base_path, best_accuracy))
68 | self.model.load_state_dict(torch.load(best_path)['model'], strict=False)
69 |
70 | # Add optimizers here
71 | self.optimizer = torch.optim.SGD(self.model.parameters(),
72 | lr=self.exp_dict["lr"],
73 | momentum=0.9,
74 | weight_decay=self.exp_dict["weight_decay"],
75 | nesterov=True)
76 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
77 | mode="min" if "loss" in self.exp_dict["target_loss"] else "max",
78 | patience=self.exp_dict["patience"])
79 | self.model.cuda()
80 | if self.ngpu > 1:
81 | self.parallel_model = torch.nn.DataParallel(self.model, device_ids=list(range(self.ngpu)))
82 |
83 | def get_logits(self, embeddings, support_size, query_size, nclasses):
84 | """Computes the logits from the queries of an episode
85 |
86 | Args:
87 | embeddings (torch.Tensor): episode embeddings
88 | support_size (int): size of the support set
89 | query_size (int): size of the query set
90 | nclasses (int): number of classes
91 |
92 | Returns:
93 | torch.Tensor: logits
94 | """
95 | b, c = embeddings.size()
96 |
97 | propagator = None
98 | if self.exp_dict["embedding_prop"] == True:
99 | embeddings = self.embedding_propagation(embeddings)
100 |
101 | if self.exp_dict["distance_type"] == "labelprop":
102 | support_labels = torch.arange(nclasses, device=embeddings.device).view(1, nclasses).repeat(support_size, 1).view(support_size, nclasses)
103 | unlabeled_labels = nclasses * torch.ones(query_size * nclasses, dtype=support_labels.dtype, device=support_labels.device).view(query_size, nclasses)
104 | labels = torch.cat([support_labels, unlabeled_labels], 0).view(-1)
105 | logits = self.label_propagation(embeddings, labels, nclasses)
106 | logits = logits.view(-1, nclasses, nclasses)[support_size:(support_size + query_size), ...].view(-1, nclasses)
107 |
108 | elif self.exp_dict["distance_tpe"] == "prototypical":
109 | embeddings = embeddings.view(-1, nclasses, c)
110 | support_embeddings = embeddings[:support_size]
111 | query_embeddings = embeddings[support_size:]
112 | logits = prototype_distance((support_embeddings.view(1, support_size, nclasses, c), False),
113 | (query_embeddings.view(1, query_size, nclasses, c), False)).view(query_size * nclasses, nclasses)
114 | return logits
115 |
116 | def train_on_batch(self, batch):
117 | """Computes the loss on an episode
118 |
119 | Args:
120 | batch (dict): Episode dict
121 |
122 | Returns:
123 | tuple: loss and accuracy of the episode
124 | """
125 | episode = batch[0]
126 | nclasses = episode["nclasses"]
127 | support_size = episode["support_size"]
128 | query_size = episode["query_size"]
129 | labels = episode["targets"].view(support_size + query_size, nclasses, -1).cuda(non_blocking=True).long()
130 | k = (support_size + query_size)
131 | c = episode["channels"]
132 | h = episode["height"]
133 | w = episode["width"]
134 |
135 | tx = episode["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True)
136 | vx = episode["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True)
137 | x = torch.cat([tx, vx], 0)
138 | x = x.view(-1, c, h, w).cuda(non_blocking=True)
139 | if self.ngpu > 1:
140 | embeddings = self.parallel_model(x, is_support=True)
141 | else:
142 | embeddings = self.model(x, is_support=True)
143 | b, c = embeddings.size()
144 |
145 | logits = self.get_logits(embeddings, support_size, query_size, nclasses)
146 |
147 | loss = 0
148 | if self.exp_dict["classification_weight"] > 0:
149 | loss += F.cross_entropy(self.model.classifier(embeddings.view(b, c)), labels.view(-1)) * self.exp_dict["classification_weight"]
150 |
151 | query_labels = torch.arange(nclasses, device=logits.device).view(1, nclasses).repeat(query_size, 1).view(-1)
152 | loss += F.cross_entropy(logits, query_labels) * self.exp_dict["few_shot_weight"]
153 | return loss
154 |
155 | def predict_on_batch(self, batch):
156 | """Computes the logits of an episode
157 |
158 | Args:
159 | batch (dict): episode dict
160 |
161 | Returns:
162 | tensor: logits for the queries of the current episode
163 | """
164 | nclasses = batch["nclasses"]
165 | support_size = batch["support_size"]
166 | query_size = batch["query_size"]
167 | k = (support_size + query_size)
168 | c = batch["channels"]
169 | h = batch["height"]
170 | w = batch["width"]
171 |
172 | tx = batch["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True)
173 | vx = batch["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True)
174 | x = torch.cat([tx, vx], 0)
175 | x = x.view(-1, c, h, w).cuda(non_blocking=True)
176 |
177 | if self.ngpu > 1:
178 | embeddings = self.parallel_model(x, is_support=True)
179 | else:
180 | embeddings = self.model(x, is_support=True)
181 |
182 | return self.get_logits(embeddings, support_size, query_size, nclasses)
183 |
184 | def val_on_batch(self, batch):
185 | """Computes the loss and accuracy on a validation batch
186 |
187 | Args:
188 | batch (dict): Episode dict
189 |
190 | Returns:
191 | tuple: loss and accuracy of the episode
192 | """
193 | nclasses = batch["nclasses"]
194 | query_size = batch["query_size"]
195 |
196 | logits = self.predict_on_batch(batch)
197 |
198 | query_labels = torch.arange(nclasses, device=logits.device).view(1, nclasses).repeat(query_size, 1).view(-1)
199 | loss = F.cross_entropy(logits, query_labels)
200 | accuracy = float(logits.max(-1)[1].eq(query_labels).float().mean())
201 |
202 | return loss, accuracy
203 |
204 | def train_on_loader(self, data_loader, max_iter=None, debug_plot_path=None):
205 | """Iterate over the training set
206 |
207 | Args:
208 | data_loader: iterable training data loader
209 | max_iter: max number of iterations to perform if the end of the dataset is not reached
210 | """
211 | self.model.train()
212 | train_loss_meter = BasicMeter.get("train_loss").reset()
213 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
214 | self.optimizer.zero_grad()
215 | for batch_idx, batch in enumerate(data_loader):
216 | loss = self.train_on_batch(batch) / self.exp_dict["tasks_per_batch"]
217 | train_loss_meter.update(float(loss), 1)
218 | loss.backward()
219 | if ((batch_idx + 1) % self.exp_dict["tasks_per_batch"]) == 0:
220 | self.optimizer.step()
221 | self.optimizer.zero_grad()
222 | if batch_idx + 1 == max_iter:
223 | break
224 | return {"train_loss": train_loss_meter.mean()}
225 |
226 |
227 | @torch.no_grad()
228 | def val_on_loader(self, data_loader, max_iter=None):
229 | """Iterate over the validation set
230 |
231 | Args:
232 | data_loader: iterable validation data loader
233 | max_iter: max number of iterations to perform if the end of the dataset is not reached
234 | """
235 | self.model.eval()
236 | val_loss_meter = BasicMeter.get("val_loss").reset()
237 | val_accuracy_meter = BasicMeter.get("val_accuracy").reset()
238 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
239 | for batch_idx, _data in enumerate(data_loader):
240 | batch = _data[0]
241 | loss, accuracy = self.val_on_batch(batch)
242 | val_loss_meter.update(float(loss), 1)
243 | val_accuracy_meter.update(float(accuracy), 1)
244 | loss = BasicMeter.get(self.exp_dict["target_loss"], recursive=True, force=False).mean()
245 | self.scheduler.step(loss) # update the learning rate monitor
246 | return {"val_loss": val_loss_meter.mean(), "val_accuracy": val_accuracy_meter.mean()}
247 |
248 | @torch.no_grad()
249 | def test_on_loader(self, data_loader, max_iter=None):
250 | """Iterate over the validation set
251 |
252 | Args:
253 | data_loader: iterable validation data loader
254 | max_iter: max number of iterations to perform if the end of the dataset is not reached
255 | """
256 | self.model.eval()
257 | test_loss_meter = BasicMeter.get("test_loss").reset()
258 | test_accuracy_meter = BasicMeter.get("test_accuracy").reset()
259 | test_accuracy = []
260 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
261 | for batch_idx, _data in enumerate(data_loader):
262 | batch = _data[0]
263 | loss, accuracy = self.val_on_batch(batch)
264 | test_loss_meter.update(float(loss), 1)
265 | test_accuracy_meter.update(float(accuracy), 1)
266 | test_accuracy.append(float(accuracy))
267 | from scipy.stats import sem, t
268 | confidence = 0.95
269 | n = len(test_accuracy)
270 | std_err = sem(np.array(test_accuracy))
271 | h = std_err * t.ppf((1 + confidence) / 2, n - 1)
272 | return {"test_loss": test_loss_meter.mean(), "test_accuracy": test_accuracy_meter.mean(), "test_confidence": h}
273 |
274 | def get_state_dict(self):
275 | """Obtains the state dict of this model including optimizer, scheduler, etc
276 |
277 | Returns:
278 | dict: state dict
279 | """
280 | ret = {}
281 | ret["optimizer"] = self.optimizer.state_dict()
282 | ret["model"] = self.model.state_dict()
283 | ret["scheduler"] = self.scheduler.state_dict()
284 | return ret
285 |
286 | def load_state_dict(self, state_dict):
287 | """Loads the state of the model
288 |
289 | Args:
290 | state_dict (dict): The state to load
291 | """
292 | self.optimizer.load_state_dict(state_dict["optimizer"])
293 | self.model.load_state_dict(state_dict["model"])
294 | self.scheduler.load_state_dict(state_dict["scheduler"])
295 |
296 | def get_lr(self):
297 | ret = {}
298 | for i, param_group in enumerate(self.optimizer.param_groups):
299 | ret["current_lr_%d" % i] = float(param_group["lr"])
300 | return ret
301 |
302 | def is_end_of_training(self):
303 | lr = self.get_lr()["current_lr_0"]
304 | return lr <= (self.exp_dict["lr"] * self.exp_dict["min_lr_decay"])
--------------------------------------------------------------------------------
/src/models/pretraining.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn.functional as F
5 | from src.tools.meters import BasicMeter
6 | from embedding_propagation import EmbeddingPropagation, LabelPropagation
7 | from .base_wrapper import BaseWrapper
8 | from src.modules.distances import prototype_distance
9 |
10 | class PretrainWrapper(BaseWrapper):
11 | """Trains a model using an episodic scheme on multiple GPUs"""
12 |
13 | def __init__(self, model, nclasses, exp_dict):
14 | """ Constructor
15 | Args:
16 | model: architecture to train
17 | nclasses: number of output classes
18 | exp_dict: reference to dictionary with the hyperparameters
19 | """
20 | super().__init__()
21 | self.model = model
22 | self.exp_dict = exp_dict
23 | self.ngpu = self.exp_dict["ngpu"]
24 | self.embedding_propagation = EmbeddingPropagation()
25 | self.label_propagation = LabelPropagation()
26 | self.model.add_classifier(nclasses, modalities=0)
27 | self.nclasses = nclasses
28 |
29 |
30 | if self.exp_dict["rotation_weight"] > 0:
31 | self.model.add_classifier(4, "classifier_rot")
32 |
33 | # Add optimizers here
34 | self.optimizer = torch.optim.SGD(self.model.parameters(),
35 | lr=self.exp_dict["lr"],
36 | momentum=0.9,
37 | weight_decay=self.exp_dict["weight_decay"],
38 | nesterov=True)
39 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
40 | mode="min" if "loss" in self.exp_dict["target_loss"] else "max",
41 | patience=self.exp_dict["patience"])
42 | self.model.cuda()
43 | if self.ngpu > 1:
44 | self.parallel_model = torch.nn.DataParallel(self.model, device_ids=list(range(self.ngpu)))
45 |
46 | def get_logits(self, embeddings, support_size, query_size, nclasses):
47 | """Computes the logits from the queries of an episode
48 |
49 | Args:
50 | embeddings (torch.Tensor): episode embeddings
51 | support_size (int): size of the support set
52 | query_size (int): size of the query set
53 | nclasses (int): number of classes
54 |
55 | Returns:
56 | torch.Tensor: logits
57 | """
58 | b, c = embeddings.size()
59 |
60 | if self.exp_dict["embedding_prop"] == True:
61 | embeddings = self.embedding_propagation(embeddings)
62 | if self.exp_dict["distance_type"] == "labelprop":
63 | support_labels = torch.arange(nclasses, device=embeddings.device).view(1, nclasses).repeat(support_size, 1).view(support_size, nclasses)
64 | unlabeled_labels = nclasses * torch.ones(query_size * nclasses, dtype=support_labels.dtype, device=support_labels.device).view(query_size, nclasses)
65 | labels = torch.cat([support_labels, unlabeled_labels], 0).view(-1)
66 | logits = self.label_propagation(embeddings, labels, nclasses)
67 | logits = logits.view(-1, nclasses, nclasses)[support_size:(support_size + query_size), ...].view(-1, nclasses)
68 |
69 | elif self.exp_dict["distance_type"] == "prototypical":
70 | embeddings = embeddings.view(-1, nclasses, c)
71 | support_embeddings = embeddings[:support_size]
72 | query_embeddings = embeddings[support_size:]
73 | logits = prototype_distance(support_embeddings.view(-1, c),
74 | query_embeddings.view(-1, c),
75 | support_labels.view(-1))
76 | return logits
77 |
78 | def train_on_batch(self, batch):
79 | """Computes the loss of a batch
80 |
81 | Args:
82 | batch (tuple): Inputs and labels
83 |
84 | Returns:
85 | loss: Loss on the batch
86 | """
87 | x, y, r = batch
88 | y = y.cuda(non_blocking=True).view(-1)
89 | r = r.cuda(non_blocking=True).view(-1)
90 | k, n, c, h, w = x.size()
91 | x = x.view(n*k, c, h, w).cuda(non_blocking=True)
92 | if self.ngpu > 1:
93 | embeddings = self.parallel_model(x, is_support=True)
94 | else:
95 | embeddings = self.model(x, is_support=True)
96 | b, c = embeddings.size()
97 |
98 | loss = 0
99 | if self.exp_dict["rotation_weight"] > 0:
100 | rot = self.model.classifier_rot(embeddings)
101 | loss += F.cross_entropy(rot, r) * self.exp_dict["rotation_weight"]
102 |
103 | if self.exp_dict["embedding_prop"] == True:
104 | embeddings = self.embedding_propagation(embeddings)
105 | logits = self.model.classifier(embeddings)
106 | loss += F.cross_entropy(logits, y) * self.exp_dict["cross_entropy_weight"]
107 | return loss
108 |
109 | def val_on_batch(self, batch):
110 | """Computes the loss and accuracy on a validation batch
111 |
112 | Args:
113 | batch (dict): Episode dict
114 |
115 | Returns:
116 | tuple: loss and accuracy of the episode
117 | """
118 | nclasses = batch["nclasses"]
119 | query_size = batch["query_size"]
120 |
121 | logits = self.predict_on_batch(batch)
122 |
123 | query_labels = torch.arange(nclasses, device=logits.device).view(1, nclasses).repeat(query_size, 1).view(-1)
124 | loss = F.cross_entropy(logits, query_labels)
125 | accuracy = float(logits.max(-1)[1].eq(query_labels).float().mean())
126 |
127 | return loss, accuracy
128 |
129 | def predict_on_batch(self, batch):
130 | """Computes the logits of an episode
131 |
132 | Args:
133 | batch (dict): episode dict
134 |
135 | Returns:
136 | tensor: logits for the queries of the current episode
137 | """
138 | nclasses = batch["nclasses"]
139 | support_size = batch["support_size"]
140 | query_size = batch["query_size"]
141 | k = (support_size + query_size)
142 | c = batch["channels"]
143 | h = batch["height"]
144 | w = batch["width"]
145 |
146 | tx = batch["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True)
147 | vx = batch["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True)
148 | x = torch.cat([tx, vx], 0)
149 | x = x.view(-1, c, h, w).cuda(non_blocking=True)
150 |
151 | if self.ngpu > 1:
152 | embeddings = self.parallel_model(x, is_support=True)
153 | else:
154 | embeddings = self.model(x, is_support=True)
155 | b, c = embeddings.size()
156 |
157 | return self.get_logits(embeddings, support_size, query_size, nclasses)
158 |
159 | def train_on_loader(self, data_loader, max_iter=None, debug_plot_path=None):
160 | """Iterate over the training set
161 |
162 | Args:
163 | data_loader (torch.utils.data.DataLoader): a pytorch dataloader
164 | max_iter (int, optional): Max number of iterations if the end of the dataset is not reached. Defaults to None.
165 |
166 | Returns:
167 | metrics: dictionary with metrics of the training set
168 | """
169 | self.model.train()
170 | train_loss_meter = BasicMeter.get("train_loss").reset()
171 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
172 | for batch_idx, batch in enumerate(data_loader):
173 | self.optimizer.zero_grad()
174 | loss = self.train_on_batch(batch)
175 | train_loss_meter.update(float(loss), 1)
176 | loss.backward()
177 | self.optimizer.step()
178 | if batch_idx + 1 == max_iter:
179 | break
180 | return {"train_loss": train_loss_meter.mean()}
181 |
182 |
183 | @torch.no_grad()
184 | def val_on_loader(self, data_loader, max_iter=None):
185 | """Iterate over the validation set
186 |
187 | Args:
188 | data_loader: iterable validation data loader
189 | max_iter: max number of iterations to perform if the end of the dataset is not reached
190 | """
191 | self.model.eval()
192 | val_loss_meter = BasicMeter.get("val_loss").reset()
193 | val_accuracy_meter = BasicMeter.get("val_accuracy").reset()
194 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
195 | for batch_idx, _data in enumerate(data_loader):
196 | batch = _data[0]
197 | loss, accuracy = self.val_on_batch(batch)
198 | val_loss_meter.update(float(loss), 1)
199 | val_accuracy_meter.update(float(accuracy), 1)
200 | loss = BasicMeter.get(self.exp_dict["target_loss"], recursive=True, force=False).mean()
201 | self.scheduler.step(loss) # update the learning rate monitor
202 | return {"val_loss": val_loss_meter.mean(), "val_accuracy": val_accuracy_meter.mean()}
203 |
204 | @torch.no_grad()
205 | def test_on_loader(self, data_loader, max_iter=None):
206 | """Iterate over the validation set
207 |
208 | Args:
209 | data_loader: iterable validation data loader
210 | max_iter: max number of iterations to perform if the end of the dataset is not reached
211 | """
212 | self.model.eval()
213 | test_loss_meter = BasicMeter.get("test_loss").reset()
214 | test_accuracy_meter = BasicMeter.get("test_accuracy").reset()
215 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
216 | for batch_idx, _data in enumerate(data_loader):
217 | batch = _data[0]
218 | loss, accuracy = self.val_on_batch(batch)
219 | test_loss_meter.update(float(loss), 1)
220 | test_accuracy_meter.update(float(accuracy), 1)
221 | return {"test_loss": test_loss_meter.mean(), "test_accuracy": test_accuracy_meter.mean()}
222 |
223 | def get_state_dict(self):
224 | """Obtains the state dict of this model including optimizer, scheduler, etc
225 |
226 | Returns:
227 | dict: state dict
228 | """
229 | ret = {}
230 | ret["optimizer"] = self.optimizer.state_dict()
231 | ret["model"] = self.model.state_dict()
232 | ret["scheduler"] = self.scheduler.state_dict()
233 | return ret
234 |
235 | def load_state_dict(self, state_dict):
236 | """Loads the state of the model
237 |
238 | Args:
239 | state_dict (dict): The state to load
240 | """
241 | self.optimizer.load_state_dict(state_dict["optimizer"])
242 | self.model.load_state_dict(state_dict["model"])
243 | self.scheduler.load_state_dict(state_dict["scheduler"])
244 |
245 | def get_lr(self):
246 | ret = {}
247 | for i, param_group in enumerate(self.optimizer.param_groups):
248 | ret["current_lr_%d" % i] = float(param_group["lr"])
249 | return ret
250 |
251 | def is_end_of_training(self):
252 | lr = self.get_lr()["current_lr_0"]
253 | return lr <= (self.exp_dict["lr"] * self.exp_dict["min_lr_decay"])
--------------------------------------------------------------------------------
/src/models/ssl_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | Few-Shot Parallel: trains a model as a series of tasks computed in parallel on multiple GPUs
3 |
4 | """
5 | import copy
6 | import numpy as np
7 | import os
8 | from .base_ssl import oracle
9 | from scipy.stats import sem, t
10 | import torch
11 | import pandas as pd
12 | import torch.nn.functional as F
13 | import tqdm
14 | from src.tools.meters import BasicMeter
15 | from src.modules.distances import standarized_label_prop, _propagate, prototype_distance
16 | from .base_wrapper import BaseWrapper
17 | from haven import haven_utils as hu
18 | import glob
19 | from scipy.stats import sem, t
20 | import shutil as sh
21 | from .base_ssl import selection_methods as sm
22 | from .base_ssl import predict_methods as pm
23 | from embedding_propagation import EmbeddingPropagation
24 |
25 | class SSLWrapper(BaseWrapper):
26 | """Trains a model using an episodic scheme on multiple GPUs"""
27 |
28 | def __init__(self, model, n_classes, exp_dict, pretrained_savedir=None, savedir_base=None):
29 | """ Constructor
30 | Args:
31 | model: architecture to train
32 | exp_dict: reference to dictionary with the global state of the application
33 | """
34 | super().__init__()
35 | self.model = model
36 | self.exp_dict = exp_dict
37 | self.ngpu = self.exp_dict["ngpu"]
38 | self.predict_method = exp_dict['predict_method']
39 |
40 | self.model.add_classifier(n_classes, modalities=0)
41 | self.nclasses = n_classes
42 |
43 | best_accuracy = -1
44 | self.label = exp_dict['model']['backbone'] + "_" + exp_dict['dataset_test'].split('_')[1].replace('-imagenet','')
45 | print('=============')
46 | print('dataset:', exp_dict["dataset_train"].split('_')[-1])
47 | print('backbone:', exp_dict['model']["backbone"])
48 | print('n_classes:', exp_dict['n_classes'])
49 | print('support_size_train:', exp_dict['support_size_train'])
50 |
51 | if pretrained_savedir is None:
52 | # find the best checkpoint
53 | savedir_base = exp_dict["finetuned_weights_root"]
54 | if not os.path.exists(savedir_base):
55 | raise ValueError("Please set the variable named \
56 | 'finetuned_weights_root' with the path of the folder \
57 | with the episodic finetuning experiments")
58 | for exp_hash in os.listdir(savedir_base):
59 | base_path = os.path.join(savedir_base, exp_hash)
60 | exp_dict_path = os.path.join(base_path, 'exp_dict.json')
61 | if not os.path.exists(exp_dict_path):
62 | continue
63 | loaded_exp_dict = hu.load_json(exp_dict_path)
64 | pkl_path = os.path.join(base_path, 'score_list_best.pkl')
65 |
66 | if exp_dict['support_size_train'] in [2,3,4]:
67 | support_size_needed = 1
68 | else:
69 | support_size_needed = exp_dict['support_size_train']
70 |
71 | if (loaded_exp_dict["model"]["name"] == 'finetuning' and
72 | loaded_exp_dict["dataset_train"].split('_')[-1] == exp_dict["dataset_train"].split('_')[-1] and
73 | loaded_exp_dict["model"]["backbone"] == exp_dict['model']["backbone"] and
74 | loaded_exp_dict['n_classes'] == exp_dict["n_classes"] and
75 | loaded_exp_dict['support_size_train'] == support_size_needed,
76 | loaded_exp_dict["embedding_prop"] == exp_dict["embedding_prop"]):
77 |
78 | model_path = os.path.join(base_path, 'checkpoint_best.pth')
79 |
80 | try:
81 | print("Attempting to load ", model_path)
82 | accuracy = hu.load_pkl(pkl_path)[-1]["val_accuracy"]
83 | self.model.load_state_dict(torch.load(model_path)['model'], strict=False)
84 | if accuracy > best_accuracy:
85 | best_path = os.path.join(base_path, 'checkpoint_best.pth')
86 | best_accuracy = accuracy
87 | except Exception as e:
88 | print(e)
89 |
90 | assert(best_accuracy > 0.1)
91 | print("Finetuning %s with original accuracy : %f" %(base_path, best_accuracy))
92 | self.model.load_state_dict(torch.load(best_path)['model'], strict=False)
93 | self.best_accuracy = best_accuracy
94 | self.acc_sum = 0.0
95 | self.n_count = 0
96 | self.model.cuda()
97 |
98 | def get_embeddings(self, embeddings, support_size, query_size, nclasses):
99 | b, c = embeddings.size()
100 |
101 | if self.exp_dict["embedding_prop"] == True:
102 | embeddings = EmbeddingPropagation()(embeddings)
103 | return embeddings.view(b, c)
104 |
105 | def get_episode_dict(self, batch):
106 | nclasses = batch["nclasses"]
107 | support_size = batch["support_size"]
108 | query_size = batch["query_size"]
109 | k = (support_size + query_size)
110 | c = batch["channels"]
111 | h = batch["height"]
112 | w = batch["width"]
113 |
114 | tx = batch["support_set"].view(support_size, nclasses, c, h, w).cuda(non_blocking=True)
115 | vx = batch["query_set"].view(query_size, nclasses, c, h, w).cuda(non_blocking=True)
116 | ux = batch["unlabeled_set"].view(batch["unlabeled_size"], nclasses, c, h, w).cuda(non_blocking=True)
117 | x = torch.cat([tx, vx, ux], 0)
118 | x = x.view(-1, c, h, w).cuda(non_blocking=True)
119 |
120 | if self.ngpu > 1:
121 | features = self.parallel_model(x, is_support=True)
122 | else:
123 | features = self.model(x, is_support=True)
124 |
125 | embeddings = self.get_embeddings(features,
126 | support_size,
127 | query_size+
128 | batch['unlabeled_size'],
129 | nclasses) # (b, channels)
130 |
131 | uniques = np.unique(batch['targets'])
132 | labels = torch.zeros(batch['targets'].shape[0])
133 | for i, u in enumerate(uniques):
134 | labels[batch['targets']==u] = i
135 |
136 | ## perform ssl
137 | # 1. indices
138 | episode_dict = {}
139 | ns = support_size*nclasses
140 | nq = query_size*nclasses
141 | episode_dict["support"] = {'samples':embeddings[:ns],
142 | 'labels':labels[:ns]}
143 | episode_dict["query"] = {'samples':embeddings[ns:ns+nq],
144 | 'labels':labels[ns:ns+nq]}
145 | episode_dict["unlabeled"] = {'samples':embeddings[ns+nq:]}
146 | # batch["support_so_far"] = {'samples':embeddings,
147 | # 'labels':labels}
148 |
149 |
150 | for k, v in episode_dict.items():
151 | episode_dict[k]['samples'] = episode_dict[k]['samples'].cpu().numpy()
152 | if 'labels' in episode_dict[k]:
153 | episode_dict[k]['labels'] = episode_dict[k]['labels'].cpu().numpy().astype(int)
154 | return episode_dict
155 |
156 | def predict_on_batch(self, episode_dict, support_size_max=None):
157 | ind_selected = sm.get_indices(selection_method="ssl",
158 | episode_dict=episode_dict,
159 | support_size_max=support_size_max)
160 | episode_dict = update_episode_dict(ind_selected, episode_dict)
161 | pred_labels = pm.get_predictions(predict_method=self.predict_method,
162 | episode_dict=episode_dict)
163 |
164 | return pred_labels
165 |
166 | def val_on_batch(self, batch):
167 | # if self.exp_dict['ora']
168 | if self.exp_dict.get("pretrained_weights_root") == 'hdf5':
169 | episode_dict = self.sampler.sample_episode(int(self.exp_dict['support_size_test']),
170 | self.exp_dict['query_size_test'],
171 | self.exp_dict['unlabeled_size_test'],
172 | apply_ten_flag=self.exp_dict.get("apply_ten_flag"))
173 | else:
174 | episode_dict = self.get_episode_dict(batch)
175 | episode_dict["support_so_far"] = copy.deepcopy(episode_dict["support"])
176 | episode_dict["n_classes"] = 5
177 |
178 | pred_labels = self.predict_on_batch(episode_dict, support_size_max=self.exp_dict['unlabeled_size_test']*self.exp_dict['classes_test'])
179 | accuracy = oracle.compute_acc(pred_labels=pred_labels,
180 | true_labels=episode_dict["query"]["labels"])
181 |
182 | # query_labels = episode_dict["query"]["labels"]
183 | # accuracy = float((pred_labels == query_labels.cuda()).float().mean())
184 |
185 | self.acc_sum += accuracy
186 | self.n_count += 1
187 | return -1, accuracy
188 |
189 | @torch.no_grad()
190 | def test_on_loader(self, data_loader, max_iter=None):
191 | """Iterate over the validation set
192 |
193 | Args:
194 | data_loader: iterable validation data loader
195 | max_iter: max number of iterations to perform if the end of the dataset is not reached
196 | """
197 | self.model.eval()
198 |
199 | test_accuracy_meter = BasicMeter.get("test_accuracy").reset()
200 | test_accuracy = []
201 | # Iterate through tasks, each iteration loads n tasks, with n = number of GPU
202 | # dirname = os.path.split(self.exp_dict["pretrained_weights_root"])[-1]
203 | with tqdm.tqdm(total=len(data_loader)) as pbar:
204 | for batch_all in data_loader:
205 | batch = batch_all[0]
206 | loss, accuracy = self.val_on_batch(batch)
207 |
208 | test_accuracy_meter.update(float(accuracy), 1)
209 | test_accuracy.append(float(accuracy))
210 |
211 | string = ("'%s' - ssl: %.3f" %
212 | (self.label,
213 | # dirname,
214 | test_accuracy_meter.mean()))
215 | # print(string)
216 | pbar.update(1)
217 | pbar.set_description(string)
218 |
219 | confidence = 0.95
220 | n = len(test_accuracy)
221 | std_err = sem(np.array(test_accuracy))
222 | h = std_err * t.ppf((1 + confidence) / 2, n - 1)
223 | return {"test_loss": -1,
224 | "ssl_accuracy": test_accuracy_meter.mean(),
225 | "ssl_confidence": h,
226 | 'finetuned_accuracy': self.best_accuracy}
227 |
228 | def update_episode_dict(ind, episode_dict):
229 | # 1. update supports so far
230 | selected_samples = episode_dict["unlabeled"]["samples"][ind]
231 | selected_labels = episode_dict["unlabeled"]["labels"][ind]
232 |
233 | selected_support_dict = {"samples": selected_samples, "labels": selected_labels}
234 |
235 | for k, v in episode_dict["support_so_far"].items():
236 | episode_dict["support_so_far"][k] = np.concatenate([v, selected_support_dict[k]], axis=0)
237 |
238 | # 2. update unlabeled samples
239 | n_unlabeled = episode_dict["unlabeled"]["samples"].shape[0]
240 | ind_rest = np.setdiff1d(np.arange(n_unlabeled), ind)
241 |
242 | new_unlabeled_dict = {}
243 | for k, v in episode_dict["unlabeled"].items():
244 | new_unlabeled_dict[k] = v[ind_rest]
245 |
246 | episode_dict["unlabeled"] = new_unlabeled_dict
247 |
248 | return episode_dict
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/modules/__init__.py
--------------------------------------------------------------------------------
/src/modules/distances.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | def _make_aligned_labels(inputs):
5 | batch, n_sample_pc, n_classes, z_dim = inputs.shape
6 | identity = torch.eye(n_classes, dtype=inputs.dtype, device=inputs.device)
7 | return identity[None, None, :, :].expand(batch, n_sample_pc, -1, -1).contiguous()
8 | def generalized_pw_sq_dist(data, d_type="euclidean"):
9 | batch, samples, z_dim = data.size()
10 | if d_type == "euclidean":
11 | return torch.sum((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim)
12 | elif d_type == "l1":
13 | return torch.mean(torch.abs(data[:, :, None, :] - data[:, None, :, :]), dim=3)
14 | elif d_type == "stable_euclidean":
15 | return torch.sqrt(1e-6 + torch.mean((data[:, :, None, :] - data[:, None, :, :]) ** 2, dim=3) / np.sqrt(z_dim))
16 | elif d_type == "cosine":
17 | data = F.normalize(data, dim=2)
18 | return torch.bmm(data, data.transpose(2, 1))
19 | else:
20 | raise ValueError("Distance type not recognized")
21 | def standarized_label_prop(embeddings,
22 | labels,
23 | gaussian_scale=1, alpha=0.5,
24 | weights=None,
25 | apply_log=False,
26 | scale_bound="",
27 | standarize="all",
28 | kernel="rbf",
29 | square_root=False,
30 | norm_prop=0,
31 | epsilon=1e-6):
32 | propagator_scale = gaussian_scale
33 | gaussian_scale = 1
34 | if scale_bound == "softplus":
35 | gaussian_scale = 0.01 + F.softplus(gaussian_scale)
36 | alpha = 0.1 + F.softplus(alpha)
37 | elif scale_bound == "square":
38 | gaussian_scale = 1e-4 + gaussian_scale ** 2
39 | alpha = 0.1 + alpha ** 2
40 | elif scale_bound == "convex_relu":
41 | #gaussian_scale = gaussian_scale ** 2
42 | alpha = F.relu(alpha) + 0.1
43 | elif scale_bound == "convex_square":
44 | # gaussian_scale = gaussian_scale ** 2
45 | alpha = 0.1 + alpha ** 2
46 | elif scale_bound == "relu":
47 | gaussian_scale = F.relu(gaussian_scale) + 0.01
48 | alpha = F.relu(alpha) + 0.1
49 | elif scale_bound == "constant":
50 | gaussian_scale = 1
51 | alpha = 1
52 | elif scale_bound == "alpha_square":
53 | alpha = 0.1 + F.relu(alpha)
54 | # Compute the pairwise distance between the examples of the sample and query sets
55 | # XXX: labels are set to a constant for the query set
56 | sq_dist = generalized_pw_sq_dist(embeddings, "euclidean")
57 | if square_root:
58 | sq_dist = (sq_dist + epsilon).sqrt()
59 | if standarize == "all":
60 | mask = sq_dist != 0
61 | # sq_dist = sq_dist - sq_dist[mask].mean()
62 | sq_dist = sq_dist / sq_dist[mask].std()
63 | elif standarize == "median":
64 | mask = sq_dist != 0
65 | gaussian_scale = torch.sqrt(
66 | 0.5 * torch.median(sq_dist[mask]) / torch.log(torch.ones(1, device=sq_dist.device) + sq_dist.size(1)))
67 | elif standarize == "frobenius":
68 | mask = sq_dist != 0
69 | sq_dist = sq_dist / (sq_dist[mask] ** 2).sum().sqrt()
70 | elif standarize == "percentile":
71 | mask = sq_dist != 2
72 | sorted, indices = torch.sort(sq_dist.data[mask])
73 | total = sorted.size(0)
74 | gaussian_scale = sorted[int(total * 0.1)].detach()
75 | if kernel == "rbf":
76 | weights = torch.exp(-sq_dist * gaussian_scale)
77 | elif kernel == "convex_rbf":
78 | scales = torch.linspace(0.1, 10, gaussian_scale.size(0), device=sq_dist.device, dtype=sq_dist.dtype)
79 | weights = torch.exp(-sq_dist.unsqueeze(1) * scales.view(1, -1, 1, 1))
80 | weights = (weights * F.softmax(gaussian_scale.view(1, -1, 1, 1), dim=1)).sum(1)
81 | # checknan(timessoftmax=weights)
82 | elif kernel == "euclidean":
83 | # Compute similarity between the examples -- inversely proportional to distance
84 | weights = 1 / (gaussian_scale + sq_dist)
85 | elif kernel == "softmax":
86 | weights = F.softmax(-sq_dist / gaussian_scale, -1)
87 | mask = torch.eye(weights.size(1), dtype=torch.bool, device=weights.device)[None, :, :]
88 | weights = weights * (~mask).float()
89 | logits, propagator = global_consistency(weights, labels, alpha=alpha, norm_prop=norm_prop, scale=propagator_scale)
90 | if apply_log:
91 | logits = torch.log(logits + epsilon)
92 | return logits, propagator
93 | def global_consistency(weights, labels, alpha=1, norm_prop=0, scale=1):
94 | """Implements D. Zhou et al. "Learning with local and global consistency". (Same as in TPN paper but without bug)
95 | Args:
96 | weights: Tensor of shape (batch, n, n). Expected to be exp( -d^2/s^2 ), where d is the euclidean distance and
97 | s the scale parameter.
98 | labels: Tensor of shape (batch, n, n_classes)
99 | alpha: Scaler, acts as a smoothing factor
100 | Returns:
101 | Tensor of shape (batch, n, n_classes) representing the logits of each classes
102 | """
103 | n = weights.shape[1]
104 | identity = torch.eye(n, dtype=weights.dtype, device=weights.device)[None, :, :]
105 | #weights = weights * (1. - identity) # zero out diagonal
106 | isqrt_diag = 1. / torch.sqrt(1e-4 + torch.sum(weights, dim=2))
107 | # checknan(laplacian=isqrt_diag)
108 | S = weights * isqrt_diag[:, None, :] * isqrt_diag[:, :, None]
109 | # checknan(normalizedlaplacian=S)
110 | propagator = identity - alpha * S
111 | propagator = torch.inverse(propagator)
112 | # checknan(propagator=propagator)
113 | if norm_prop > 0:
114 | propagator = F.normalize(propagator, p=norm_prop, dim=-1)
115 | elif norm_prop < 0:
116 | propagator = F.softmax(propagator, dim=-1)
117 | propagator = propagator * scale
118 | return _propagate(labels, propagator, scaling=1), propagator
119 | def _propagate(labels, propagator, scaling=1.):
120 | return torch.matmul(propagator, labels) * scaling
121 | def prototype_distance(support_set, query_set, labels, unlabeled_set=None):
122 | """Computes distance from each element of the query set to prototypes in the sample set.
123 | Args:
124 | sample_set: Tensor of shape (batch, n_classes, n_sample_per_classes, z_dim) containing the representation z of
125 | each images.
126 | query_set: Tensor of shape (batch, n_classes, n_query_per_classes, z_dim) containing the representation z of
127 | each images.
128 | unlabeled_set: Tensor of shape (batch, n_classes, n_unlabeled_per_classes, z_dim) containing the representation
129 | z of each images.
130 | Returns:
131 | Tensor of shape (batch, n_total_query, n_classes) containing the similarity between each pair of query,
132 | prototypes, for each task.
133 | """
134 | n_queries, channels = query_set.size()
135 | n_support, channels = support_set.size()
136 | support_set = support_set.view(n_support, 1, channels)
137 | way = int(labels.data.max()) + 1
138 | one_hot_labels = torch.zeros(n_support, way, 1, dtype=support_set.dtype, device=support_set.device)
139 | one_hot_labels.scatter_(1, labels.view(n_support, 1, 1), 1)
140 | total_per_class = one_hot_labels.sum(0, keepdim=True)
141 | prototypes = (support_set * one_hot_labels).sum(0) / total_per_class
142 | prototypes = prototypes.view(1, way, channels)
143 | query_set = query_set.view(n_queries, 1, channels)
144 | d = query_set - prototypes
145 | return -torch.sum(d ** 2, 2) / np.sqrt(channels)
--------------------------------------------------------------------------------
/src/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ServiceNow/embedding-propagation/c51e7ac591459052b9c56b1fe1c8d450b3d90b3d/src/tools/__init__.py
--------------------------------------------------------------------------------
/src/tools/meters.py:
--------------------------------------------------------------------------------
1 | class BasicMeter(object):
2 | """
3 | Basic class to monitor scores
4 | """
5 | meters = {}
6 | submeters = {}
7 |
8 | @staticmethod
9 | def get(name, recursive=False, tag=None, force=True):
10 | """ Creates a new meter or returns an already existing one with the given name.
11 |
12 | Args:
13 | name: meter name
14 |
15 | Returns: BasicMeter instance
16 |
17 | """
18 |
19 | if name not in BasicMeter.meters:
20 | if recursive:
21 | for supername, meter in BasicMeter.meters.items():
22 | for subname in meter.submeters:
23 | if "%s_%s" %(supername, subname) == name:
24 | return meter.submeters[subname]
25 | if force:
26 | BasicMeter.meters[name] = BasicMeter(name)
27 | else:
28 | raise ModuleNotFoundError
29 |
30 | if tag is not None:
31 | if force:
32 | return BasicMeter.meters[name].get_submeter(tag)
33 | else:
34 | raise ModuleNotFoundError
35 |
36 | return BasicMeter.meters[name]
37 |
38 | @staticmethod
39 |
40 |
41 | @staticmethod
42 | def dict():
43 | """ Obtain meters in a dictionary
44 |
45 | Returns: dictionary of BasicMeter
46 |
47 | """
48 | return BasicMeter.meters
49 |
50 | def __init__(self, name=""):
51 | """
52 | Constructor
53 | """
54 | self.count = 0.
55 | self.total = 0.
56 | self.name = name
57 | self.submeters = {}
58 |
59 | def get_submeter(self, name):
60 | if name not in self.submeters:
61 | name_ = "%s_%s" %(self.name, name)
62 | self.submeters[name] = BasicMeter(name_)
63 | return self.submeters[name]
64 |
65 | def update(self, v, count, tag=None):
66 | """ Update meter values
67 |
68 | Args:
69 | v: current value
70 | count: N if value is the average of N values.
71 |
72 | Returns: self
73 |
74 | """
75 | self.count += count
76 | self.total += v
77 |
78 | if tag is not None:
79 | if not isinstance(tag, list):
80 | tag = [tag]
81 | for t in tag:
82 | self.get_submeter(t).update(v, count)
83 | return self
84 |
85 | def mean(self, tag=None, recursive=False):
86 | """ Computes the mean of the current values
87 |
88 | Returns: mean of the current values (float)
89 |
90 | """
91 | if recursive:
92 | try:
93 | ret = { self.name: self.total / self.count }
94 | except ZeroDivisionError:
95 | return { self.name: 0 }
96 | for submeter in self.submeters:
97 | ret.update(self.get_submeter(submeter).mean(recursive=True))
98 | return ret
99 | if tag is not None:
100 | return self.get_submeter(tag).mean()
101 | else:
102 | return self.total / self.count
103 |
104 | def reset(self):
105 | """ Resets the meter.
106 |
107 | Returns: self
108 |
109 | """
110 | for submeter in self.submeters:
111 | self.submeters[submeter].reset()
112 | self.count = 0
113 | self.total = 0
114 |
115 | return self
116 |
117 |
--------------------------------------------------------------------------------
/src/tools/plot_episode.py:
--------------------------------------------------------------------------------
1 | import pylab
2 |
3 | def plot_episode(episode, classes_first=True):
4 | sample_set = episode["support_set"].cpu()
5 | query_set = episode["query_set"].cpu()
6 | support_size = episode["support_size"]
7 | query_size = episode["query_size"]
8 | if not classes_first:
9 | sample_set = sample_set.permute(1, 0, 2, 3, 4)
10 | query_set = query_set.permute(1, 0, 2, 3, 4)
11 | n, support_size, c, h, w = sample_set.size()
12 | n, query_size, c, h, w = query_set.size()
13 | sample_set = ((sample_set / 2 + 0.5) * 255).numpy().astype('uint8').transpose((0, 3, 1, 4, 2)).reshape((n *h, support_size * w, c))
14 | pylab.imsave('support_set.png', sample_set)
15 | query_set = ((query_set / 2 + 0.5) * 255).numpy().astype('uint8').transpose((0, 3, 1, 4, 2)).reshape((n *h, query_size * w, c))
16 | pylab.imsave('query_set.png', query_set)
17 | # pylab.imshow(query_set)
18 | # pylab.title("query_set")
19 | # pylab.show()
20 | # pylab.savefig('query_set.png')
21 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataloader import default_collate
2 |
3 | def get_collate(name):
4 | if name == "identity":
5 | return lambda x: x
6 | else:
7 | return default_collate
--------------------------------------------------------------------------------
/trainval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import pandas as pd
4 | import sys
5 | import os
6 | from torch import nn
7 | from torch.nn import functional as F
8 | import tqdm
9 | import pprint
10 | from src import utils as ut
11 | import torchvision
12 | import numpy as np
13 |
14 | from src import datasets, models
15 | from src.models import backbones
16 | from torch.utils.data import DataLoader
17 | import exp_configs
18 | from torch.utils.data.sampler import RandomSampler
19 |
20 | from haven import haven_utils as hu
21 | from haven import haven_results as hr
22 | from haven import haven_chk as hc
23 | from haven import haven_jupyter as hj
24 |
25 |
26 | def trainval(exp_dict, savedir_base, datadir, reset=False,
27 | num_workers=0, pretrained_weights_dir=None):
28 | # bookkeeping
29 | # ---------------
30 |
31 | # get experiment directory
32 | exp_id = hu.hash_dict(exp_dict)
33 | savedir = os.path.join(savedir_base, exp_id)
34 |
35 | if reset:
36 | # delete and backup experiment
37 | hc.delete_experiment(savedir, backup_flag=True)
38 |
39 | # create folder and save the experiment dictionary
40 | os.makedirs(savedir, exist_ok=True)
41 | hu.save_json(os.path.join(savedir, 'exp_dict.json'), exp_dict)
42 | pprint.pprint(exp_dict)
43 | print('Experiment saved in %s' % savedir)
44 |
45 | # load datasets
46 | # ==========================
47 | train_set = datasets.get_dataset(dataset_name=exp_dict["dataset_train"],
48 | data_root=os.path.join(datadir, exp_dict["dataset_train_root"]),
49 | split="train",
50 | transform=exp_dict["transform_train"],
51 | classes=exp_dict["classes_train"],
52 | support_size=exp_dict["support_size_train"],
53 | query_size=exp_dict["query_size_train"],
54 | n_iters=exp_dict["train_iters"],
55 | unlabeled_size=exp_dict["unlabeled_size_train"])
56 |
57 | val_set = datasets.get_dataset(dataset_name=exp_dict["dataset_val"],
58 | data_root=os.path.join(datadir, exp_dict["dataset_val_root"]),
59 | split="val",
60 | transform=exp_dict["transform_val"],
61 | classes=exp_dict["classes_val"],
62 | support_size=exp_dict["support_size_val"],
63 | query_size=exp_dict["query_size_val"],
64 | n_iters=exp_dict.get("val_iters", None),
65 | unlabeled_size=exp_dict["unlabeled_size_val"])
66 |
67 | test_set = datasets.get_dataset(dataset_name=exp_dict["dataset_test"],
68 | data_root=os.path.join(datadir, exp_dict["dataset_test_root"]),
69 | split="test",
70 | transform=exp_dict["transform_val"],
71 | classes=exp_dict["classes_test"],
72 | support_size=exp_dict["support_size_test"],
73 | query_size=exp_dict["query_size_test"],
74 | n_iters=exp_dict["test_iters"],
75 | unlabeled_size=exp_dict["unlabeled_size_test"])
76 |
77 | # get dataloaders
78 | # ==========================
79 | train_loader = torch.utils.data.DataLoader(
80 | train_set,
81 | batch_size=exp_dict["batch_size"],
82 | shuffle=True,
83 | num_workers=num_workers,
84 | collate_fn=ut.get_collate(exp_dict["collate_fn"]),
85 | drop_last=True)
86 | val_loader = torch.utils.data.DataLoader(
87 | val_set,
88 | batch_size=1,
89 | shuffle=False,
90 | num_workers=num_workers,
91 | collate_fn=lambda x: x,
92 | drop_last=True)
93 | test_loader = torch.utils.data.DataLoader(
94 | test_set,
95 | batch_size=1,
96 | shuffle=False,
97 | num_workers=num_workers,
98 | collate_fn=lambda x: x,
99 | drop_last=True)
100 |
101 |
102 | # create model and trainer
103 | # ==========================
104 |
105 | # Create model, opt, wrapper
106 | backbone = backbones.get_backbone(backbone_name=exp_dict['model']["backbone"], exp_dict=exp_dict)
107 | model = models.get_model(model_name=exp_dict["model"]['name'], backbone=backbone,
108 | n_classes=exp_dict["n_classes"],
109 | exp_dict=exp_dict,
110 | pretrained_weights_dir=pretrained_weights_dir,
111 | savedir_base=savedir_base)
112 |
113 | # Pretrain or Fine-tune or run SSL
114 | if exp_dict["model"]['name'] == 'ssl':
115 | # runs the SSL experiments
116 | score_list_path = os.path.join(savedir, 'score_list.pkl')
117 | if not os.path.exists(score_list_path):
118 | test_dict = model.test_on_loader(test_loader, max_iter=None)
119 | hu.save_pkl(score_list_path, [test_dict])
120 | return
121 |
122 | # Checkpoint
123 | # -----------
124 | checkpoint_path = os.path.join(savedir, 'checkpoint.pth')
125 | score_list_path = os.path.join(savedir, 'score_list.pkl')
126 |
127 | if os.path.exists(score_list_path):
128 | # resume experiment
129 | model.load_state_dict(hu.torch_load(checkpoint_path))
130 | score_list = hu.load_pkl(score_list_path)
131 | s_epoch = score_list[-1]['epoch'] + 1
132 | else:
133 | # restart experiment
134 | score_list = []
135 | s_epoch = 0
136 |
137 | # Run training and validation
138 | for epoch in range(s_epoch, exp_dict["max_epoch"]):
139 | score_dict = {"epoch": epoch}
140 | score_dict.update(model.get_lr())
141 |
142 | # train
143 | score_dict.update(model.train_on_loader(train_loader))
144 |
145 | # validate
146 | score_dict.update(model.val_on_loader(val_loader))
147 | score_dict.update(model.test_on_loader(test_loader))
148 |
149 | # Add score_dict to score_list
150 | score_list += [score_dict]
151 |
152 | # Report
153 | score_df = pd.DataFrame(score_list)
154 | print(score_df.tail())
155 |
156 | # Save checkpoint
157 | hu.save_pkl(score_list_path, score_list)
158 | hu.torch_save(checkpoint_path, model.get_state_dict())
159 | print("Saved: %s" % savedir)
160 |
161 | if "accuracy" in exp_dict["target_loss"]:
162 | is_best = score_dict[exp_dict["target_loss"]] >= score_df[exp_dict["target_loss"]][:-1].max()
163 | else:
164 | is_best = score_dict[exp_dict["target_loss"]] <= score_df[exp_dict["target_loss"]][:-1].min()
165 |
166 | # Save best checkpoint
167 | if is_best:
168 | hu.save_pkl(os.path.join(savedir, "score_list_best.pkl"), score_list)
169 | hu.torch_save(os.path.join(savedir, "checkpoint_best.pth"), model.get_state_dict())
170 | print("Saved Best: %s" % savedir)
171 |
172 | # Check for end of training conditions
173 | if model.is_end_of_training():
174 | break
175 |
176 |
177 | if __name__ == '__main__':
178 | parser = argparse.ArgumentParser()
179 |
180 | parser.add_argument('-e', '--exp_group_list', nargs='+')
181 | parser.add_argument('-sb', '--savedir_base', required=True)
182 | parser.add_argument('-d', '--datadir', default='')
183 | parser.add_argument('-r', '--reset', default=0, type=int)
184 | parser.add_argument('-ei', '--exp_id', type=str, default=None)
185 | parser.add_argument('-j', '--run_jobs', type=int, default=0)
186 | parser.add_argument('-nw', '--num_workers', default=0, type=int)
187 | parser.add_argument('-p', '--pretrained_weights_dir', type=str, default=None)
188 |
189 | args = parser.parse_args()
190 |
191 | # Collect experiments
192 | # -------------------
193 | if args.exp_id is not None:
194 | # select one experiment
195 | savedir = os.path.join(args.savedir_base, args.exp_id)
196 | exp_dict = hu.load_json(os.path.join(savedir, 'exp_dict.json'))
197 |
198 | exp_list = [exp_dict]
199 |
200 | else:
201 | # select exp group
202 | exp_list = []
203 | for exp_group_name in args.exp_group_list:
204 | exp_list += exp_configs.EXP_GROUPS[exp_group_name]
205 |
206 |
207 | # Run experiments or View them
208 | # ----------------------------
209 | if args.run_jobs:
210 | pass
211 | else:
212 | # run experiments
213 | for exp_dict in exp_list:
214 | # do trainval
215 | trainval(exp_dict=exp_dict,
216 | savedir_base=args.savedir_base,
217 | datadir=args.datadir,
218 | reset=args.reset,
219 | num_workers=args.num_workers,
220 | pretrained_weights_dir=args.pretrained_weights_dir)
--------------------------------------------------------------------------------