├── .gitignore
├── LICENSE
├── README.md
├── backbone.py
├── configs.py
├── data_preparation
├── cars.py
├── input
│ └── .gitignore
├── output
│ └── .gitignore
├── places.py
├── places_cdfsl_subset_16_class_1715_sample_seed_0.json
├── places_plantae_subset_sampler.ipynb
├── plantae.py
└── plantae_cdfsl_subset_69_class.json
├── datasets
├── __init__.py
├── dataloader.py
├── datasets.py
├── sampler.py
├── split.py
├── split_seed_1
│ ├── ChestX_labeled_80.csv
│ ├── ChestX_unlabeled_20.csv
│ ├── CropDisease_labeled_80.csv
│ ├── CropDisease_unlabeled_20.csv
│ ├── EuroSAT_labeled_80.csv
│ ├── EuroSAT_unlabeled_20.csv
│ ├── ISIC_labeled_80.csv
│ ├── ISIC_unlabeled_20.csv
│ ├── cars_labeled_80.csv
│ ├── cars_unlabeled_20.csv
│ ├── cub_labeled_80.csv
│ ├── cub_unlabeled_20.csv
│ ├── miniImageNet_test_labeled_80.csv
│ ├── miniImageNet_test_unlabeled_20.csv
│ ├── places_labeled_80.csv
│ ├── places_unlabeled_20.csv
│ ├── plantae_labeled_80.csv
│ ├── plantae_unlabeled_20.csv
│ ├── tieredImageNet_test_labeled_80.csv
│ └── tieredImageNet_test_unlabeled_20.csv
└── transforms.py
├── finetune.py
├── finetune.sh
├── io_utils.py
├── methods
├── __init__.py
├── baselinefinetune.py
├── baselinetrain.py
├── boil.py
├── byol.py
├── maml.py
├── meta_template.py
└── protonet.py
├── model
├── __init__.py
├── base.py
├── byol.py
├── classifier_head.py
├── moco.py
├── simclr.py
└── simsiam.py
├── paths.py
├── pretrain.py
├── pretrain.sh
├── pretrain_new_lu20.py
├── requirements.txt
├── scheduler.py
├── setup.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.pt
3 | *.json
4 |
5 | scripts/
6 |
7 | # user-defined files
8 | *.log
9 | *.tar
10 | *.pkl
11 | *.csv
12 | *.JPEG
13 | wandb/
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | wheels/
37 | pip-wheel-metadata/
38 | share/python-wheels/
39 | *.egg-info/
40 | .installed.cfg
41 | *.egg
42 | MANIFEST
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .nox/
58 | .coverage
59 | .coverage.*
60 | .cache
61 | nosetests.xml
62 | coverage.xml
63 | *.cover
64 | *.py,cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Translations
69 | *.mo
70 | *.pot
71 |
72 | # Django stuff:
73 | *.log
74 | local_settings.py
75 | db.sqlite3
76 | db.sqlite3-journal
77 |
78 | # Flask stuff:
79 | instance/
80 | .webassets-cache
81 |
82 | # Scrapy stuff:
83 | .scrapy
84 |
85 | # Sphinx documentation
86 | docs/_build/
87 |
88 | # PyBuilder
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | .python-version
100 |
101 | # pipenv
102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
105 | # install all needed dependencies.
106 | #Pipfile.lock
107 |
108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109 | __pypackages__/
110 |
111 | # Celery stuff
112 | celerybeat-schedule
113 | celerybeat.pid
114 |
115 | # SageMath parsed files
116 | *.sage.py
117 |
118 | # Environments
119 | .env
120 | .venv
121 | env/
122 | venv/
123 | ENV/
124 | env.bak/
125 | venv.bak/
126 |
127 | # Spyder project settings
128 | .spyderproject
129 | .spyproject
130 |
131 | # Rope project settings
132 | .ropeproject
133 |
134 | # mkdocs documentation
135 | /site
136 |
137 | # mypy
138 | .mypy_cache/
139 | .dmypy.json
140 | dmypy.json
141 |
142 | # Pyre type checker
143 | .pyre/
144 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Official] Understanding Cross-Domain Few-Shot Learning Based on Domain Similarity and Few-shot Difficulty
2 |
3 | This repo contains the implementation of our [paper](https://arxiv.org/abs/2202.01339) accepted at NeurIPS 2022.
4 |
5 | ## Abstract
6 | Cross-domain few-shot learning (CD-FSL) has drawn increasing attention for handling large differences between the source and target domains--an important concern in real-world scenarios. To overcome these large differences, recent works have considered exploiting small-scale unlabeled data from the target domain during the pre-training stage. This data enables self-supervised pre-training on the target domain, in addition to supervised pre-training on the source domain. In this paper, we empirically investigate which pre-training is preferred based on domain similarity and few-shot difficulty of the target domain. We discover that the performance gain of self-supervised pre-training over supervised pre-training becomes large when the target domain is dissimilar to the source domain, or the target domain itself has low few-shot difficulty. We further design two pre-training schemes, mixed-supervised and two-stage learning, that improve performance. In this light, we present six findings for CD-FSL, which are supported by extensive experiments and analyses on three source and eight target benchmark datasets with varying levels of domain similarity and few-shot difficulty.
7 |
8 | ## Table of Contents
9 |
10 | * [Prerequisites](#prerequisites)
11 | * [Data Preparation](#data-preparation)
12 | * [BSCD-FSL](#bscd-fsl)
13 | * [Cars](#cars)
14 | * [CUB](#cub (caltech-ucsd birds-200-2011))
15 | * [Places](#places)
16 | * [Plantae](#plantae)
17 | * [Usage](#usage)
18 | * [Model Checkpoints](#model-checkpoints)
19 | * [Attribution](#attribution)
20 | * [License](#license)
21 | * [Citation](#citation)
22 | * [Contact](#contact)
23 |
24 | ## Prerequisites
25 |
26 | Our code works on `torch>=1.8`. Install the required Python packages via
27 |
28 | ```sh
29 | pip install -r requirements.txt
30 | ```
31 |
32 | ## Data Preparation
33 |
34 | Prepare and place all dataset folders in `/data/cdfsl/`. You may specify custom locations in `configs.py`.
35 |
36 | ### BSCD-FSL
37 |
38 | Refer to the original BSCD-FSL [repository](https://github.com/IBM/cdfsl-benchmark).
39 |
40 | The dataset folders should be organized in `/data/cdfsl/` as follows:
41 |
42 | ```
43 | CropDiseases/train
44 | ├── Apple___Apple_scab
45 | │ ├── 00075aa8-d81a-4184-8541-b692b78d398a___FREC_Scab 3335.JPG
46 | │ ├── 0208f4eb-45a4-4399-904e-989ac2c6257c___FREC_Scab 3037.JPG
47 |
48 | EuroSAT
49 | ├── AnnualCrop
50 | │ ├── AnnualCrop_1.jpg
51 | │ ├── AnnualCrop_2.jpg
52 |
53 | ISIC
54 | ├── ATTRIBUTION.txt
55 | ├── ISIC2018_Task3_Training_GroundTruth.csv
56 | ├── ISIC2018_Task3_Training_LesionGroupings.csv
57 | ├── ISIC_0024306.jpg
58 | ├── ISIC_0024307.jpg
59 |
60 | chestX/images
61 | ├── 00000001_000.png
62 | ├── 00000001_001.png
63 | ```
64 |
65 | Note: `chestX/images/` should contain **all** images from the ChestX dataset (the dataset archive provided online will
66 | typically split these images across multiple folders).
67 |
68 | ### Cars
69 |
70 | https://ai.stanford.edu/~jkrause/cars/car_dataset.html
71 |
72 | We use all images from both training and test sets for CD-FSL experiments.
73 | For convenience, we pre-process the data such that each image goes into its respective class folder.
74 |
75 | 1. Download [`car_ims.tgz`](http://ai.stanford.edu/~jkrause/car196/car_ims.tgz) (the tar of all images) and [`cars_annos.mat`](http://ai.stanford.edu/~jkrause/car196/cars_annos.mat) (all bounding boxes and labels for both training and test).
76 | 2. Copy `cars_annos.mat` and unzip `car_ims.tgz` into `./data_preparation/input/`. The directory should contain the following:
77 | ```
78 | data_preparation/input
79 | ├── cars_annos.mat
80 | ├── car_ims
81 | │ ├── 000001.jpg
82 | │ ├── 000002.jpg
83 | ```
84 | 3. Run `./data_preparation/cars.py`, to generate the cars dataset folder at `./data_preparation/output/cars_cdfsl/`.
85 | 4. Move the `cars_cdfsl` directory to `/data/cdfsl/`. You may specify a custom location in `configs.py`.
86 |
87 | ### CUB (Caltech-UCSD Birds-200-2011)
88 |
89 | http://www.vision.caltech.edu/datasets/cub_200_2011/
90 |
91 | We use all images for CD-FSL experiments.
92 |
93 | 1. Download [`CUB_200_2011.tgz`](https://data.caltech.edu/records/20098).
94 | 2. Unzip the archive and copy the *enclosed* `CUB_200_2011/` folder to `/data/cdfsl/`. You may specify a custom location in `configs.py`. The directory should contain the following:
95 | ```
96 | CUB_200_2011/
97 | ├── attributes/
98 | ├── bounding_boxes.txt
99 | ├── classes.txt
100 | ├── image_class_labels.txt
101 | ├── images/
102 | ├── images.txt
103 | ├── parts/
104 | ├── README
105 | └── train_test_split.txt
106 | ```
107 |
108 | ### Places (Places 205)
109 |
110 | http://places.csail.mit.edu/user/
111 |
112 | Due to the size of the original dataset, we only use a subset of the training set for CD-FSL experiments.
113 | We use 27,440 images from 16 classes. Please refer to the paper for details or refer to the subset sampling code at
114 | `data_prepratation/places_plantae_subset_sampler.ipynb`.
115 |
116 | 1. Download [places365standard_easyformat.tar](http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar).
117 | 2. Unzip the archive into `./data_preparation/input/`. The directory should contain the following:
118 | ```
119 | data_preparation/input/
120 | ├── places365_standard/
121 | │ ├── train/
122 | │ │ ├── airfield/
123 | │ │ ├── airplane_cabin/
124 | │ │ ├── aiport_terminal/
125 | ```
126 | 3. Run `./data_preparation/places.py` to generate the places dataset folder at `./data_preparation/outuput/places_cdfsl/`.
127 | 4. Move the `places_cdfsl` directory to `/data/cdfsl/`. You may specify a custom location in `configs.py`.
128 |
129 | ### Plantae (from iNaturalist 2018)
130 |
131 | https://github.com/visipedia/inat_comp/tree/master/2018#Data
132 |
133 | Due to the size of the original dataset, we only use a subset of the training set (of the Plantae super category) for CD-FSL experiments.
134 | We use 26,650 images from 69 classes. Please refer to the paper for details or refer to the subset sampling code at
135 | `data_prepratation/places_plantae_subset_sampler.ipynb`
136 |
137 | 1. Download [`train_val2018.tar.gz`](https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz) (~120GB).
138 | 2. Unzip the archive and copy the enclosed `Plantae/` folder (~43GB) to `./data_preparation/input/`. The directory should contain the following:
139 | ```
140 | data_preparation/input
141 | └── Plantae/
142 | ├── 5221/
143 | ├── 5222/
144 | ├── 5223/
145 | ```
146 | 3. Run `./data_preparation/plantae.py` to generate the plantae dataset folder at `./data_preparation/outuput/plantae_cdfsl`.
147 | 4. Move the `plantae_cdfsl` directory to `/data/cdfsl/`. You may specify a custom location in `configs.py`.
148 |
149 | ## Usage
150 |
151 | The main training scripts are `pretrain.py` and `finetune.py`. Refer to `pretrain.sh` and `finetune.sh` on example
152 | usages for the main results in our paper, e.g., SL, SSL, MSL, two-stage SSL and two-stage MSL.
153 | To see all CLI arguments, refer to `io_utils.py`.
154 |
155 | ## Model Checkpoints
156 |
157 | | Backbone | Pretraining | Augmentation | Model Checkpoints |
158 | | :-------- | :------------: |:---------: |:--------------:|
159 | | ResNet10 | miniImageNet (SL) | default (strong) | [google drive](https://drive.google.com/file/d/1J4weUMgMhdjYe0sbPBNavaf5D7aRkAog/view?usp=sharing) |
160 | | ResNet10 | miniImageNet (SL) | base | [google drive](https://drive.google.com/file/d/11HSAg85vlS67sVsEgd-RYgksnlX61WOj/view?usp=sharing) |
161 | | ResNet18 | tieredImageNet (SL) | default (strong) | [google drive](https://drive.google.com/file/d/1hRbE5VwDvgsKV6E7okOgqJaNOVsSfitP/view?usp=share_link) |
162 | | ResNet18 | tieredImageNet (SL) | base | [google drive](https://drive.google.com/file/d/1-UOzOG-NkqRFXng-Zu3RAbuMhZJZfEhN/view?usp=share_link) |
163 | | ResNet18 | ImageNet (SL) | base | [torchvision](https://pytorch.org/vision/stable/models.html) |
164 |
165 | ## Attribution
166 |
167 | Below, we provide the licenses, attribution, citations, and URL of the datasets considered in our paper (if applicable).
168 |
169 | - **ImageNet** is available at https://image-net.org/. miniImageNet and tieredImageNet are subsets of ImageNet.
170 | 1. *Deng, Jia, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. "Imagenet: A large-scale hierarchical image database." In 2009 IEEE conference on computer vision and pattern recognition, pp. 248-255. Ieee, 2009.*
171 | - **CropDisease** refers to the Plant Disease dataset on Kaggle, licensed under GPL 2. It is available at https://www.kaggle.com/saroz014/plant-disease/.
172 | - **EuroSAT** is available at https://github.com/phelber/eurosat.
173 | 1. *Helber, Patrick, Benjamin Bischke, Andreas Dengel, and Damian Borth. "Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification." IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing 12, no. 7 (2019): 2217-2226.*
174 | 2. *Helber, Patrick, Benjamin Bischke, Andreas Dengel, and Damian Borth. "Introducing eurosat: A novel dataset and deep learning benchmark for land use and land cover classification." In IGARSS 2018-2018 IEEE international geoscience and remote sensing symposium, pp. 204-207. IEEE, 2018.*
175 | - **ISIC** refers to the ISIC 2018 Challenge Task 3 dataset. It is available at https://challenge.isic-archive.com.
176 | 1. *Codella, Noel, Veronica Rotemberg, Philipp Tschandl, M. Emre Celebi, Stephen Dusza, David Gutman, Brian Helba et al. "Skin lesion analysis toward melanoma detection 2018: A challenge hosted by the international skin imaging collaboration (isic)." arXiv preprint arXiv:1902.03368 (2019).*
177 | 2. *Tschandl, Philipp, Cliff Rosendahl, and Harald Kittler. "The HAM10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions." Scientific data 5, no. 1 (2018): 1-9.*
178 | - **ChestX** refers to the Chest-Xray8 provided by the NIH Clinical Center. The dataset is available at https://nihcc.app.box.com/v/ChestXray-NIHCC.
179 | 1. *Wang, Xiaosong, Yifan Peng, Le Lu, Zhiyong Lu, Mohammadhadi Bagheri, and Ronald M. Summers. "Chestx-ray8: Hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2097-2106. 2017.*
180 | - **Places** is licensed under CC BY. It is available at http://places.csail.mit.edu/user.
181 | 1. *Zhou, Bolei, Agata Lapedriza, Aditya Khosla, Aude Oliva, and Antonio Torralba. "Places: A 10 million image database for scene recognition." IEEE transactions on pattern analysis and machine intelligence 40, no. 6 (2017): 1452-1464.*
182 | - **Plantae** refers to the Plantae super category of the iNaturalist 2018 dataset. It is available at https://github.com/visipedia/inat_comp/tree/master/2017.
183 | 1. *Van Horn, Grant, Oisin Mac Aodha, Yang Song, Yin Cui, Chen Sun, Alex Shepard, Hartwig Adam, Pietro Perona, and Serge Belongie. "The inaturalist species classification and detection dataset." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 8769-8778. 2018.*
184 | - **CUB** refers to the Caltech-UCSD Birds-200-2011 dataset, licensed under CC0 (public domain). It is available at https://www.kaggle.com/datasets/veeralakrishna/200-bird-species-with-11788-images.
185 | 1. *Wah, Catherine, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. "The caltech-ucsd birds-200-2011 dataset." (2011).*
186 | - **Cars** is available at https://ai.stanford.edu/~jkrause/cars/car_dataset.html.
187 | 1. *Krause, Jonathan, Michael Stark, Jia Deng, and Li Fei-Fei. "3d object representations for fine-grained categorization." In Proceedings of the IEEE international conference on computer vision workshops, pp. 554-561. 2013.*
188 |
189 | ## License
190 |
191 | Distributed under the MIT License.
192 |
193 | ## Citation
194 |
195 | If you find this repo useful for your research, please consider citing our paper:
196 |
197 | ```
198 | @inproceedings{oh2022understanding,
199 | title={Understanding Cross-Domain Few-Shot Learning Based on Domain Similarity and Few-Shot Difficulty},
200 | author={Oh, Jaehoon and Kim, Sungnyun and Ho, Namgyu and Kim, Jin-Hwa and Song, Hwanjun and Yun, Se-Young},
201 | booktitle={Advances in Neural Information Processing Systems},
202 | year={2022}
203 | }
204 | ```
205 |
206 | ## Contact
207 | * Jaehoon Oh: jhoon.oh@kaist.ac.kr
208 | * Sungnyun Kim: ksn4397@kaist.ac.kr
209 | * Namgyu Ho: itsnamgyu@kaist.ac.kr
210 |
--------------------------------------------------------------------------------
/backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torch import Tensor
4 | from torch.autograd import Variable
5 | import torch.nn as nn
6 | import math
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from torch.hub import load_state_dict_from_url
10 | from torch.distributions import Bernoulli
11 | from torch.nn.utils.weight_norm import WeightNorm
12 |
13 |
14 | def init_layer(L):
15 | # Initialization using fan-in
16 | if isinstance(L, nn.Conv2d):
17 | n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
18 | L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
19 | elif isinstance(L, nn.BatchNorm2d):
20 | L.weight.data.fill_(1)
21 | L.bias.data.fill_(0)
22 |
23 | class distLinear(nn.Module):
24 | def __init__(self, indim, outdim):
25 | super(distLinear, self).__init__()
26 | self.L = nn.Linear( indim, outdim, bias = False)
27 | self.class_wise_learnable_norm = True #See the issue#4&8 in the github
28 | if self.class_wise_learnable_norm:
29 | WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm
30 |
31 | if outdim <=200:
32 | self.scale_factor = 2; #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax
33 | else:
34 | self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes.
35 |
36 | def forward(self, x):
37 | x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x)
38 | x_normalized = x.div(x_norm+ 0.00001)
39 | if not self.class_wise_learnable_norm:
40 | L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data)
41 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
42 | cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
43 | scores = self.scale_factor* (cos_dist)
44 |
45 | return scores
46 |
47 | class Flatten(nn.Module):
48 | def __init__(self):
49 | super(Flatten, self).__init__()
50 |
51 | def forward(self, x):
52 | return x.view(x.size(0), -1)
53 |
54 | # For meta-learning based algorithms (task-specific weight)
55 | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
56 | def __init__(self, in_features, out_features):
57 | super(Linear_fw, self).__init__(in_features, out_features)
58 | self.weight.fast = None #Lazy hack to add fast weight link
59 | self.bias.fast = None
60 |
61 | def forward(self, x):
62 | if self.weight.fast is not None and self.bias.fast is not None:
63 | out = F.linear(x, self.weight.fast, self.bias.fast) #weight.fast (fast weight) is the temporaily adapted weight
64 | else:
65 | out = super(Linear_fw, self).forward(x)
66 | return out
67 |
68 | class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight
69 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True):
70 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias)
71 | self.weight.fast = None
72 | if not self.bias is None:
73 | self.bias.fast = None
74 |
75 | def forward(self, x):
76 | if self.bias is None:
77 | if self.weight.fast is not None:
78 | out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding)
79 | else:
80 | out = super(Conv2d_fw, self).forward(x)
81 | else:
82 | if self.weight.fast is not None and self.bias.fast is not None:
83 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding)
84 | else:
85 | out = super(Conv2d_fw, self).forward(x)
86 |
87 | return out
88 |
89 | class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input with fast weight
90 | def __init__(self, num_features):
91 | super(BatchNorm2d_fw, self).__init__(num_features)
92 | self.weight.fast = None
93 | self.bias.fast = None
94 |
95 | def forward(self, x):
96 | running_mean = torch.zeros(x.data.size()[1]).cuda()
97 | running_var = torch.ones(x.data.size()[1]).cuda()
98 | if self.weight.fast is not None and self.bias.fast is not None:
99 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training = True, momentum = 1)
100 | #batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py
101 | else:
102 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training = True, momentum = 1)
103 | return out
104 |
105 | # Simple ResNet Block
106 | class SimpleBlock(nn.Module):
107 | def __init__(self, method, indim, outdim, half_res, track_bn):
108 | super(SimpleBlock, self).__init__()
109 | self.indim = indim
110 | self.outdim = outdim
111 |
112 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
113 | self.BN1 = nn.BatchNorm2d(outdim, track_running_stats=track_bn)
114 | self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1, bias=False)
115 | self.BN2 = nn.BatchNorm2d(outdim, track_running_stats=track_bn)
116 |
117 | self.relu1 = nn.ReLU(inplace=True)
118 | self.relu2 = nn.ReLU(inplace=True)
119 |
120 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
121 |
122 | self.half_res = half_res
123 |
124 | # if the input number of channels is not equal to the output, then need a 1x1 convolution
125 | if indim!=outdim:
126 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
127 | self.BNshortcut = nn.BatchNorm2d(outdim, track_running_stats=track_bn)
128 |
129 | self.parametrized_layers.append(self.shortcut)
130 | self.parametrized_layers.append(self.BNshortcut)
131 | self.shortcut_type = '1x1'
132 | else:
133 | self.shortcut_type = 'identity'
134 |
135 | for layer in self.parametrized_layers:
136 | init_layer(layer)
137 |
138 | def forward(self, x):
139 | out = self.C1(x)
140 | out = self.BN1(out)
141 | out = self.relu1(out)
142 | out = self.C2(out)
143 | out = self.BN2(out)
144 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
145 | out = out + short_out
146 | out = self.relu2(out)
147 |
148 | return out
149 |
150 | class ResNet(nn.Module):
151 | def __init__(self, method, block, list_of_num_layers, list_of_out_dims, flatten, track_bn, reinit_bn_stats):
152 | # list_of_num_layers specifies number of layers in each stage
153 | # list_of_out_dims specifies number of output channel for each stage
154 | super(ResNet,self).__init__()
155 | assert len(list_of_num_layers)==4, 'Can have only four stages'
156 |
157 | self.reinit_bn_stats = reinit_bn_stats
158 |
159 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
160 | bn1 = nn.BatchNorm2d(64, track_running_stats=track_bn)
161 |
162 | relu = nn.ReLU()
163 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
164 |
165 | init_layer(conv1)
166 | init_layer(bn1)
167 |
168 | trunk = [conv1, bn1, relu, pool1]
169 |
170 | indim = 64
171 | for i in range(4):
172 | for j in range(list_of_num_layers[i]):
173 | half_res = (i>=1) and (j==0)
174 | B = block(method, indim, list_of_out_dims[i], half_res, track_bn)
175 | trunk.append(B)
176 | indim = list_of_out_dims[i]
177 |
178 | if flatten:
179 | avgpool = nn.AvgPool2d(7)
180 | trunk.append(avgpool)
181 | trunk.append(Flatten())
182 | self.final_feat_dim = indim
183 | else:
184 | self.final_feat_dim = [indim, 7, 7]
185 |
186 | self.trunk = nn.Sequential(*trunk)
187 |
188 | def forward(self,x):
189 | if self.reinit_bn_stats:
190 | self._reinit_running_batch_statistics()
191 | out = self.trunk(x)
192 | return out
193 |
194 | def _reinit_running_batch_statistics(self):
195 | with torch.no_grad():
196 | self.trunk[1].running_mean.data.fill_(0.)
197 | self.trunk[1].running_var.data.fill_(1.)
198 |
199 | self.trunk[4].BN1.running_mean.data.fill_(0.)
200 | self.trunk[4].BN1.running_var.data.fill_(1.)
201 | self.trunk[4].BN2.running_mean.data.fill_(0.)
202 | self.trunk[4].BN2.running_var.data.fill_(1.)
203 |
204 | self.trunk[5].BN1.running_mean.data.fill_(0.)
205 | self.trunk[5].BN1.running_var.data.fill_(1.)
206 | self.trunk[5].BN2.running_mean.data.fill_(0.)
207 | self.trunk[5].BN2.running_var.data.fill_(1.)
208 | self.trunk[5].BNshortcut.running_mean.data.fill_(0.)
209 | self.trunk[5].BNshortcut.running_var.data.fill_(1.)
210 |
211 | self.trunk[6].BN1.running_mean.data.fill_(0.)
212 | self.trunk[6].BN1.running_var.data.fill_(1.)
213 | self.trunk[6].BN2.running_mean.data.fill_(0.)
214 | self.trunk[6].BN2.running_var.data.fill_(1.)
215 | self.trunk[6].BNshortcut.running_mean.data.fill_(0.)
216 | self.trunk[6].BNshortcut.running_var.data.fill_(1.)
217 |
218 | self.trunk[7].BN1.running_mean.data.fill_(0.)
219 | self.trunk[7].BN1.running_var.data.fill_(1.)
220 | self.trunk[7].BN2.running_mean.data.fill_(0.)
221 | self.trunk[7].BN2.running_var.data.fill_(1.)
222 | self.trunk[7].BNshortcut.running_mean.data.fill_(0.)
223 | self.trunk[7].BNshortcut.running_var.data.fill_(1.)
224 |
225 |
226 | def return_features(self, x, return_avg=False):
227 | flat = Flatten()
228 | m = nn.AdaptiveAvgPool2d((1,1))
229 |
230 | with torch.no_grad():
231 | block1_out = self.trunk[4](self.trunk[3](self.trunk[2](self.trunk[1](self.trunk[0](x)))))
232 | block2_out = self.trunk[5](block1_out)
233 | block3_out = self.trunk[6](block2_out)
234 | block4_out = self.trunk[7](block3_out)
235 |
236 | if return_avg:
237 | return flat(m(block1_out)), flat(m(block2_out)), flat(m(block3_out)), flat(m(block4_out))
238 | else:
239 | return flat(block1_out), flat(block2_out), flat(block3_out), flat(block4_out)
240 |
241 | def forward_bodyfreeze(self,x):
242 | flat = Flatten()
243 | m = nn.AdaptiveAvgPool2d((1,1))
244 |
245 | with torch.no_grad():
246 | block1_out = self.trunk[4](self.trunk[3](self.trunk[2](self.trunk[1](self.trunk[0](x)))))
247 | block2_out = self.trunk[5](block1_out)
248 | block3_out = self.trunk[6](block2_out)
249 |
250 | out = self.trunk[7].C1(block3_out)
251 | out = self.trunk[7].BN1(out)
252 | out = self.trunk[7].relu1(out)
253 |
254 | out = self.trunk[7].C2(out)
255 | out = self.trunk[7].BN2(out)
256 | short_out = self.trunk[7].BNshortcut(self.trunk[7].shortcut(block3_out))
257 | out = out + short_out
258 | out = self.trunk[7].relu2(out)
259 |
260 | return flat(m(out))
261 |
262 | def ResNet10(method='baseline', track_bn=True, reinit_bn_stats=False):
263 | return ResNet(method, block=SimpleBlock, list_of_num_layers=[1,1,1,1], list_of_out_dims=[64,128,256,512], flatten=True, track_bn=track_bn, reinit_bn_stats=reinit_bn_stats)
264 |
265 | # -*- coding: utf-8 -*-
266 | # https://github.com/ElementAI/embedding-propagation/blob/master/src/models/backbones/resnet12.py
267 |
268 | class Block(torch.nn.Module):
269 | def __init__(self, ni, no, stride, dropout, track_bn, reinit_bn_stats):
270 | super().__init__()
271 | self.reinit_bn_stats = reinit_bn_stats
272 |
273 | self.dropout = nn.Dropout2d(dropout) if dropout > 0 else lambda x: x
274 | self.C0 = nn.Conv2d(ni, no, 3, stride, padding=1, bias=False)
275 | self.BN0 = nn.BatchNorm2d(no, track_running_stats=track_bn)
276 | self.C1 = nn.Conv2d(no, no, 3, 1, padding=1, bias=False)
277 | self.BN1 = nn.BatchNorm2d(no, track_running_stats=track_bn)
278 | self.C2 = nn.Conv2d(no, no, 3, 1, padding=1, bias=False)
279 | self.BN2 = nn.BatchNorm2d(no, track_running_stats=track_bn)
280 | if stride == 2 or ni != no:
281 | self.shortcut = nn.Conv2d(ni, no, 1, stride=1, padding=0, bias=False)
282 | self.BNshortcut = nn.BatchNorm2d(no, track_running_stats=track_bn)
283 |
284 | def get_parameters(self):
285 | return self.parameters()
286 |
287 | def forward(self, x):
288 | if self.reinit_bn_stats:
289 | self._reinit_running_batch_statistics()
290 |
291 | out = self.C0(x)
292 | out = self.BN0(out)
293 | out = F.relu(out)
294 | out = self.dropout(out)
295 | out = self.C1(out)
296 | out = self.BN1(out)
297 | out = F.relu(out)
298 | out = self.dropout(out)
299 | out = self.C2(out)
300 | out = self.BN2(out)
301 | out += self.BNshortcut(self.shortcut(x))
302 | out = F.relu(out)
303 |
304 | return out
305 |
306 | def _reinit_running_batch_statistics(self):
307 | with torch.no_grad():
308 | self.BN0.running_mean.data.fill_(0.)
309 | self.BN0.running_var.data.fill_(1.)
310 | self.BN1.running_mean.data.fill_(0.)
311 | self.BN1.running_var.data.fill_(1.)
312 | self.BN2.running_mean.data.fill_(0.)
313 | self.BN2.running_var.data.fill_(1.)
314 | self.BNshortcut.running_mean.data.fill_(0.)
315 | self.BNshortcut.running_var.data.fill_(1.)
316 |
317 | class ResNet12(torch.nn.Module):
318 | def __init__(self, track_bn, reinit_bn_stats, width=1, dropout=0):
319 | super().__init__()
320 | self.final_feat_dim = 512
321 | assert(width == 1) # Comment for different variants of this model
322 | self.widths = [x * int(width) for x in [64, 128, 256]]
323 | self.widths.append(self.final_feat_dim * width)
324 | # self.bn_out = nn.BatchNorm1d(self.final_feat_dim)
325 |
326 | start_width = 3
327 | for i in range(len(self.widths)):
328 | setattr(self, "group_%d" %i, Block(start_width, self.widths[i], 1, dropout, track_bn, reinit_bn_stats))
329 | start_width = self.widths[i]
330 |
331 | def add_classifier(self, nclasses, name="classifier", modalities=None):
332 | setattr(self, name, torch.nn.Linear(self.final_feat_dim, nclasses))
333 |
334 | def up_to_embedding(self, x):
335 | """ Applies the four residual groups
336 | Args:
337 | x: input images
338 | n: number of few-shot classes
339 | k: number of images per few-shot class
340 | """
341 | for i in range(len(self.widths)):
342 | x = getattr(self, "group_%d" % i)(x)
343 | x = F.max_pool2d(x, 3, 2, 1)
344 | return x
345 |
346 | def forward(self, x):
347 | """Main Pytorch forward function
348 | Returns: class logits
349 | Args:
350 | x: input mages
351 | """
352 | *args, c, h, w = x.size()
353 | x = x.view(-1, c, h, w)
354 | x = self.up_to_embedding(x)
355 | # return F.relu(self.bn_out(x.mean(3).mean(2)), True)
356 | return F.relu(x.mean(3).mean(2), True)
357 |
358 |
359 | class ResNet18(torchvision.models.resnet.ResNet):
360 | def __init__(self, track_bn=True):
361 | def norm_layer(*args, **kwargs):
362 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
363 | super().__init__(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], norm_layer=norm_layer)
364 | del self.fc
365 | self.final_feat_dim = 512
366 |
367 | def load_imagenet_weights(self, progress=True):
368 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet18'],
369 | progress=progress)
370 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
371 | if len(missing) > 0:
372 | raise AssertionError('Model code may be incorrect')
373 |
374 | def _forward_impl(self, x: Tensor) -> Tensor:
375 | # See note [TorchScript super()]
376 | x = self.conv1(x)
377 | x = self.bn1(x)
378 | x = self.relu(x)
379 | x = self.maxpool(x)
380 |
381 | x = self.layer1(x)
382 | x = self.layer2(x)
383 | x = self.layer3(x)
384 | x = self.layer4(x)
385 |
386 | x = self.avgpool(x)
387 | x = torch.flatten(x, 1)
388 | # x = self.fc(x)
389 |
390 | return x
391 |
392 | class ResNet50(torchvision.models.resnet.ResNet):
393 | def __init__(self, track_bn=True):
394 | def norm_layer(*args, **kwargs):
395 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
396 | super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], norm_layer=norm_layer)
397 | del self.fc
398 | self.final_feat_dim = 2048
399 |
400 | def load_imagenet_weights(self, progress=True):
401 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet50'],
402 | progress=progress)
403 | missing, unexpected = self.load_state_dict(state_dict, strict=False)
404 | if len(missing) > 0:
405 | raise AssertionError('Model code may be incorrect')
406 |
407 | def _forward_impl(self, x: Tensor) -> Tensor:
408 | # See note [TorchScript super()]
409 | x = self.conv1(x)
410 | x = self.bn1(x)
411 | x = self.relu(x)
412 | x = self.maxpool(x)
413 |
414 | x = self.layer1(x)
415 | x = self.layer2(x)
416 | x = self.layer3(x)
417 | x = self.layer4(x)
418 |
419 | x = self.avgpool(x)
420 | x = torch.flatten(x, 1)
421 | # x = self.fc(x)
422 |
423 | return x
424 |
425 |
426 | ##########################################################################################################
427 | # code from https://github.com/WangYueFt/rfs/blob/f8c837ba93c62dd0ac68a2f4019c619aa86b8421/models/resnet.py#L88
428 |
429 | def conv3x3(in_planes, out_planes, stride=1):
430 | """3x3 convolution with padding"""
431 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
432 | padding=1, bias=False)
433 |
434 |
435 | class SELayer(nn.Module):
436 | def __init__(self, channel, reduction=16):
437 | super(SELayer, self).__init__()
438 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
439 | self.fc = nn.Sequential(
440 | nn.Linear(channel, channel // reduction),
441 | nn.ReLU(inplace=True),
442 | nn.Linear(channel // reduction, channel),
443 | nn.Sigmoid()
444 | )
445 |
446 | def forward(self, x):
447 | b, c, _, _ = x.size()
448 | y = self.avg_pool(x).view(b, c)
449 | y = self.fc(y).view(b, c, 1, 1)
450 | return x * y
451 |
452 |
453 | class DropBlock(nn.Module):
454 | def __init__(self, block_size):
455 | super(DropBlock, self).__init__()
456 |
457 | self.block_size = block_size
458 | #self.gamma = gamma
459 | #self.bernouli = Bernoulli(gamma)
460 |
461 | def forward(self, x, gamma):
462 | # shape: (bsize, channels, height, width)
463 |
464 | if self.training:
465 | batch_size, channels, height, width = x.shape
466 |
467 | bernoulli = Bernoulli(gamma)
468 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
469 | block_mask = self._compute_block_mask(mask)
470 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
471 | count_ones = block_mask.sum()
472 |
473 | return block_mask * x * (countM / count_ones)
474 | else:
475 | return x
476 |
477 | def _compute_block_mask(self, mask):
478 | left_padding = int((self.block_size-1) / 2)
479 | right_padding = int(self.block_size / 2)
480 |
481 | batch_size, channels, height, width = mask.shape
482 | #print ("mask", mask[0][0])
483 | non_zero_idxs = mask.nonzero()
484 | nr_blocks = non_zero_idxs.shape[0]
485 |
486 | offsets = torch.stack(
487 | [
488 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
489 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding
490 | ]
491 | ).t().cuda()
492 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1)
493 |
494 | if nr_blocks > 0:
495 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
496 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
497 | offsets = offsets.long()
498 |
499 | block_idxs = non_zero_idxs + offsets
500 | #block_idxs += left_padding
501 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
502 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
503 | else:
504 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
505 |
506 | block_mask = 1 - padded_mask#[:height, :width]
507 | return block_mask
508 |
509 |
510 | class BasicBlock(torch.nn.Module):
511 | expansion = 1
512 |
513 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False,
514 | block_size=1, use_se=False):
515 | super(BasicBlock, self).__init__()
516 | self.conv1 = conv3x3(inplanes, planes)
517 | self.bn1 = nn.BatchNorm2d(planes)
518 | self.relu = nn.LeakyReLU(0.1)
519 | self.conv2 = conv3x3(planes, planes)
520 | self.bn2 = nn.BatchNorm2d(planes)
521 | self.conv3 = conv3x3(planes, planes)
522 | self.bn3 = nn.BatchNorm2d(planes)
523 | self.maxpool = nn.MaxPool2d(stride)
524 | self.downsample = downsample
525 | self.stride = stride
526 | self.drop_rate = drop_rate
527 | self.num_batches_tracked = 0
528 | self.drop_block = drop_block
529 | self.block_size = block_size
530 | self.DropBlock = DropBlock(block_size=self.block_size)
531 | self.use_se = use_se
532 | if self.use_se:
533 | self.se = SELayer(planes, 4)
534 |
535 | def forward(self, x):
536 | self.num_batches_tracked += 1
537 |
538 | residual = x
539 |
540 | out = self.conv1(x)
541 | out = self.bn1(out)
542 | out = self.relu(out)
543 |
544 | out = self.conv2(out)
545 | out = self.bn2(out)
546 | out = self.relu(out)
547 |
548 | out = self.conv3(out)
549 | out = self.bn3(out)
550 | if self.use_se:
551 | out = self.se(out)
552 |
553 | if self.downsample is not None:
554 | residual = self.downsample(x)
555 | out += residual
556 | out = self.relu(out)
557 | out = self.maxpool(out)
558 |
559 | if self.drop_rate > 0:
560 | if self.drop_block == True:
561 | feat_size = out.size()[2]
562 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
563 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
564 | out = self.DropBlock(out, gamma=gamma)
565 | else:
566 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
567 |
568 | return out
569 |
570 |
571 | class ResNet18_84x84(torch.nn.Module):
572 | def __init__(self, track_bn=True, block=BasicBlock, n_blocks=[1,1,2,2], keep_prob=1.0, avg_pool=True, drop_rate=0.1,
573 | dropblock_size=5, num_classes=-1, use_se=False):
574 | super(ResNet18_84x84, self).__init__()
575 | self.final_feat_dim = 640
576 |
577 | self.inplanes = 3
578 | self.use_se = use_se
579 | self.layer1 = self._make_layer(block, n_blocks[0], 64,
580 | stride=2, drop_rate=drop_rate)
581 | self.layer2 = self._make_layer(block, n_blocks[1], 160,
582 | stride=2, drop_rate=drop_rate)
583 | self.layer3 = self._make_layer(block, n_blocks[2], 320,
584 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
585 | self.layer4 = self._make_layer(block, n_blocks[3], 640,
586 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
587 | if avg_pool:
588 | # self.avgpool = nn.AvgPool2d(5, stride=1)
589 | self.avgpool = nn.AdaptiveAvgPool2d(1)
590 | self.keep_prob = keep_prob
591 | self.keep_avg_pool = avg_pool
592 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
593 | self.drop_rate = drop_rate
594 |
595 | for m in self.modules():
596 | if isinstance(m, nn.Conv2d):
597 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
598 | elif isinstance(m, nn.BatchNorm2d):
599 | nn.init.constant_(m.weight, 1)
600 | nn.init.constant_(m.bias, 0)
601 | if not track_bn:
602 | m.track_running_stats = False
603 |
604 | # self.num_classes = num_classes
605 | # if self.num_classes > 0:
606 | # self.classifier = nn.Linear(640, self.num_classes)
607 |
608 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
609 | downsample = None
610 | if stride != 1 or self.inplanes != planes * block.expansion:
611 | downsample = nn.Sequential(
612 | nn.Conv2d(self.inplanes, planes * block.expansion,
613 | kernel_size=1, stride=1, bias=False),
614 | nn.BatchNorm2d(planes * block.expansion),
615 | )
616 |
617 | layers = []
618 | if n_block == 1:
619 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se)
620 | else:
621 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se)
622 | layers.append(layer)
623 | self.inplanes = planes * block.expansion
624 |
625 | for i in range(1, n_block):
626 | if i == n_block - 1:
627 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block,
628 | block_size=block_size, use_se=self.use_se)
629 | else:
630 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se)
631 | layers.append(layer)
632 |
633 | return nn.Sequential(*layers)
634 |
635 | def forward(self, x, is_feat=False):
636 | x = self.layer1(x)
637 | # f0 = x
638 | x = self.layer2(x)
639 | # f1 = x
640 | x = self.layer3(x)
641 | # f2 = x
642 | x = self.layer4(x)
643 | # f3 = x
644 | if self.keep_avg_pool:
645 | x = self.avgpool(x)
646 | x = x.view(x.size(0), -1)
647 | # feat = x
648 | # if self.num_classes > 0:
649 | # x = self.classifier(x)
650 |
651 | # if is_feat:
652 | # return [f0, f1, f2, f3, feat], x
653 | # else:
654 | # return x
655 | return x
656 |
657 |
658 | _backbone_class_map = {
659 | 'resnet10': ResNet10,
660 | 'resnet18': ResNet18,
661 | 'resnet50': ResNet50,
662 | }
663 |
664 |
665 | def get_backbone_class(key):
666 | if key in _backbone_class_map:
667 | return _backbone_class_map[key]
668 | else:
669 | raise ValueError('Invalid backbone: {}'.format(key))
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | save_dir = './logs'
2 |
3 | miniImageNet_path = '/data/cdfsl/miniImagenet'
4 | miniImageNet_test_path = '/data/cdfsl/miniImagenet_test'
5 | tieredImageNet_path = '/data/cdfsl/tieredImagenet/train'
6 | tieredImageNet_test_path = '/data/cdfsl/tieredImagenet/test'
7 | ImageNet_path = '/data/cdfsl/Imagenet/train'
8 | DTD_path = '/ssd/dtd/images/'
9 |
10 | ISIC_path = "/data/cdfsl/ISIC"
11 | ChestX_path = "/data/cdfsl/chestX"
12 | CropDisease_path = "/data/cdfsl/CropDiseases"
13 | EuroSAT_path = "/data/cdfsl/EuroSAT"
14 |
15 | cars_path = '/data/cdfsl/cars_cdfsl'
16 | cub_path = '/data/cdfsl/CUB_200_2011/images'
17 | places_path = '/data/cdfsl/places_cdfsl'
18 | plantae_path = '/data/cdfsl/plantae_cdfsl'
19 |
--------------------------------------------------------------------------------
/data_preparation/cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from collections import defaultdict
4 |
5 | import scipy.io
6 | from tqdm import tqdm
7 |
8 | ROOT = os.path.dirname(os.path.abspath(__file__))
9 | METADATA_PATH = os.path.join(ROOT, "input", "cars_annos.mat")
10 | SOURCE_DIR = os.path.join(ROOT, "input", "car_ims")
11 | TARGET_DIR = os.path.join(ROOT, "output", "cars_cdfsl")
12 |
13 | if not os.path.isfile(METADATA_PATH):
14 | raise Exception("Could not find metadata file at `{}`".format(METADATA_PATH))
15 | if not os.path.isdir(SOURCE_DIR):
16 | raise Exception("could not find image folder at `{}`".format(SOURCE_DIR))
17 |
18 | metadata = scipy.io.loadmat(METADATA_PATH)
19 | metadata = metadata["annotations"][0]
20 |
21 | paths_by_class = defaultdict(list)
22 | total = len(metadata)
23 | for m in metadata:
24 | path, _, _, _, _, cls, test = m
25 | cls = cls.item()
26 | path = path.item()
27 | path = os.path.basename(path)
28 | paths_by_class[str(cls)].append(path)
29 |
30 | print("Copying images to `{}`...".format(TARGET_DIR))
31 | with tqdm(total=total) as pbar:
32 | for cls, paths in paths_by_class.items():
33 | for path in paths:
34 | source_path = os.path.join(SOURCE_DIR, path)
35 | target_path = os.path.join(TARGET_DIR, cls, path)
36 | os.makedirs(os.path.dirname(target_path), exist_ok=True)
37 | shutil.copy(source_path, target_path)
38 | pbar.update()
39 | print("Complete")
40 |
--------------------------------------------------------------------------------
/data_preparation/input/.gitignore:
--------------------------------------------------------------------------------
1 | *
--------------------------------------------------------------------------------
/data_preparation/output/.gitignore:
--------------------------------------------------------------------------------
1 | *
--------------------------------------------------------------------------------
/data_preparation/places.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import shutil
4 |
5 | from tqdm import tqdm
6 |
7 | ROOT = os.path.dirname(os.path.abspath(__file__))
8 | SUBSET_PATH = os.path.join(ROOT, "places_cdfsl_subset_16_class_1715_sample_seed_0.json")
9 | SOURCE_DIR = os.path.join(ROOT, "input", "places365_standard/train")
10 | TARGET_DIR = os.path.join(ROOT, "output", "places_cdfsl")
11 |
12 | if not os.path.isfile(SUBSET_PATH):
13 | raise Exception("Could not find subset file at `{}`".format(SUBSET_PATH))
14 | if not os.path.isdir(SOURCE_DIR):
15 | raise Exception("could not find image folder at `{}`".format(SOURCE_DIR))
16 |
17 | with open(SUBSET_PATH) as f:
18 | subset_data = json.load(f)
19 | all_paths = []
20 | for paths in subset_data.values():
21 | all_paths.extend(paths)
22 |
23 | print("Copying images to {}...".format(TARGET_DIR))
24 | for p in tqdm(all_paths):
25 | source_path = os.path.join(SOURCE_DIR, p)
26 | target_path = os.path.join(TARGET_DIR, p)
27 | os.makedirs(os.path.dirname(target_path), exist_ok=True)
28 | shutil.copy(source_path, target_path)
29 | print("Complete")
30 |
--------------------------------------------------------------------------------
/data_preparation/plantae.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import shutil
4 |
5 | from tqdm import tqdm
6 |
7 | ROOT = os.path.dirname(os.path.abspath(__file__))
8 | SUBSET_PATH = os.path.join(ROOT, "plantae_cdfsl_subset_69_class.json")
9 | SOURCE_DIR = os.path.join(ROOT, "input", "Plantae")
10 | TARGET_DIR = os.path.join(ROOT, "output", "plantae_cdfsl")
11 |
12 | if not os.path.isfile(SUBSET_PATH):
13 | raise Exception("Could not find subset file at `{}`".format(SUBSET_PATH))
14 | if not os.path.isdir(SOURCE_DIR):
15 | raise Exception("could not find image folder at `{}`".format(SOURCE_DIR))
16 |
17 | with open(SUBSET_PATH) as f:
18 | subset_data = json.load(f)
19 | all_paths = []
20 | for paths in subset_data.values():
21 | all_paths.extend(paths)
22 |
23 | print("Copying images to {}...".format(TARGET_DIR))
24 | for p in tqdm(all_paths):
25 | source_path = os.path.join(SOURCE_DIR, p)
26 | target_path = os.path.join(TARGET_DIR, p)
27 | os.makedirs(os.path.dirname(target_path), exist_ok=True)
28 | shutil.copy(source_path, target_path)
29 | print("Complete")
30 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sungnyun/understanding-cdfsl/fdd45d8e5af0cb35bdb0b2b4046ed8089abdf304/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/dataloader.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, MutableMapping
2 | from weakref import WeakValueDictionary
3 |
4 | import torch
5 | import torch.utils.data
6 | from torch.utils.data import Dataset
7 |
8 | from datasets.datasets import dataset_class_map
9 | from datasets.sampler import EpisodicBatchSampler
10 | from datasets.split import split_dataset
11 | from datasets.transforms import get_composed_transform
12 |
13 | _unlabeled_dataset_cache: MutableMapping[Tuple[str, str, int, bool, int], Dataset] = WeakValueDictionary()
14 |
15 | DEFAULT_IMAGE_SIZE = 224
16 |
17 |
18 | class ToSiamese:
19 | """
20 | A wrapper for torchvision transform. The transform is applied twice for
21 | SimCLR training
22 | """
23 |
24 | def __init__(self, transform, transform2=None):
25 | self.transform = transform
26 |
27 | if transform2 is not None:
28 | self.transform2 = transform2
29 | else:
30 | self.transform2 = transform
31 |
32 | def __call__(self, img):
33 | return self.transform(img), self.transform2(img)
34 |
35 |
36 | def get_default_dataset(dataset_name: str, augmentation: str, image_size: int = None, siamese=False):
37 | """
38 | :param augmentation: One of {'base', 'strong', None, 'none'}
39 | """
40 | if image_size is None:
41 | print('Using default image size: {}'.format(DEFAULT_IMAGE_SIZE))
42 | image_size = DEFAULT_IMAGE_SIZE
43 |
44 | try:
45 | dataset_cls = dataset_class_map[dataset_name]
46 | except KeyError as e:
47 | raise ValueError('Unsupported dataset: {}'.format(dataset_name))
48 |
49 | transform = get_composed_transform(augmentation, image_size=image_size)
50 | if siamese:
51 | transform = ToSiamese(transform)
52 | return dataset_cls(transform=transform)
53 |
54 |
55 | def get_split_dataset(dataset_name: str, augmentation: str, image_size: int = None, siamese=False,
56 | unlabeled_ratio: int = 20, seed=1):
57 | # If cache details change, just remove the cache – it's not worth the maintenance TBH.
58 | cache_key = (dataset_name, augmentation, image_size, siamese, unlabeled_ratio)
59 | if cache_key not in _unlabeled_dataset_cache:
60 | dataset = get_default_dataset(dataset_name=dataset_name, augmentation=augmentation, image_size=image_size,
61 | siamese=siamese)
62 | unlabeled, labeled = split_dataset(dataset, ratio=unlabeled_ratio, seed=seed)
63 | # Cross-reference so that strong ref persists if either split is currently referenced
64 | unlabeled.counterpart = labeled
65 | labeled.counterpart = unlabeled
66 | _unlabeled_dataset_cache[cache_key] = unlabeled
67 |
68 | unlabeled = _unlabeled_dataset_cache[cache_key]
69 | labeled = unlabeled.counterpart
70 |
71 | return unlabeled, labeled
72 |
73 |
74 | def get_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None, siamese=False,
75 | num_workers=2, shuffle=True, drop_last=False):
76 | dataset = get_default_dataset(dataset_name=dataset_name, augmentation=augmentation, image_size=image_size,
77 | siamese=siamese)
78 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
79 | shuffle=shuffle, drop_last=drop_last)
80 |
81 |
82 | def get_split_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None, siamese=False,
83 | unlabeled_ratio: int = 20, num_workers=2, shuffle=True, drop_last=False, seed=1):
84 | unlabeled, labeled = get_split_dataset(dataset_name, augmentation, image_size=image_size, siamese=siamese,
85 | unlabeled_ratio=unlabeled_ratio, seed=seed)
86 | dataloaders = []
87 | for dataset in [unlabeled, labeled]:
88 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
89 | shuffle=shuffle, drop_last=drop_last)
90 | dataloaders.append(dataloader)
91 | return dataloaders
92 |
93 |
94 | def get_labeled_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None, siamese=False,
95 | unlabeled_ratio: int = 20, num_workers=2, shuffle=True, drop_last=False, split_seed=1):
96 | unlabeled, labeled = get_split_dataloader(dataset_name, augmentation, batch_size, image_size, siamese=siamese,
97 | unlabeled_ratio=unlabeled_ratio,
98 | num_workers=num_workers, shuffle=shuffle, drop_last=drop_last,
99 | seed=split_seed)
100 | return labeled
101 |
102 |
103 | def get_unlabeled_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None,
104 | siamese=False, unlabeled_ratio: int = 20, num_workers=2,
105 | shuffle=True, drop_last=True, split_seed=1):
106 | unlabeled, labeled = get_split_dataloader(dataset_name, augmentation, batch_size, image_size, siamese=siamese,
107 | unlabeled_ratio=unlabeled_ratio,
108 | num_workers=num_workers, shuffle=shuffle, drop_last=drop_last,
109 | seed=split_seed)
110 | return unlabeled
111 |
112 |
113 | def get_episodic_dataloader(dataset_name: str, n_way: int, n_shot: int, support: bool, n_episodes=600, n_query_shot=15,
114 | augmentation: str = None, image_size: int = None, num_workers=2, n_epochs=1,
115 | episode_seed=0):
116 | dataset = get_default_dataset(dataset_name=dataset_name, augmentation=augmentation, image_size=image_size,
117 | siamese=False)
118 | sampler = EpisodicBatchSampler(dataset, n_way=n_way, n_shot=n_shot, n_query_shot=n_query_shot,
119 | n_episodes=n_episodes, support=support, n_epochs=n_epochs, seed=episode_seed)
120 | return torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_sampler=sampler)
121 |
122 |
123 | def get_labeled_episodic_dataloader(dataset_name: str, n_way: int, n_shot: int, support: bool, n_episodes=600,
124 | n_query_shot=15, n_epochs=1, augmentation: str = None, image_size: int = None,
125 | unlabeled_ratio: int = 20, num_workers=2, split_seed=1, episode_seed=0):
126 | unlabeled, labeled = get_split_dataset(dataset_name, augmentation, image_size=image_size, siamese=False,
127 | unlabeled_ratio=unlabeled_ratio, seed=split_seed)
128 | sampler = EpisodicBatchSampler(labeled, n_way=n_way, n_shot=n_shot, n_query_shot=n_query_shot,
129 | n_episodes=n_episodes, support=support, n_epochs=n_epochs, seed=episode_seed)
130 | return torch.utils.data.DataLoader(labeled, num_workers=num_workers, batch_sampler=sampler)
131 |
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | All dataset classes in unified `torchvision.datasets.ImageFolder` format!
3 | """
4 |
5 | import os
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from torchvision.datasets import ImageFolder
10 |
11 | from configs import *
12 |
13 |
14 | class MiniImageNetDataset(ImageFolder):
15 | name = "miniImageNet"
16 |
17 | def __init__(self, root=miniImageNet_path, *args, **kwargs):
18 | super().__init__(root=root, *args, **kwargs)
19 |
20 |
21 | class MiniImageNetTestDataset(ImageFolder):
22 | name = "miniImageNet_test"
23 |
24 | def __init__(self, root=miniImageNet_test_path, *args, **kwargs):
25 | super().__init__(root=root, *args, **kwargs)
26 |
27 |
28 | class TieredImageNetDataset(ImageFolder):
29 | name = "tieredImageNet"
30 |
31 | def __init__(self, root=tieredImageNet_path, *args, **kwargs):
32 | super().__init__(root=root, *args, **kwargs)
33 |
34 |
35 | class TieredImageNetTestDataset(ImageFolder):
36 | name = "tieredImageNet_test"
37 |
38 | def __init__(self, root=tieredImageNet_test_path, *args, **kwargs):
39 | super().__init__(root=root, *args, **kwargs)
40 |
41 |
42 | class ImageNetDataset(ImageFolder):
43 | name = "ImageNet"
44 |
45 | def __init__(self, root=ImageNet_path, *args, **kwargs):
46 | super().__init__(root=root, *args, **kwargs)
47 |
48 |
49 | class CropDiseaseDataset(ImageFolder):
50 | name = "CropDisease"
51 |
52 | def __init__(self, root=CropDisease_path, *args, **kwargs):
53 | super().__init__(root=os.path.join(root, "dataset", "train"), *args, **kwargs)
54 |
55 |
56 | class EuroSATDataset(ImageFolder):
57 | name = "EuroSAT"
58 |
59 | def __init__(self, root=EuroSAT_path, *args, **kwargs):
60 | super().__init__(root, *args, **kwargs)
61 |
62 |
63 | class ISICDataset(ImageFolder):
64 | name = "ISIC"
65 | """
66 | Implementation note: functions for finding data files have been customized so that data is selected based on
67 | the given CSV file.
68 | """
69 |
70 | def __init__(self, root=ISIC_path, *args, **kwargs):
71 | csv_path = os.path.join(root, "ISIC2018_Task3_Training_GroundTruth.csv")
72 | self.metadata = pd.read_csv(csv_path)
73 | super().__init__(root, *args, **kwargs)
74 |
75 | def make_dataset(self, root, *args, **kwargs):
76 | paths = np.asarray(self.metadata.iloc[:, 0])
77 | labels = np.asarray(self.metadata.iloc[:, 1:])
78 | labels = (labels != 0).argmax(axis=1)
79 |
80 | samples = []
81 | for path, label in zip(paths, labels):
82 | path = os.path.join(root, path + ".jpg")
83 | samples.append((path, label))
84 | samples.sort()
85 |
86 | return samples
87 |
88 | def find_classes(self, _):
89 | classes = self.metadata.columns[1:].tolist()
90 | classes.sort()
91 | class_to_idx = dict()
92 | for i, cls in enumerate(classes):
93 | class_to_idx[cls] = i
94 | return classes, class_to_idx
95 |
96 | _find_classes = find_classes # compatibility with earlier versions
97 |
98 |
99 | class ChestXDataset(ImageFolder):
100 | name = "ChestX"
101 | """
102 | Implementation note: functions for finding data files have been customized so that data is selected based on
103 | the given CSV file.
104 | """
105 |
106 | def __init__(self, root=ChestX_path, *args, **kwargs):
107 | csv_path = os.path.join(root, "Data_Entry_2017.csv")
108 | images_root = os.path.join(root, "images")
109 | # self.used_labels = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia",
110 | # "Pneumothorax"]
111 | self.used_labels = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule",
112 | "Pneumothorax"]
113 | self.labels_maps = {"Atelectasis": 0, "Cardiomegaly": 1, "Effusion": 2, "Infiltration": 3, "Mass": 4,
114 | "Nodule": 5, "Pneumothorax": 6}
115 | self.metadata = pd.read_csv(csv_path)
116 | super().__init__(images_root, *args, **kwargs)
117 |
118 | def make_dataset(self, root, *args, **kwargs):
119 | samples = []
120 | paths = np.asarray(self.metadata.iloc[:, 0])
121 | labels = np.asarray(self.metadata.iloc[:, 1])
122 | for path, label in zip(paths, labels):
123 | label = label.split("|")
124 | if len(label) == 1 and label[0] != "No Finding" and label[0] != "Pneumonia" and label[
125 | 0] in self.used_labels:
126 | path = os.path.join(root, path)
127 | label = self.labels_maps[label[0]]
128 | samples.append((path, label))
129 | samples.sort()
130 | return samples
131 |
132 | def find_classes(self, _):
133 | return self.used_labels, self.labels_maps
134 |
135 | _find_classes = find_classes # compatibility with earlier versions
136 |
137 |
138 | class CarsDataset(ImageFolder):
139 | name = "cars"
140 |
141 | def __init__(self, root=cars_path, *args, **kwargs):
142 | super().__init__(root=root, *args, **kwargs)
143 |
144 |
145 | class CUBDataset(ImageFolder):
146 | name = "cub"
147 |
148 | def __init__(self, root=cub_path, *args, **kwargs):
149 | super().__init__(root=root, *args, **kwargs)
150 |
151 |
152 | class PlacesDataset(ImageFolder):
153 | name = "places"
154 |
155 | def __init__(self, root=places_path, *args, **kwargs):
156 | super().__init__(root=root, *args, **kwargs)
157 |
158 |
159 | class PlantaeDataset(ImageFolder):
160 | name = "plantae"
161 |
162 | def __init__(self, root=plantae_path, *args, **kwargs):
163 | super().__init__(root=root, *args, **kwargs)
164 |
165 |
166 | dataset_classes = [
167 | MiniImageNetDataset,
168 | MiniImageNetTestDataset,
169 | TieredImageNetDataset,
170 | TieredImageNetTestDataset,
171 | ImageNetDataset,
172 | CropDiseaseDataset,
173 | EuroSATDataset,
174 | ISICDataset,
175 | ChestXDataset,
176 | CarsDataset,
177 | CUBDataset,
178 | PlacesDataset,
179 | PlantaeDataset,
180 | ]
181 |
182 | dataset_class_map = {
183 | cls.name: cls for cls in dataset_classes
184 | }
185 |
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import numpy as np
4 | from torch.utils.data import Sampler
5 | from torchvision.datasets import ImageFolder
6 |
7 |
8 | class EpisodeSampler:
9 | """
10 | Stable sampler for support and query indices. Used by episodic batch sampler, so that the support and query sets
11 | can be sampled from independent data loaders using the same splits, i.e., such that support and query do not overlap.
12 | """
13 |
14 | def __init__(self, dataset: ImageFolder, n_way: int, n_shot: int, n_query_shot: int, n_episodes: int,
15 | seed: int = 0):
16 | self.dataset = dataset
17 | self.n_classes = len(dataset.classes)
18 | self.w = n_way
19 | self.s = n_shot
20 | self.q = n_query_shot
21 | self.n_episodes = n_episodes
22 | self.seed = seed
23 |
24 | rs = np.random.RandomState(seed)
25 | self.episode_seeds = []
26 | for i in range(n_episodes):
27 | self.episode_seeds.append(rs.randint(2 ** 32 - 1))
28 |
29 | self.indices_by_class = defaultdict(list)
30 | for index, (path, label) in enumerate(dataset.samples):
31 | self.indices_by_class[label].append(index)
32 |
33 | def __getitem__(self, index):
34 | """
35 | :param index:
36 | :return: support: ndarray[w, s], query: ndarray[w ,q]
37 | """
38 | rs = np.random.RandomState(self.episode_seeds[index])
39 | selected_classes = rs.permutation(self.n_classes)[:self.w]
40 | indices = []
41 | for cls in selected_classes:
42 | candidates: int = len(self.indices_by_class[cls])
43 | choices: int = self.s + self.q
44 |
45 | if candidates > choices:
46 | indices.append(
47 | rs.choice(self.indices_by_class[cls], choices, replace=False))
48 | else:
49 | indices.append(
50 | rs.choice(self.indices_by_class[cls], choices, replace=True))
51 |
52 | episode = np.stack(indices)
53 | support = episode[:, :self.s]
54 | query = episode[:, self.s:]
55 | return support, query
56 |
57 | def __len__(self):
58 | return self.n_episodes
59 |
60 |
61 | class EpisodicBatchSampler(Sampler):
62 | """
63 | For each epoch, the same batch is yielded repeatedly. For batch-training within episodes, you need to divide up the
64 | sampled data (from the dataloader) into further smaller batches.
65 |
66 | For classification-based training, note that you need to reset the class indices to [0, 0, ..., 1, ..., w-1]. Note
67 | that this is why inter-episode batches are not supported by the sampler: it's harder to reset the class indices.
68 | """
69 |
70 | def __init__(self, dataset: ImageFolder, n_way: int, n_shot: int, n_query_shot: int, n_episodes: int, support: bool,
71 | n_epochs=1, seed=0):
72 | super().__init__(dataset)
73 | self.dataset = dataset
74 |
75 | self.w = n_way
76 | self.s = n_shot
77 | self.q = n_query_shot
78 | self.episode_sampler = EpisodeSampler(dataset, n_way, n_shot, n_query_shot, n_episodes, seed)
79 |
80 | self.n_episodes = n_episodes
81 | self.n_epochs = n_epochs
82 | self.support = support
83 |
84 | def __len__(self):
85 | return self.n_episodes * self.n_epochs
86 |
87 | def __iter__(self):
88 | for i in range(self.n_episodes):
89 | support, query = self.episode_sampler[i]
90 | indices = support if self.support else query
91 | indices = indices.flatten()
92 | for j in range(self.n_epochs):
93 | yield indices
94 |
--------------------------------------------------------------------------------
/datasets/split.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | from typing import List, Tuple
4 |
5 | import pandas as pd
6 | from numpy.random import RandomState
7 | from torchvision.datasets import ImageFolder
8 |
9 | DIRNAME = os.path.dirname(os.path.abspath(__file__))
10 |
11 | DATASETS_WITH_DEFAULT_SPLITS = [
12 | "miniImageNet",
13 | "miniImageNet_test",
14 | "tieredImageNet",
15 | "tieredImageNet_test",
16 | "CropDisease",
17 | "EuroSAT",
18 | "ISIC",
19 | "ChestX",
20 | ]
21 |
22 |
23 | def split_dataset(dataset: ImageFolder, ratio=20, seed=1):
24 | """
25 | :param dataset:
26 | :param ratio: Ratio of unlabeled portion
27 | :param seed:
28 | :return: unlabeled_dataset, labeled_dataset
29 | """
30 | assert (0 <= ratio <= 100)
31 |
32 | # Check default splits
33 | unlabeled_path = _get_split_path(dataset, ratio, seed, True)
34 | labeled_path = _get_split_path(dataset, ratio, seed, False)
35 | for path in [unlabeled_path, labeled_path]:
36 | if ratio == 20 and seed == 1 and dataset.name in DATASETS_WITH_DEFAULT_SPLITS and not os.path.exists(path):
37 | raise Exception("Default split file missing: {}".format(path))
38 |
39 | if os.path.exists(unlabeled_path) and os.path.exists(labeled_path):
40 | print("Loading unlabeled split from {}".format(unlabeled_path))
41 | print("Loading labeled split from {}".format(labeled_path))
42 | unlabeled = _load_split(unlabeled_path)
43 | labeled = _load_split(labeled_path)
44 | else:
45 | unlabeled, labeled = _get_split(dataset, ratio, seed)
46 | print("Generating unlabeled split to {}".format(unlabeled_path))
47 | print("Generating labeled split to {}".format(labeled_path))
48 | _save_split(unlabeled, unlabeled_path)
49 | _save_split(labeled, labeled_path)
50 |
51 | ud = copy.deepcopy(dataset)
52 | ld = copy.deepcopy(dataset)
53 |
54 | _apply_split(ud, unlabeled)
55 | _apply_split(ld, labeled)
56 |
57 | return ud, ld
58 |
59 |
60 | def _get_split(dataset: ImageFolder, ratio: int, seed: int) -> Tuple[List[str], List[str]]:
61 | img_paths = []
62 | for path, label in dataset.samples:
63 | root_with_slash = os.path.join(dataset.root, "")
64 | img_paths.append(path.replace(root_with_slash, ""))
65 | img_paths.sort()
66 | # Assert uniqueness
67 | assert (len(img_paths) == len(set(img_paths)))
68 |
69 | rs = RandomState(seed)
70 | unlabeled_count = len(img_paths) * ratio // 100
71 | unlabeled_paths = set(rs.choice(img_paths, unlabeled_count, replace=False))
72 | labeled_paths = set(img_paths) - unlabeled_paths
73 |
74 | return sorted(list(unlabeled_paths)), sorted(list(labeled_paths))
75 |
76 |
77 | def _save_split(split: List, path):
78 | df = pd.DataFrame({
79 | "img_path": split
80 | })
81 | df.to_csv(path)
82 |
83 |
84 | def _load_split(path) -> List[str]:
85 | df = pd.read_csv(path)
86 | return df["img_path"].values
87 |
88 |
89 | def _get_split_path(dataset: ImageFolder, ratio: int, seed=1, unlabeled=True, makedirs=True):
90 | if unlabeled:
91 | basename = '{}_unlabeled_{}.csv'.format(dataset.name, ratio)
92 | else:
93 | basename = '{}_labeled_{}.csv'.format(dataset.name, 100 - ratio)
94 | path = os.path.join(DIRNAME, 'split_seed_{}'.format(seed), basename)
95 | if makedirs:
96 | os.makedirs(os.path.dirname(path), exist_ok=True)
97 | return path
98 |
99 |
100 | def _apply_split(dataset: ImageFolder, split: List[str]):
101 | img_paths = []
102 | for path, label in dataset.samples:
103 | root_with_slash = os.path.join(dataset.root, "")
104 | img_paths.append(path.replace(root_with_slash, ""))
105 |
106 | split_set = set(split)
107 | samples = []
108 | for path, sample in zip(img_paths, dataset.samples):
109 | if len(split) > 0 and '.jpg' not in split[0] and dataset.name == 'ISIC': # HOTFIX (paths in ISIC's default split file don't have ".jpg")
110 | path = path.replace('.jpg', '')
111 | if path in split_set:
112 | samples.append(sample)
113 |
114 | dataset.samples = samples
115 | dataset.imgs = samples
116 | dataset.targets = [s[1] for s in samples]
117 |
--------------------------------------------------------------------------------
/datasets/split_seed_1/ISIC_unlabeled_20.csv:
--------------------------------------------------------------------------------
1 | img_path
2 | ISIC_0028058
3 | ISIC_0029999
4 | ISIC_0029695
5 | ISIC_0033303
6 | ISIC_0025777
7 | ISIC_0025563
8 | ISIC_0028185
9 | ISIC_0024758
10 | ISIC_0032674
11 | ISIC_0030458
12 | ISIC_0030532
13 | ISIC_0030370
14 | ISIC_0025134
15 | ISIC_0033986
16 | ISIC_0027242
17 | ISIC_0033035
18 | ISIC_0034007
19 | ISIC_0025028
20 | ISIC_0034198
21 | ISIC_0031865
22 | ISIC_0032023
23 | ISIC_0033227
24 | ISIC_0031642
25 | ISIC_0027194
26 | ISIC_0025901
27 | ISIC_0031278
28 | ISIC_0032601
29 | ISIC_0030270
30 | ISIC_0029166
31 | ISIC_0033398
32 | ISIC_0033490
33 | ISIC_0030927
34 | ISIC_0029994
35 | ISIC_0028277
36 | ISIC_0027402
37 | ISIC_0028054
38 | ISIC_0028736
39 | ISIC_0029823
40 | ISIC_0031883
41 | ISIC_0029056
42 | ISIC_0032406
43 | ISIC_0028685
44 | ISIC_0028383
45 | ISIC_0024656
46 | ISIC_0029851
47 | ISIC_0027628
48 | ISIC_0026565
49 | ISIC_0032105
50 | ISIC_0026332
51 | ISIC_0034214
52 | ISIC_0027658
53 | ISIC_0032894
54 | ISIC_0033933
55 | ISIC_0025751
56 | ISIC_0029383
57 | ISIC_0034151
58 | ISIC_0029555
59 | ISIC_0032044
60 | ISIC_0032496
61 | ISIC_0029076
62 | ISIC_0028870
63 | ISIC_0029719
64 | ISIC_0025917
65 | ISIC_0024735
66 | ISIC_0034079
67 | ISIC_0028256
68 | ISIC_0029957
69 | ISIC_0029531
70 | ISIC_0032029
71 | ISIC_0025516
72 | ISIC_0025222
73 | ISIC_0028568
74 | ISIC_0031461
75 | ISIC_0024546
76 | ISIC_0030856
77 | ISIC_0030926
78 | ISIC_0026919
79 | ISIC_0026800
80 | ISIC_0025306
81 | ISIC_0029580
82 | ISIC_0024362
83 | ISIC_0027326
84 | ISIC_0034305
85 | ISIC_0024514
86 | ISIC_0026417
87 | ISIC_0028969
88 | ISIC_0032987
89 | ISIC_0029709
90 | ISIC_0025746
91 | ISIC_0024731
92 | ISIC_0032367
93 | ISIC_0025003
94 | ISIC_0031006
95 | ISIC_0025571
96 | ISIC_0025929
97 | ISIC_0029139
98 | ISIC_0031981
99 | ISIC_0025899
100 | ISIC_0031511
101 | ISIC_0026607
102 | ISIC_0028894
103 | ISIC_0029363
104 | ISIC_0030262
105 | ISIC_0032579
106 | ISIC_0026506
107 | ISIC_0031437
108 | ISIC_0033205
109 | ISIC_0025284
110 | ISIC_0027076
111 | ISIC_0028631
112 | ISIC_0025125
113 | ISIC_0033050
114 | ISIC_0030228
115 | ISIC_0031949
116 | ISIC_0033924
117 | ISIC_0033123
118 | ISIC_0025773
119 | ISIC_0033536
120 | ISIC_0024927
121 | ISIC_0032942
122 | ISIC_0029769
123 | ISIC_0032124
124 | ISIC_0028312
125 | ISIC_0027354
126 | ISIC_0026084
127 | ISIC_0030798
128 | ISIC_0028372
129 | ISIC_0024642
130 | ISIC_0026948
131 | ISIC_0031567
132 | ISIC_0025389
133 | ISIC_0031802
134 | ISIC_0026913
135 | ISIC_0031267
136 | ISIC_0027875
137 | ISIC_0028330
138 | ISIC_0032849
139 | ISIC_0024696
140 | ISIC_0025043
141 | ISIC_0024419
142 | ISIC_0025374
143 | ISIC_0028229
144 | ISIC_0025243
145 | ISIC_0026605
146 | ISIC_0030275
147 | ISIC_0033915
148 | ISIC_0032097
149 | ISIC_0029417
150 | ISIC_0031285
151 | ISIC_0031693
152 | ISIC_0024563
153 | ISIC_0026544
154 | ISIC_0025645
155 | ISIC_0025458
156 | ISIC_0033732
157 | ISIC_0030666
158 | ISIC_0025298
159 | ISIC_0032511
160 | ISIC_0026152
161 | ISIC_0026864
162 | ISIC_0029411
163 | ISIC_0026278
164 | ISIC_0028916
165 | ISIC_0028928
166 | ISIC_0027970
167 | ISIC_0031966
168 | ISIC_0027709
169 | ISIC_0030830
170 | ISIC_0024438
171 | ISIC_0026768
172 | ISIC_0028437
173 | ISIC_0024747
174 | ISIC_0028466
175 | ISIC_0026335
176 | ISIC_0025974
177 | ISIC_0030171
178 | ISIC_0029684
179 | ISIC_0031986
180 | ISIC_0026481
181 | ISIC_0029040
182 | ISIC_0029687
183 | ISIC_0029626
184 | ISIC_0031891
185 | ISIC_0029620
186 | ISIC_0027031
187 | ISIC_0034137
188 | ISIC_0033096
189 | ISIC_0029169
190 | ISIC_0033586
191 | ISIC_0025600
192 | ISIC_0032813
193 | ISIC_0025965
194 | ISIC_0032954
195 | ISIC_0031361
196 | ISIC_0027736
197 | ISIC_0027852
198 | ISIC_0024333
199 | ISIC_0030614
200 | ISIC_0024797
201 | ISIC_0032253
202 | ISIC_0034257
203 | ISIC_0028349
204 | ISIC_0033004
205 | ISIC_0026431
206 | ISIC_0026772
207 | ISIC_0033093
208 | ISIC_0030149
209 | ISIC_0030356
210 | ISIC_0028763
211 | ISIC_0029336
212 | ISIC_0031927
213 | ISIC_0024645
214 | ISIC_0028507
215 | ISIC_0032431
216 | ISIC_0027957
217 | ISIC_0024430
218 | ISIC_0030950
219 | ISIC_0027331
220 | ISIC_0026692
221 | ISIC_0026271
222 | ISIC_0027266
223 | ISIC_0031563
224 | ISIC_0025275
225 | ISIC_0028614
226 | ISIC_0032879
227 | ISIC_0025894
228 | ISIC_0025241
229 | ISIC_0030869
230 | ISIC_0024829
231 | ISIC_0025659
232 | ISIC_0030699
233 | ISIC_0027444
234 | ISIC_0025494
235 | ISIC_0032327
236 | ISIC_0032018
237 | ISIC_0031051
238 | ISIC_0033683
239 | ISIC_0025183
240 | ISIC_0027600
241 | ISIC_0032450
242 | ISIC_0033221
243 | ISIC_0029138
244 | ISIC_0030248
245 | ISIC_0027735
246 | ISIC_0031185
247 | ISIC_0032151
248 | ISIC_0030605
249 | ISIC_0029062
250 | ISIC_0034243
251 | ISIC_0031203
252 | ISIC_0032959
253 | ISIC_0028690
254 | ISIC_0027848
255 | ISIC_0030178
256 | ISIC_0032928
257 | ISIC_0025088
258 | ISIC_0026469
259 | ISIC_0030025
260 | ISIC_0025066
261 | ISIC_0024820
262 | ISIC_0025911
263 | ISIC_0028585
264 | ISIC_0031319
265 | ISIC_0027197
266 | ISIC_0028236
267 | ISIC_0026458
268 | ISIC_0024499
269 | ISIC_0025048
270 | ISIC_0032609
271 | ISIC_0026176
272 | ISIC_0026517
273 | ISIC_0032113
274 | ISIC_0027944
275 | ISIC_0029344
276 | ISIC_0030502
277 | ISIC_0027506
278 | ISIC_0033183
279 | ISIC_0030325
280 | ISIC_0034106
281 | ISIC_0033241
282 | ISIC_0024715
283 | ISIC_0024884
284 | ISIC_0024998
285 | ISIC_0030580
286 | ISIC_0032715
287 | ISIC_0029233
288 | ISIC_0030231
289 | ISIC_0032739
290 | ISIC_0033069
291 | ISIC_0025463
292 | ISIC_0031787
293 | ISIC_0033796
294 | ISIC_0027439
295 | ISIC_0024896
296 | ISIC_0025881
297 | ISIC_0028880
298 | ISIC_0027435
299 | ISIC_0033633
300 | ISIC_0030556
301 | ISIC_0031058
302 | ISIC_0026702
303 | ISIC_0033264
304 | ISIC_0027430
305 | ISIC_0029346
306 | ISIC_0025049
307 | ISIC_0026254
308 | ISIC_0024488
309 | ISIC_0027853
310 | ISIC_0028644
311 | ISIC_0027022
312 | ISIC_0029796
313 | ISIC_0025359
314 | ISIC_0032663
315 | ISIC_0024959
316 | ISIC_0024454
317 | ISIC_0025014
318 | ISIC_0024675
319 | ISIC_0024329
320 | ISIC_0025591
321 | ISIC_0026263
322 | ISIC_0032331
323 | ISIC_0026566
324 | ISIC_0029934
325 | ISIC_0025505
326 | ISIC_0028814
327 | ISIC_0026589
328 | ISIC_0025498
329 | ISIC_0032604
330 | ISIC_0033407
331 | ISIC_0032238
332 | ISIC_0029911
333 | ISIC_0032546
334 | ISIC_0032068
335 | ISIC_0029114
336 | ISIC_0026999
337 | ISIC_0028557
338 | ISIC_0032988
339 | ISIC_0027104
340 | ISIC_0033471
341 | ISIC_0026817
342 | ISIC_0032847
343 | ISIC_0026214
344 | ISIC_0032136
345 | ISIC_0032374
346 | ISIC_0026012
347 | ISIC_0025770
348 | ISIC_0028106
349 | ISIC_0027339
350 | ISIC_0027077
351 | ISIC_0030724
352 | ISIC_0034120
353 | ISIC_0024336
354 | ISIC_0026612
355 | ISIC_0031608
356 | ISIC_0025532
357 | ISIC_0026635
358 | ISIC_0027719
359 | ISIC_0033424
360 | ISIC_0024479
361 | ISIC_0033901
362 | ISIC_0027547
363 | ISIC_0029261
364 | ISIC_0030374
365 | ISIC_0028048
366 | ISIC_0030226
367 | ISIC_0032699
368 | ISIC_0032372
369 | ISIC_0025113
370 | ISIC_0025775
371 | ISIC_0028208
372 | ISIC_0031533
373 | ISIC_0024799
374 | ISIC_0031053
375 | ISIC_0025307
376 | ISIC_0026623
377 | ISIC_0033061
378 | ISIC_0033299
379 | ISIC_0025278
380 | ISIC_0029710
381 | ISIC_0026686
382 | ISIC_0032110
383 | ISIC_0030891
384 | ISIC_0030097
385 | ISIC_0029013
386 | ISIC_0031167
387 | ISIC_0034165
388 | ISIC_0032219
389 | ISIC_0028416
390 | ISIC_0031044
391 | ISIC_0025969
392 | ISIC_0031640
393 | ISIC_0027651
394 | ISIC_0024504
395 | ISIC_0025364
396 | ISIC_0024946
397 | ISIC_0031615
398 | ISIC_0024625
399 | ISIC_0031737
400 | ISIC_0026968
401 | ISIC_0033661
402 | ISIC_0033283
403 | ISIC_0026226
404 | ISIC_0026258
405 | ISIC_0024677
406 | ISIC_0024891
407 | ISIC_0027136
408 | ISIC_0032035
409 | ISIC_0034024
410 | ISIC_0034153
411 | ISIC_0024382
412 | ISIC_0030640
413 | ISIC_0031168
414 | ISIC_0029399
415 | ISIC_0026674
416 | ISIC_0031950
417 | ISIC_0030451
418 | ISIC_0029041
419 | ISIC_0024681
420 | ISIC_0024386
421 | ISIC_0031956
422 | ISIC_0032337
423 | ISIC_0026476
424 | ISIC_0034164
425 | ISIC_0027092
426 | ISIC_0031695
427 | ISIC_0027204
428 | ISIC_0031880
429 | ISIC_0031092
430 | ISIC_0029786
431 | ISIC_0024495
432 | ISIC_0029917
433 | ISIC_0027198
434 | ISIC_0034236
435 | ISIC_0033106
436 | ISIC_0027476
437 | ISIC_0026365
438 | ISIC_0028755
439 | ISIC_0031483
440 | ISIC_0025254
441 | ISIC_0030030
442 | ISIC_0032500
443 | ISIC_0034105
444 | ISIC_0030540
445 | ISIC_0033670
446 | ISIC_0032313
447 | ISIC_0027412
448 | ISIC_0029872
449 | ISIC_0031851
450 | ISIC_0026416
451 | ISIC_0032738
452 | ISIC_0025823
453 | ISIC_0031736
454 | ISIC_0028008
455 | ISIC_0030285
456 | ISIC_0033860
457 | ISIC_0025179
458 | ISIC_0025511
459 | ISIC_0031321
460 | ISIC_0025115
461 | ISIC_0029421
462 | ISIC_0026122
463 | ISIC_0031507
464 | ISIC_0033662
465 | ISIC_0029101
466 | ISIC_0028271
467 | ISIC_0029885
468 | ISIC_0027376
469 | ISIC_0031544
470 | ISIC_0033225
471 | ISIC_0031025
472 | ISIC_0033814
473 | ISIC_0027456
474 | ISIC_0030019
475 | ISIC_0027132
476 | ISIC_0028158
477 | ISIC_0026920
478 | ISIC_0033869
479 | ISIC_0031269
480 | ISIC_0027923
481 | ISIC_0029762
482 | ISIC_0032638
483 | ISIC_0031911
484 | ISIC_0030887
485 | ISIC_0026802
486 | ISIC_0031635
487 | ISIC_0030839
488 | ISIC_0031681
489 | ISIC_0028343
490 | ISIC_0029222
491 | ISIC_0025382
492 | ISIC_0033576
493 | ISIC_0025375
494 | ISIC_0033671
495 | ISIC_0025442
496 | ISIC_0025305
497 | ISIC_0027021
498 | ISIC_0033079
499 | ISIC_0031496
500 | ISIC_0030982
501 | ISIC_0033894
502 | ISIC_0032522
503 | ISIC_0024567
504 | ISIC_0028216
505 | ISIC_0027603
506 | ISIC_0030955
507 | ISIC_0026001
508 | ISIC_0032506
509 | ISIC_0030813
510 | ISIC_0026019
511 | ISIC_0027367
512 | ISIC_0024706
513 | ISIC_0027556
514 | ISIC_0029583
515 | ISIC_0029933
516 | ISIC_0028440
517 | ISIC_0030694
518 | ISIC_0025170
519 | ISIC_0026130
520 | ISIC_0030795
521 | ISIC_0033103
522 | ISIC_0029836
523 | ISIC_0026739
524 | ISIC_0032488
525 | ISIC_0028609
526 | ISIC_0027634
527 | ISIC_0027668
528 | ISIC_0029471
529 | ISIC_0025196
530 | ISIC_0030056
531 | ISIC_0031864
532 | ISIC_0032791
533 | ISIC_0028361
534 | ISIC_0029296
535 | ISIC_0026603
536 | ISIC_0029385
537 | ISIC_0028347
538 | ISIC_0025418
539 | ISIC_0024937
540 | ISIC_0031315
541 | ISIC_0032710
542 | ISIC_0028360
543 | ISIC_0026144
544 | ISIC_0030276
545 | ISIC_0031406
546 | ISIC_0029689
547 | ISIC_0031392
548 | ISIC_0032696
549 | ISIC_0033827
550 | ISIC_0030663
551 | ISIC_0029676
552 | ISIC_0028087
553 | ISIC_0031068
554 | ISIC_0034239
555 | ISIC_0029844
556 | ISIC_0031971
557 | ISIC_0027488
558 | ISIC_0028925
559 | ISIC_0030679
560 | ISIC_0033596
561 | ISIC_0026240
562 | ISIC_0027841
563 | ISIC_0032101
564 | ISIC_0032472
565 | ISIC_0027447
566 | ISIC_0030201
567 | ISIC_0026937
568 | ISIC_0024601
569 | ISIC_0034033
570 | ISIC_0027182
571 | ISIC_0033260
572 | ISIC_0033085
573 | ISIC_0028654
574 | ISIC_0030382
575 | ISIC_0031699
576 | ISIC_0026752
577 | ISIC_0031901
578 | ISIC_0026745
579 | ISIC_0032288
580 | ISIC_0031341
581 | ISIC_0033057
582 | ISIC_0024962
583 | ISIC_0026391
584 | ISIC_0026815
585 | ISIC_0030709
586 | ISIC_0032230
587 | ISIC_0033137
588 | ISIC_0033491
589 | ISIC_0030505
590 | ISIC_0028307
591 | ISIC_0030289
592 | ISIC_0026380
593 | ISIC_0024583
594 | ISIC_0028560
595 | ISIC_0033966
596 | ISIC_0031541
597 | ISIC_0024600
598 | ISIC_0031154
599 | ISIC_0030608
600 | ISIC_0033235
601 | ISIC_0031041
602 | ISIC_0027823
603 | ISIC_0031295
604 | ISIC_0027630
605 | ISIC_0027851
606 | ISIC_0031245
607 | ISIC_0025675
608 | ISIC_0028915
609 | ISIC_0028469
610 | ISIC_0033315
611 | ISIC_0025620
612 | ISIC_0028391
613 | ISIC_0026032
614 | ISIC_0026663
615 | ISIC_0027701
616 | ISIC_0032047
617 | ISIC_0031712
618 | ISIC_0030192
619 | ISIC_0030357
620 | ISIC_0033278
621 | ISIC_0030575
622 | ISIC_0033766
623 | ISIC_0025703
624 | ISIC_0025550
625 | ISIC_0031926
626 | ISIC_0025139
627 | ISIC_0024455
628 | ISIC_0027953
629 | ISIC_0029574
630 | ISIC_0024490
631 | ISIC_0025226
632 | ISIC_0025793
633 | ISIC_0027232
634 | ISIC_0029916
635 | ISIC_0030274
636 | ISIC_0027940
637 | ISIC_0030048
638 | ISIC_0026341
639 | ISIC_0030083
640 | ISIC_0024570
641 | ISIC_0027605
642 | ISIC_0029725
643 | ISIC_0026810
644 | ISIC_0028998
645 | ISIC_0031872
646 | ISIC_0027336
647 | ISIC_0028811
648 | ISIC_0032707
649 | ISIC_0031098
650 | ISIC_0026556
651 | ISIC_0026701
652 | ISIC_0025684
653 | ISIC_0025806
654 | ISIC_0026597
655 | ISIC_0033269
656 | ISIC_0031380
657 | ISIC_0027760
658 | ISIC_0032067
659 | ISIC_0033678
660 | ISIC_0027743
661 | ISIC_0028272
662 | ISIC_0031509
663 | ISIC_0027560
664 | ISIC_0031734
665 | ISIC_0034194
666 | ISIC_0033927
667 | ISIC_0024357
668 | ISIC_0030755
669 | ISIC_0032098
670 | ISIC_0031939
671 | ISIC_0032232
672 | ISIC_0033367
673 | ISIC_0031176
674 | ISIC_0026275
675 | ISIC_0027019
676 | ISIC_0032154
677 | ISIC_0029884
678 | ISIC_0025820
679 | ISIC_0031937
680 | ISIC_0033686
681 | ISIC_0032610
682 | ISIC_0027639
683 | ISIC_0024474
684 | ISIC_0026628
685 | ISIC_0028977
686 | ISIC_0024730
687 | ISIC_0026908
688 | ISIC_0027800
689 | ISIC_0028562
690 | ISIC_0025660
691 | ISIC_0028064
692 | ISIC_0031543
693 | ISIC_0027843
694 | ISIC_0032213
695 | ISIC_0027462
696 | ISIC_0027632
697 | ISIC_0028306
698 | ISIC_0034276
699 | ISIC_0025159
700 | ISIC_0032498
701 | ISIC_0033863
702 | ISIC_0031485
703 | ISIC_0026564
704 | ISIC_0024359
705 | ISIC_0031229
706 | ISIC_0033798
707 | ISIC_0026483
708 | ISIC_0030596
709 | ISIC_0030948
710 | ISIC_0033829
711 | ISIC_0028301
712 | ISIC_0029021
713 | ISIC_0028006
714 | ISIC_0029397
715 | ISIC_0033452
716 | ISIC_0030003
717 | ISIC_0027147
718 | ISIC_0032104
719 | ISIC_0025518
720 | ISIC_0032070
721 | ISIC_0033224
722 | ISIC_0033730
723 | ISIC_0028154
724 | ISIC_0028300
725 | ISIC_0034051
726 | ISIC_0033331
727 | ISIC_0027897
728 | ISIC_0032355
729 | ISIC_0031230
730 | ISIC_0024821
731 | ISIC_0032569
732 | ISIC_0027015
733 | ISIC_0024662
734 | ISIC_0031268
735 | ISIC_0032424
736 | ISIC_0029807
737 | ISIC_0033027
738 | ISIC_0027626
739 | ISIC_0024365
740 | ISIC_0029745
741 | ISIC_0024587
742 | ISIC_0033952
743 | ISIC_0027654
744 | ISIC_0025935
745 | ISIC_0027088
746 | ISIC_0033109
747 | ISIC_0029454
748 | ISIC_0032949
749 | ISIC_0027881
750 | ISIC_0031631
751 | ISIC_0034310
752 | ISIC_0025582
753 | ISIC_0028353
754 | ISIC_0025836
755 | ISIC_0031297
756 | ISIC_0025111
757 | ISIC_0025835
758 | ISIC_0031160
759 | ISIC_0033232
760 | ISIC_0032547
761 | ISIC_0029972
762 | ISIC_0026997
763 | ISIC_0029960
764 | ISIC_0027360
765 | ISIC_0025673
766 | ISIC_0029712
767 | ISIC_0032922
768 | ISIC_0026770
769 | ISIC_0024306
770 | ISIC_0033689
771 | ISIC_0028874
772 | ISIC_0030350
773 | ISIC_0029378
774 | ISIC_0024482
775 | ISIC_0025328
776 | ISIC_0033130
777 | ISIC_0033010
778 | ISIC_0025592
779 | ISIC_0031369
780 | ISIC_0026715
781 | ISIC_0030214
782 | ISIC_0027314
783 | ISIC_0030257
784 | ISIC_0030849
785 | ISIC_0026690
786 | ISIC_0025654
787 | ISIC_0029722
788 | ISIC_0028581
789 | ISIC_0030980
790 | ISIC_0030746
791 | ISIC_0024838
792 | ISIC_0033724
793 | ISIC_0024686
794 | ISIC_0033825
795 | ISIC_0027652
796 | ISIC_0031439
797 | ISIC_0026987
798 | ISIC_0025236
799 | ISIC_0026932
800 | ISIC_0026832
801 | ISIC_0034270
802 | ISIC_0034316
803 | ISIC_0032917
804 | ISIC_0025365
805 | ISIC_0028482
806 | ISIC_0031396
807 | ISIC_0024379
808 | ISIC_0030278
809 | ISIC_0027479
810 | ISIC_0026910
811 | ISIC_0031542
812 | ISIC_0025558
813 | ISIC_0028700
814 | ISIC_0033140
815 | ISIC_0026170
816 | ISIC_0028909
817 | ISIC_0032986
818 | ISIC_0032886
819 | ISIC_0033615
820 | ISIC_0033684
821 | ISIC_0025596
822 | ISIC_0027301
823 | ISIC_0029263
824 | ISIC_0030954
825 | ISIC_0024761
826 | ISIC_0030733
827 | ISIC_0025121
828 | ISIC_0031920
829 | ISIC_0027226
830 | ISIC_0026944
831 | ISIC_0033475
832 | ISIC_0029978
833 | ISIC_0027436
834 | ISIC_0034195
835 | ISIC_0032342
836 | ISIC_0029006
837 | ISIC_0027083
838 | ISIC_0026528
839 | ISIC_0032850
840 | ISIC_0031403
841 | ISIC_0031850
842 | ISIC_0032465
843 | ISIC_0029586
844 | ISIC_0028025
845 | ISIC_0033630
846 | ISIC_0031829
847 | ISIC_0030057
848 | ISIC_0028179
849 | ISIC_0025509
850 | ISIC_0025711
851 | ISIC_0026005
852 | ISIC_0032129
853 | ISIC_0030469
854 | ISIC_0029024
855 | ISIC_0033883
856 | ISIC_0029184
857 | ISIC_0030233
858 | ISIC_0024469
859 | ISIC_0026353
860 | ISIC_0031272
861 | ISIC_0030648
862 | ISIC_0027666
863 | ISIC_0029600
864 | ISIC_0032906
865 | ISIC_0028424
866 | ISIC_0030172
867 | ISIC_0027718
868 | ISIC_0029407
869 | ISIC_0024947
870 | ISIC_0031056
871 | ISIC_0025649
872 | ISIC_0028786
873 | ISIC_0025194
874 | ISIC_0028233
875 | ISIC_0024837
876 | ISIC_0030337
877 | ISIC_0032308
878 | ISIC_0031789
879 | ISIC_0031555
880 | ISIC_0032181
881 | ISIC_0026101
882 | ISIC_0031090
883 | ISIC_0033538
884 | ISIC_0030334
885 | ISIC_0026415
886 | ISIC_0029987
887 | ISIC_0027472
888 | ISIC_0024572
889 | ISIC_0028069
890 | ISIC_0025117
891 | ISIC_0024666
892 | ISIC_0025766
893 | ISIC_0025258
894 | ISIC_0026466
895 | ISIC_0026105
896 | ISIC_0031774
897 | ISIC_0032430
898 | ISIC_0031306
899 | ISIC_0027643
900 | ISIC_0033978
901 | ISIC_0029733
902 | ISIC_0032017
903 | ISIC_0030870
904 | ISIC_0026406
905 | ISIC_0032784
906 | ISIC_0031030
907 | ISIC_0029481
908 | ISIC_0033578
909 | ISIC_0028537
910 | ISIC_0033279
911 | ISIC_0026497
912 | ISIC_0032709
913 | ISIC_0027470
914 | ISIC_0032772
915 | ISIC_0033495
916 | ISIC_0024394
917 | ISIC_0026630
918 | ISIC_0032796
919 | ISIC_0032318
920 | ISIC_0028082
921 | ISIC_0024621
922 | ISIC_0032064
923 | ISIC_0027423
924 | ISIC_0028924
925 | ISIC_0026326
926 | ISIC_0026308
927 | ISIC_0027280
928 | ISIC_0025581
929 | ISIC_0031583
930 | ISIC_0024834
931 | ISIC_0030494
932 | ISIC_0028710
933 | ISIC_0027904
934 | ISIC_0029536
935 | ISIC_0027987
936 | ISIC_0025624
937 | ISIC_0031049
938 | ISIC_0031391
939 | ISIC_0030084
940 | ISIC_0027340
941 | ISIC_0030689
942 | ISIC_0029117
943 | ISIC_0033642
944 | ISIC_0025472
945 | ISIC_0029031
946 | ISIC_0034072
947 | ISIC_0033652
948 | ISIC_0030542
949 | ISIC_0025528
950 | ISIC_0030910
951 | ISIC_0029249
952 | ISIC_0028561
953 | ISIC_0031189
954 | ISIC_0026288
955 | ISIC_0033648
956 | ISIC_0028326
957 | ISIC_0032869
958 | ISIC_0027788
959 | ISIC_0031372
960 | ISIC_0025844
961 | ISIC_0034155
962 | ISIC_0027050
963 | ISIC_0025859
964 | ISIC_0031647
965 | ISIC_0027123
966 | ISIC_0033528
967 | ISIC_0032266
968 | ISIC_0033116
969 | ISIC_0026687
970 | ISIC_0029492
971 | ISIC_0028832
972 | ISIC_0027830
973 | ISIC_0026935
974 | ISIC_0031753
975 | ISIC_0028429
976 | ISIC_0033787
977 | ISIC_0027072
978 | ISIC_0025962
979 | ISIC_0029124
980 | ISIC_0032202
981 | ISIC_0032049
982 | ISIC_0026149
983 | ISIC_0031148
984 | ISIC_0027457
985 | ISIC_0034082
986 | ISIC_0027692
987 | ISIC_0029352
988 | ISIC_0028331
989 | ISIC_0033926
990 | ISIC_0024732
991 | ISIC_0028400
992 | ISIC_0029880
993 | ISIC_0026229
994 | ISIC_0030529
995 | ISIC_0034304
996 | ISIC_0031360
997 | ISIC_0032902
998 | ISIC_0032989
999 | ISIC_0030483
1000 | ISIC_0031204
1001 | ISIC_0026848
1002 | ISIC_0032491
1003 | ISIC_0027821
1004 | ISIC_0031337
1005 | ISIC_0034163
1006 | ISIC_0030113
1007 | ISIC_0029275
1008 | ISIC_0031325
1009 | ISIC_0030693
1010 | ISIC_0030462
1011 | ISIC_0033672
1012 | ISIC_0030854
1013 | ISIC_0031547
1014 | ISIC_0031669
1015 | ISIC_0031782
1016 | ISIC_0026147
1017 | ISIC_0029706
1018 | ISIC_0025515
1019 | ISIC_0024780
1020 | ISIC_0027617
1021 | ISIC_0032713
1022 | ISIC_0033154
1023 | ISIC_0028902
1024 | ISIC_0027541
1025 | ISIC_0031962
1026 | ISIC_0033714
1027 | ISIC_0031754
1028 | ISIC_0029366
1029 | ISIC_0028713
1030 | ISIC_0024437
1031 | ISIC_0029716
1032 | ISIC_0025647
1033 | ISIC_0026490
1034 | ISIC_0033597
1035 | ISIC_0033092
1036 | ISIC_0031034
1037 | ISIC_0033308
1038 | ISIC_0032252
1039 | ISIC_0024553
1040 | ISIC_0033468
1041 | ISIC_0024335
1042 | ISIC_0025425
1043 | ISIC_0030133
1044 | ISIC_0031876
1045 | ISIC_0031713
1046 | ISIC_0031879
1047 | ISIC_0029148
1048 | ISIC_0030477
1049 | ISIC_0030279
1050 | ISIC_0034280
1051 | ISIC_0034167
1052 | ISIC_0028972
1053 | ISIC_0024939
1054 | ISIC_0032544
1055 | ISIC_0032458
1056 | ISIC_0029051
1057 | ISIC_0033606
1058 | ISIC_0025174
1059 | ISIC_0029229
1060 | ISIC_0031471
1061 | ISIC_0031919
1062 | ISIC_0030425
1063 | ISIC_0028435
1064 | ISIC_0027865
1065 | ISIC_0026809
1066 | ISIC_0027873
1067 | ISIC_0028999
1068 | ISIC_0025166
1069 | ISIC_0032091
1070 | ISIC_0028411
1071 | ISIC_0031524
1072 | ISIC_0029191
1073 | ISIC_0026616
1074 | ISIC_0027065
1075 | ISIC_0031281
1076 | ISIC_0025154
1077 | ISIC_0025549
1078 | ISIC_0029019
1079 | ISIC_0029631
1080 | ISIC_0030808
1081 | ISIC_0030485
1082 | ISIC_0025486
1083 | ISIC_0027766
1084 | ISIC_0030799
1085 | ISIC_0026677
1086 | ISIC_0024308
1087 | ISIC_0031480
1088 | ISIC_0028317
1089 | ISIC_0029878
1090 | ISIC_0031265
1091 | ISIC_0025354
1092 | ISIC_0029340
1093 | ISIC_0028797
1094 | ISIC_0034038
1095 | ISIC_0033360
1096 | ISIC_0028169
1097 | ISIC_0029116
1098 | ISIC_0026857
1099 | ISIC_0034264
1100 | ISIC_0026737
1101 | ISIC_0028103
1102 | ISIC_0032490
1103 | ISIC_0024616
1104 | ISIC_0025924
1105 | ISIC_0025398
1106 | ISIC_0026246
1107 | ISIC_0030970
1108 | ISIC_0028131
1109 | ISIC_0030181
1110 | ISIC_0025535
1111 | ISIC_0031126
1112 | ISIC_0029161
1113 | ISIC_0025780
1114 | ISIC_0026039
1115 | ISIC_0026934
1116 | ISIC_0026928
1117 | ISIC_0032103
1118 | ISIC_0030765
1119 | ISIC_0028868
1120 | ISIC_0028655
1121 | ISIC_0033868
1122 | ISIC_0027140
1123 | ISIC_0027315
1124 | ISIC_0025068
1125 | ISIC_0031750
1126 | ISIC_0029416
1127 | ISIC_0028196
1128 | ISIC_0031116
1129 | ISIC_0033709
1130 | ISIC_0033747
1131 | ISIC_0027620
1132 | ISIC_0027747
1133 | ISIC_0033180
1134 | ISIC_0031138
1135 | ISIC_0029265
1136 | ISIC_0025297
1137 | ISIC_0025220
1138 | ISIC_0033504
1139 | ISIC_0032683
1140 | ISIC_0033420
1141 | ISIC_0027960
1142 | ISIC_0034096
1143 | ISIC_0031170
1144 | ISIC_0031724
1145 | ISIC_0026790
1146 | ISIC_0027454
1147 | ISIC_0025046
1148 | ISIC_0024337
1149 | ISIC_0025277
1150 | ISIC_0030145
1151 | ISIC_0028268
1152 | ISIC_0029866
1153 | ISIC_0025422
1154 | ISIC_0030111
1155 | ISIC_0024462
1156 | ISIC_0031913
1157 | ISIC_0026773
1158 | ISIC_0025137
1159 | ISIC_0029760
1160 | ISIC_0033333
1161 | ISIC_0027131
1162 | ISIC_0033000
1163 | ISIC_0033976
1164 | ISIC_0030766
1165 | ISIC_0031000
1166 | ISIC_0033421
1167 | ISIC_0031674
1168 | ISIC_0026350
1169 | ISIC_0026286
1170 | ISIC_0032747
1171 | ISIC_0026989
1172 | ISIC_0025652
1173 | ISIC_0026383
1174 | ISIC_0030993
1175 | ISIC_0027959
1176 | ISIC_0024661
1177 | ISIC_0025050
1178 | ISIC_0031499
1179 | ISIC_0034277
1180 | ISIC_0033342
1181 | ISIC_0025029
1182 | ISIC_0026282
1183 | ISIC_0027261
1184 | ISIC_0028646
1185 | ISIC_0026520
1186 | ISIC_0029700
1187 | ISIC_0034088
1188 | ISIC_0028176
1189 | ISIC_0030148
1190 | ISIC_0025564
1191 | ISIC_0032583
1192 | ISIC_0025052
1193 | ISIC_0029582
1194 | ISIC_0026173
1195 | ISIC_0030492
1196 | ISIC_0033460
1197 | ISIC_0024571
1198 | ISIC_0034084
1199 | ISIC_0026734
1200 | ISIC_0027445
1201 | ISIC_0032004
1202 | ISIC_0028839
1203 | ISIC_0027368
1204 | ISIC_0034177
1205 | ISIC_0024519
1206 | ISIC_0032132
1207 | ISIC_0029529
1208 | ISIC_0028253
1209 | ISIC_0026364
1210 | ISIC_0032392
1211 | ISIC_0030757
1212 | ISIC_0028017
1213 | ISIC_0027210
1214 | ISIC_0029228
1215 | ISIC_0028144
1216 | ISIC_0031688
1217 | ISIC_0032612
1218 | ISIC_0027442
1219 | ISIC_0024590
1220 | ISIC_0024401
1221 | ISIC_0026186
1222 | ISIC_0031333
1223 | ISIC_0032156
1224 | ISIC_0032165
1225 | ISIC_0026586
1226 | ISIC_0029842
1227 | ISIC_0026827
1228 | ISIC_0027919
1229 | ISIC_0031639
1230 | ISIC_0033917
1231 | ISIC_0026309
1232 | ISIC_0029652
1233 | ISIC_0028238
1234 | ISIC_0028884
1235 | ISIC_0032541
1236 | ISIC_0030223
1237 | ISIC_0028722
1238 | ISIC_0030659
1239 | ISIC_0027523
1240 | ISIC_0033876
1241 | ISIC_0028639
1242 | ISIC_0025309
1243 | ISIC_0031159
1244 | ISIC_0033677
1245 | ISIC_0030907
1246 | ISIC_0029601
1247 | ISIC_0032798
1248 | ISIC_0026813
1249 | ISIC_0033440
1250 | ISIC_0025490
1251 | ISIC_0025565
1252 | ISIC_0026828
1253 | ISIC_0031505
1254 | ISIC_0033634
1255 | ISIC_0028865
1256 | ISIC_0028407
1257 | ISIC_0033339
1258 | ISIC_0032824
1259 | ISIC_0026742
1260 | ISIC_0033665
1261 | ISIC_0033946
1262 | ISIC_0029294
1263 | ISIC_0031603
1264 | ISIC_0029096
1265 | ISIC_0028789
1266 | ISIC_0027409
1267 | ISIC_0031500
1268 | ISIC_0024909
1269 | ISIC_0028724
1270 | ISIC_0030269
1271 | ISIC_0033363
1272 | ISIC_0028348
1273 | ISIC_0033895
1274 | ISIC_0029867
1275 | ISIC_0030965
1276 | ISIC_0027673
1277 | ISIC_0025586
1278 | ISIC_0032034
1279 | ISIC_0033328
1280 | ISIC_0027455
1281 | ISIC_0028105
1282 | ISIC_0030522
1283 | ISIC_0030707
1284 | ISIC_0030664
1285 | ISIC_0024634
1286 | ISIC_0026010
1287 | ISIC_0026799
1288 | ISIC_0031379
1289 | ISIC_0032977
1290 | ISIC_0027935
1291 | ISIC_0033873
1292 | ISIC_0028316
1293 | ISIC_0034061
1294 | ISIC_0032615
1295 | ISIC_0030373
1296 | ISIC_0033377
1297 | ISIC_0033527
1298 | ISIC_0032964
1299 | ISIC_0030667
1300 | ISIC_0026505
1301 | ISIC_0032316
1302 | ISIC_0030895
1303 | ISIC_0027389
1304 | ISIC_0030371
1305 | ISIC_0029758
1306 | ISIC_0026207
1307 | ISIC_0033274
1308 | ISIC_0028970
1309 | ISIC_0026092
1310 | ISIC_0033785
1311 | ISIC_0033660
1312 | ISIC_0029494
1313 | ISIC_0029659
1314 | ISIC_0030400
1315 | ISIC_0024390
1316 | ISIC_0027793
1317 | ISIC_0027095
1318 | ISIC_0031661
1319 | ISIC_0024817
1320 | ISIC_0034233
1321 | ISIC_0030838
1322 | ISIC_0025037
1323 | ISIC_0029088
1324 | ISIC_0029167
1325 | ISIC_0028776
1326 | ISIC_0033789
1327 | ISIC_0027063
1328 | ISIC_0027165
1329 | ISIC_0033948
1330 | ISIC_0027428
1331 | ISIC_0027490
1332 | ISIC_0026361
1333 | ISIC_0031120
1334 | ISIC_0028640
1335 | ISIC_0024433
1336 | ISIC_0028756
1337 | ISIC_0024470
1338 | ISIC_0034054
1339 | ISIC_0031767
1340 | ISIC_0030721
1341 | ISIC_0027505
1342 | ISIC_0030292
1343 | ISIC_0025625
1344 | ISIC_0029893
1345 | ISIC_0031948
1346 | ISIC_0024778
1347 | ISIC_0026303
1348 | ISIC_0024324
1349 | ISIC_0029464
1350 | ISIC_0026443
1351 | ISIC_0025331
1352 | ISIC_0032595
1353 | ISIC_0030241
1354 | ISIC_0027420
1355 | ISIC_0032050
1356 | ISIC_0030416
1357 | ISIC_0027467
1358 | ISIC_0030246
1359 | ISIC_0034020
1360 | ISIC_0032194
1361 | ISIC_0024755
1362 | ISIC_0033048
1363 | ISIC_0034060
1364 | ISIC_0028678
1365 | ISIC_0030668
1366 | ISIC_0028050
1367 | ISIC_0024484
1368 | ISIC_0028835
1369 | ISIC_0025646
1370 | ISIC_0027854
1371 | ISIC_0025897
1372 | ISIC_0034030
1373 | ISIC_0031859
1374 | ISIC_0033239
1375 | ISIC_0026075
1376 | ISIC_0027902
1377 | ISIC_0024789
1378 | ISIC_0028273
1379 | ISIC_0031300
1380 | ISIC_0034267
1381 | ISIC_0026791
1382 | ISIC_0025086
1383 | ISIC_0028950
1384 | ISIC_0031303
1385 | ISIC_0031535
1386 | ISIC_0030365
1387 | ISIC_0024383
1388 | ISIC_0030995
1389 | ISIC_0031353
1390 | ISIC_0033916
1391 | ISIC_0032054
1392 | ISIC_0033171
1393 | ISIC_0026509
1394 | ISIC_0032956
1395 | ISIC_0030391
1396 | ISIC_0030312
1397 | ISIC_0025058
1398 | ISIC_0032619
1399 | ISIC_0027943
1400 | ISIC_0025005
1401 | ISIC_0032448
1402 | ISIC_0033741
1403 | ISIC_0031248
1404 | ISIC_0031490
1405 | ISIC_0025798
1406 | ISIC_0030634
1407 | ISIC_0026392
1408 | ISIC_0027671
1409 | ISIC_0032529
1410 | ISIC_0031513
1411 | ISIC_0026460
1412 | ISIC_0024375
1413 | ISIC_0025755
1414 | ISIC_0027481
1415 | ISIC_0028183
1416 | ISIC_0028265
1417 | ISIC_0027527
1418 | ISIC_0025837
1419 | ISIC_0028807
1420 | ISIC_0031984
1421 | ISIC_0031514
1422 | ISIC_0024810
1423 | ISIC_0031259
1424 | ISIC_0032837
1425 | ISIC_0027886
1426 | ISIC_0031395
1427 | ISIC_0030635
1428 | ISIC_0025104
1429 | ISIC_0029146
1430 | ISIC_0029380
1431 | ISIC_0034186
1432 | ISIC_0028101
1433 | ISIC_0024367
1434 | ISIC_0026081
1435 | ISIC_0030004
1436 | ISIC_0034255
1437 | ISIC_0026766
1438 | ISIC_0029813
1439 | ISIC_0034056
1440 | ISIC_0030841
1441 | ISIC_0025873
1442 | ISIC_0029048
1443 | ISIC_0028762
1444 | ISIC_0027154
1445 | ISIC_0027292
1446 | ISIC_0028760
1447 | ISIC_0031519
1448 | ISIC_0025874
1449 | ISIC_0033767
1450 | ISIC_0030518
1451 | ISIC_0031792
1452 | ISIC_0029325
1453 | ISIC_0028679
1454 | ISIC_0024687
1455 | ISIC_0030594
1456 | ISIC_0026718
1457 | ISIC_0028218
1458 | ISIC_0029426
1459 | ISIC_0025234
1460 | ISIC_0028468
1461 | ISIC_0033779
1462 | ISIC_0030478
1463 | ISIC_0032100
1464 | ISIC_0024551
1465 | ISIC_0025745
1466 | ISIC_0033520
1467 | ISIC_0032243
1468 | ISIC_0032152
1469 | ISIC_0032909
1470 | ISIC_0033560
1471 | ISIC_0028518
1472 | ISIC_0032767
1473 | ISIC_0029427
1474 | ISIC_0028649
1475 | ISIC_0027867
1476 | ISIC_0033163
1477 | ISIC_0033230
1478 | ISIC_0024420
1479 | ISIC_0026608
1480 | ISIC_0032616
1481 | ISIC_0026486
1482 | ISIC_0032660
1483 | ISIC_0032761
1484 | ISIC_0029715
1485 | ISIC_0025081
1486 | ISIC_0029177
1487 | ISIC_0031666
1488 | ISIC_0028982
1489 | ISIC_0033149
1490 | ISIC_0033920
1491 | ISIC_0033668
1492 | ISIC_0031592
1493 | ISIC_0028066
1494 | ISIC_0034269
1495 | ISIC_0030983
1496 | ISIC_0031312
1497 | ISIC_0028180
1498 | ISIC_0033544
1499 | ISIC_0025674
1500 | ISIC_0033395
1501 | ISIC_0033426
1502 | ISIC_0032512
1503 | ISIC_0033465
1504 | ISIC_0032361
1505 | ISIC_0025412
1506 | ISIC_0027731
1507 | ISIC_0028642
1508 | ISIC_0028274
1509 | ISIC_0025560
1510 | ISIC_0026954
1511 | ISIC_0027229
1512 | ISIC_0029086
1513 | ISIC_0025945
1514 | ISIC_0025439
1515 | ISIC_0027825
1516 | ISIC_0030821
1517 | ISIC_0032031
1518 | ISIC_0024577
1519 | ISIC_0029655
1520 | ISIC_0025853
1521 | ISIC_0028616
1522 | ISIC_0025324
1523 | ISIC_0028746
1524 | ISIC_0033336
1525 | ISIC_0030017
1526 | ISIC_0029197
1527 | ISIC_0032197
1528 | ISIC_0032910
1529 | ISIC_0028744
1530 | ISIC_0032875
1531 | ISIC_0025797
1532 | ISIC_0033236
1533 | ISIC_0027921
1534 | ISIC_0024358
1535 | ISIC_0029711
1536 | ISIC_0029400
1537 | ISIC_0027400
1538 | ISIC_0025695
1539 | ISIC_0032190
1540 | ISIC_0027784
1541 | ISIC_0028497
1542 | ISIC_0029901
1543 | ISIC_0027528
1544 | ISIC_0025938
1545 | ISIC_0030000
1546 | ISIC_0032037
1547 | ISIC_0024773
1548 | ISIC_0034091
1549 | ISIC_0033401
1550 | ISIC_0024377
1551 | ISIC_0033942
1552 | ISIC_0029358
1553 | ISIC_0025513
1554 | ISIC_0025896
1555 | ISIC_0029465
1556 | ISIC_0029930
1557 | ISIC_0024964
1558 | ISIC_0032207
1559 | ISIC_0032561
1560 | ISIC_0030844
1561 | ISIC_0029478
1562 | ISIC_0027342
1563 | ISIC_0032952
1564 | ISIC_0029785
1565 | ISIC_0030784
1566 | ISIC_0024708
1567 | ISIC_0029391
1568 | ISIC_0029847
1569 | ISIC_0025778
1570 | ISIC_0033175
1571 | ISIC_0027662
1572 | ISIC_0033749
1573 | ISIC_0026289
1574 | ISIC_0029224
1575 | ISIC_0029365
1576 | ISIC_0028149
1577 | ISIC_0025339
1578 | ISIC_0024480
1579 | ISIC_0033437
1580 | ISIC_0027408
1581 | ISIC_0032019
1582 | ISIC_0032832
1583 | ISIC_0033371
1584 | ISIC_0027159
1585 | ISIC_0025792
1586 | ISIC_0030457
1587 | ISIC_0028599
1588 | ISIC_0032794
1589 | ISIC_0031866
1590 | ISIC_0033943
1591 | ISIC_0034107
1592 | ISIC_0029932
1593 | ISIC_0028151
1594 | ISIC_0030147
1595 | ISIC_0032876
1596 | ISIC_0027416
1597 | ISIC_0028224
1598 | ISIC_0028444
1599 | ISIC_0026478
1600 | ISIC_0026905
1601 | ISIC_0032319
1602 | ISIC_0033413
1603 | ISIC_0030273
1604 | ISIC_0034103
1605 | ISIC_0030173
1606 | ISIC_0028126
1607 | ISIC_0033454
1608 | ISIC_0026193
1609 | ISIC_0032642
1610 | ISIC_0028023
1611 | ISIC_0025488
1612 | ISIC_0034307
1613 | ISIC_0029310
1614 | ISIC_0028908
1615 | ISIC_0033921
1616 | ISIC_0034295
1617 | ISIC_0026028
1618 | ISIC_0027190
1619 | ISIC_0027300
1620 | ISIC_0029069
1621 | ISIC_0033384
1622 | ISIC_0028362
1623 | ISIC_0029686
1624 | ISIC_0032831
1625 | ISIC_0032651
1626 | ISIC_0029822
1627 | ISIC_0030951
1628 | ISIC_0030242
1629 | ISIC_0032186
1630 | ISIC_0025027
1631 | ISIC_0024679
1632 | ISIC_0032492
1633 | ISIC_0032329
1634 | ISIC_0031040
1635 | ISIC_0024618
1636 | ISIC_0032422
1637 | ISIC_0024348
1638 | ISIC_0031339
1639 | ISIC_0031351
1640 | ISIC_0031502
1641 | ISIC_0032915
1642 | ISIC_0029059
1643 | ISIC_0027097
1644 | ISIC_0033202
1645 | ISIC_0025038
1646 | ISIC_0030900
1647 | ISIC_0025250
1648 | ISIC_0029562
1649 | ISIC_0033516
1650 | ISIC_0032493
1651 | ISIC_0029871
1652 | ISIC_0033405
1653 | ISIC_0025999
1654 | ISIC_0024825
1655 | ISIC_0025120
1656 | ISIC_0025124
1657 | ISIC_0029136
1658 | ISIC_0034015
1659 | ISIC_0031364
1660 | ISIC_0029523
1661 | ISIC_0028040
1662 | ISIC_0030578
1663 | ISIC_0030316
1664 | ISIC_0031139
1665 | ISIC_0032096
1666 | ISIC_0026898
1667 | ISIC_0024685
1668 | ISIC_0026557
1669 | ISIC_0028745
1670 | ISIC_0026454
1671 | ISIC_0028032
1672 | ISIC_0026872
1673 | ISIC_0030344
1674 | ISIC_0026109
1675 | ISIC_0034159
1676 | ISIC_0031451
1677 | ISIC_0027927
1678 | ISIC_0030414
1679 | ISIC_0025900
1680 | ISIC_0031367
1681 | ISIC_0029196
1682 | ISIC_0031916
1683 | ISIC_0033213
1684 | ISIC_0026865
1685 | ISIC_0028514
1686 | ISIC_0029401
1687 | ISIC_0031834
1688 | ISIC_0025888
1689 | ISIC_0026853
1690 | ISIC_0032817
1691 | ISIC_0031384
1692 | ISIC_0028627
1693 | ISIC_0029444
1694 | ISIC_0031190
1695 | ISIC_0030847
1696 | ISIC_0026069
1697 | ISIC_0030990
1698 | ISIC_0030778
1699 | ISIC_0027149
1700 | ISIC_0029250
1701 | ISIC_0032711
1702 | ISIC_0028739
1703 | ISIC_0033301
1704 | ISIC_0032427
1705 | ISIC_0030662
1706 | ISIC_0029584
1707 | ISIC_0031586
1708 | ISIC_0028438
1709 | ISIC_0024411
1710 | ISIC_0029252
1711 | ISIC_0027112
1712 | ISIC_0030815
1713 | ISIC_0033963
1714 | ISIC_0031131
1715 | ISIC_0027259
1716 | ISIC_0033063
1717 | ISIC_0031825
1718 | ISIC_0028651
1719 | ISIC_0032108
1720 | ISIC_0033972
1721 | ISIC_0027542
1722 | ISIC_0027135
1723 | ISIC_0026808
1724 | ISIC_0024623
1725 | ISIC_0025506
1726 | ISIC_0032225
1727 | ISIC_0026584
1728 | ISIC_0031769
1729 | ISIC_0033150
1730 | ISIC_0032568
1731 | ISIC_0032294
1732 | ISIC_0026964
1733 | ISIC_0025480
1734 | ISIC_0025394
1735 | ISIC_0033040
1736 | ISIC_0032326
1737 | ISIC_0032908
1738 | ISIC_0028370
1739 | ISIC_0034217
1740 | ISIC_0031601
1741 | ISIC_0026397
1742 | ISIC_0028941
1743 | ISIC_0024907
1744 | ISIC_0029734
1745 | ISIC_0032059
1746 | ISIC_0026525
1747 | ISIC_0030164
1748 | ISIC_0033335
1749 | ISIC_0024930
1750 | ISIC_0033304
1751 | ISIC_0029284
1752 | ISIC_0029886
1753 | ISIC_0027421
1754 | ISIC_0033135
1755 | ISIC_0032921
1756 | ISIC_0029129
1757 | ISIC_0031417
1758 | ISIC_0026302
1759 | ISIC_0032396
1760 | ISIC_0028778
1761 | ISIC_0034230
1762 | ISIC_0026792
1763 | ISIC_0032648
1764 | ISIC_0025886
1765 | ISIC_0030351
1766 | ISIC_0030176
1767 | ISIC_0033124
1768 | ISIC_0031431
1769 | ISIC_0027308
1770 | ISIC_0028907
1771 | ISIC_0024449
1772 | ISIC_0029119
1773 | ISIC_0031033
1774 | ISIC_0029557
1775 | ISIC_0027758
1776 | ISIC_0027911
1777 | ISIC_0032177
1778 | ISIC_0028726
1779 | ISIC_0026477
1780 | ISIC_0029281
1781 | ISIC_0029477
1782 | ISIC_0029026
1783 | ISIC_0027138
1784 | ISIC_0025833
1785 | ISIC_0031730
1786 | ISIC_0033115
1787 | ISIC_0033822
1788 | ISIC_0030384
1789 | ISIC_0024996
1790 | ISIC_0026594
1791 | ISIC_0027608
1792 | ISIC_0030091
1793 | ISIC_0033848
1794 | ISIC_0030908
1795 | ISIC_0030088
1796 | ISIC_0030094
1797 | ISIC_0028867
1798 | ISIC_0025493
1799 | ISIC_0025497
1800 | ISIC_0032643
1801 | ISIC_0030656
1802 | ISIC_0027878
1803 | ISIC_0031282
1804 | ISIC_0026140
1805 | ISIC_0033609
1806 | ISIC_0031705
1807 | ISIC_0030032
1808 | ISIC_0026805
1809 | ISIC_0030341
1810 | ISIC_0028502
1811 | ISIC_0028133
1812 | ISIC_0025184
1813 | ISIC_0025761
1814 | ISIC_0032889
1815 | ISIC_0029677
1816 | ISIC_0032261
1817 | ISIC_0032973
1818 | ISIC_0026000
1819 | ISIC_0033793
1820 | ISIC_0026398
1821 | ISIC_0031238
1822 | ISIC_0033635
1823 | ISIC_0029637
1824 | ISIC_0032094
1825 | ISIC_0025138
1826 | ISIC_0031643
1827 | ISIC_0024897
1828 | ISIC_0030022
1829 | ISIC_0030037
1830 | ISIC_0024415
1831 | ISIC_0031804
1832 | ISIC_0025358
1833 | ISIC_0025617
1834 | ISIC_0024974
1835 | ISIC_0032751
1836 | ISIC_0031762
1837 | ISIC_0029474
1838 | ISIC_0029894
1839 | ISIC_0025149
1840 | ISIC_0028815
1841 | ISIC_0033382
1842 | ISIC_0028853
1843 | ISIC_0029660
1844 | ISIC_0029058
1845 | ISIC_0029983
1846 | ISIC_0029888
1847 | ISIC_0025791
1848 | ISIC_0029729
1849 | ISIC_0027043
1850 | ISIC_0026573
1851 | ISIC_0030471
1852 | ISIC_0027114
1853 | ISIC_0025175
1854 | ISIC_0026774
1855 | ISIC_0031799
1856 | ISIC_0032948
1857 | ISIC_0029962
1858 | ISIC_0026285
1859 | ISIC_0025002
1860 | ISIC_0026047
1861 | ISIC_0028024
1862 | ISIC_0033309
1863 | ISIC_0032102
1864 | ISIC_0026495
1865 | ISIC_0029312
1866 | ISIC_0029923
1867 | ISIC_0029451
1868 | ISIC_0032453
1869 | ISIC_0024811
1870 | ISIC_0034193
1871 | ISIC_0027609
1872 | ISIC_0032903
1873 | ISIC_0032203
1874 | ISIC_0027946
1875 | ISIC_0030767
1876 | ISIC_0024489
1877 | ISIC_0031570
1878 | ISIC_0027388
1879 | ISIC_0026182
1880 | ISIC_0030119
1881 | ISIC_0024915
1882 | ISIC_0026339
1883 | ISIC_0028287
1884 | ISIC_0030935
1885 | ISIC_0030728
1886 | ISIC_0026274
1887 | ISIC_0031080
1888 | ISIC_0024326
1889 | ISIC_0031489
1890 | ISIC_0032825
1891 | ISIC_0027576
1892 | ISIC_0026824
1893 | ISIC_0030290
1894 | ISIC_0032168
1895 | ISIC_0026484
1896 | ISIC_0029285
1897 | ISIC_0027716
1898 | ISIC_0034219
1899 | ISIC_0024795
1900 | ISIC_0032046
1901 | ISIC_0033104
1902 | ISIC_0027518
1903 | ISIC_0024458
1904 | ISIC_0025195
1905 | ISIC_0030729
1906 | ISIC_0024575
1907 | ISIC_0024402
1908 | ISIC_0031982
1909 | ISIC_0031019
1910 | ISIC_0028191
1911 | ISIC_0034070
1912 | ISIC_0032016
1913 | ISIC_0025373
1914 | ISIC_0031806
1915 | ISIC_0029501
1916 | ISIC_0027066
1917 | ISIC_0030527
1918 | ISIC_0030713
1919 | ISIC_0025021
1920 | ISIC_0028365
1921 | ISIC_0034279
1922 | ISIC_0028871
1923 | ISIC_0026104
1924 | ISIC_0024767
1925 | ISIC_0029679
1926 | ISIC_0027250
1927 | ISIC_0033843
1928 | ISIC_0029097
1929 | ISIC_0029577
1930 | ISIC_0030154
1931 | ISIC_0029330
1932 | ISIC_0031776
1933 | ISIC_0025520
1934 | ISIC_0025678
1935 | ISIC_0024549
1936 | ISIC_0028345
1937 | ISIC_0032267
1938 | ISIC_0031561
1939 | ISIC_0025985
1940 | ISIC_0031892
1941 | ISIC_0033810
1942 | ISIC_0024954
1943 | ISIC_0028715
1944 | ISIC_0024461
1945 | ISIC_0027734
1946 | ISIC_0033801
1947 | ISIC_0025456
1948 | ISIC_0027353
1949 | ISIC_0025352
1950 | ISIC_0034138
1951 | ISIC_0025612
1952 | ISIC_0032212
1953 | ISIC_0026550
1954 | ISIC_0024321
1955 | ISIC_0026975
1956 | ISIC_0033858
1957 | ISIC_0026150
1958 | ISIC_0027404
1959 | ISIC_0025918
1960 | ISIC_0032995
1961 | ISIC_0031414
1962 | ISIC_0025643
1963 | ISIC_0025784
1964 | ISIC_0033593
1965 | ISIC_0024848
1966 | ISIC_0028380
1967 | ISIC_0024737
1968 | ISIC_0031077
1969 | ISIC_0033226
1970 | ISIC_0026867
1971 | ISIC_0033835
1972 | ISIC_0027912
1973 | ISIC_0027183
1974 | ISIC_0028721
1975 | ISIC_0025337
1976 | ISIC_0030427
1977 | ISIC_0030591
1978 | ISIC_0025416
1979 | ISIC_0026194
1980 | ISIC_0034025
1981 | ISIC_0025059
1982 | ISIC_0027355
1983 | ISIC_0032457
1984 | ISIC_0024868
1985 | ISIC_0030848
1986 | ISIC_0032399
1987 | ISIC_0029070
1988 | ISIC_0027868
1989 | ISIC_0032750
1990 | ISIC_0026822
1991 | ISIC_0027268
1992 | ISIC_0029680
1993 | ISIC_0029702
1994 | ISIC_0033700
1995 | ISIC_0027768
1996 | ISIC_0032057
1997 | ISIC_0028531
1998 | ISIC_0024450
1999 | ISIC_0029362
2000 | ISIC_0027201
2001 | ISIC_0032119
2002 | ISIC_0025211
2003 | ISIC_0029498
2004 | ISIC_0024541
2005 |
--------------------------------------------------------------------------------
/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 |
3 |
4 | def parse_transform(transform: str, image_size=224, **transform_kwargs):
5 | if transform == 'RandomColorJitter':
6 | return transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.0)], p=1.0)
7 | elif transform == 'RandomGrayscale':
8 | return transforms.RandomGrayscale(p=0.1)
9 | elif transform == 'RandomGaussianBlur':
10 | return transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5))], p=0.3)
11 | elif transform == 'RandomCrop':
12 | return transforms.RandomCrop(image_size)
13 | elif transform == 'RandomResizedCrop':
14 | return transforms.RandomResizedCrop(image_size)
15 | elif transform == 'CenterCrop':
16 | return transforms.CenterCrop(image_size)
17 | elif transform == 'Resize_up':
18 | return transforms.Resize(
19 | [int(image_size * 1.15),
20 | int(image_size * 1.15)])
21 | elif transform == 'Normalize':
22 | return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
23 | elif transform == 'Resize':
24 | return transforms.Resize(
25 | [int(image_size),
26 | int(image_size)])
27 | elif transform == 'RandomRotation':
28 | return transforms.RandomRotation(degrees=10)
29 | else:
30 | method = getattr(transforms, transform)
31 | return method(**transform_kwargs)
32 |
33 |
34 | def get_composed_transform(augmentation: str = None, image_size=224) -> transforms.Compose:
35 | if augmentation == 'base':
36 | transform_list = ['RandomResizedCrop', 'RandomColorJitter', 'RandomHorizontalFlip', 'ToTensor',
37 | 'Normalize']
38 | elif augmentation == 'strong':
39 | transform_list = ['RandomResizedCrop', 'RandomColorJitter', 'RandomGrayscale', 'RandomGaussianBlur',
40 | 'RandomHorizontalFlip', 'ToTensor', 'Normalize']
41 | elif augmentation is None or augmentation.lower() == 'none':
42 | transform_list = ['Resize', 'ToTensor', 'Normalize']
43 | else:
44 | raise ValueError('Unsupported augmentation: {}'.format(augmentation))
45 |
46 | transform_funcs = [parse_transform(x, image_size=image_size) for x in transform_list]
47 | transform = transforms.Compose(transform_funcs)
48 | return transform
49 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import math
4 | import os
5 |
6 | import pandas as pd
7 | import torch.nn as nn
8 |
9 | from backbone import get_backbone_class
10 | from datasets.dataloader import get_labeled_episodic_dataloader
11 | from io_utils import parse_args
12 | from model import get_model_class
13 | from model.classifier_head import get_classifier_head_class
14 | from paths import get_output_directory, get_ft_output_directory, get_ft_train_history_path, get_ft_test_history_path, \
15 | get_final_pretrain_state_path, get_pretrain_state_path, get_ft_params_path
16 | from utils import *
17 |
18 |
19 | def main(params):
20 | base_output_dir = get_output_directory(params)
21 | output_dir = get_ft_output_directory(params)
22 | print('Running fine-tune with output folder:')
23 | print(output_dir)
24 | print()
25 |
26 | # Settings
27 | n_episodes = 600
28 | bs = params.ft_batch_size
29 | w = params.n_way
30 | s = params.n_shot
31 | q = params.n_query_shot
32 | # Whether to optimize for fixed features (when there is no augmentation and only head is updated)
33 | use_fixed_features = params.ft_augmentation is None and params.ft_parts == 'head'
34 |
35 | # Model
36 | backbone = get_backbone_class(params.backbone)()
37 | body = get_model_class(params.model)(backbone, params)
38 | if params.ft_features is not None:
39 | if params.ft_features not in body.supported_feature_selectors:
40 | raise ValueError(
41 | 'Feature selector "{}" is not supported for model "{}"'.format(params.ft_features, params.model))
42 |
43 | # Dataloaders
44 | # Note that both dataloaders sample identical episodes, via episode_seed
45 | support_epochs = 1 if use_fixed_features else params.ft_epochs
46 | support_loader = get_labeled_episodic_dataloader(params.target_dataset, n_way=w, n_shot=s, support=True,
47 | n_query_shot=q, n_episodes=n_episodes, n_epochs=support_epochs,
48 | augmentation=params.ft_augmentation,
49 | unlabeled_ratio=params.unlabeled_ratio,
50 | num_workers=params.num_workers,
51 | split_seed=params.split_seed, episode_seed=params.ft_episode_seed)
52 | query_loader = get_labeled_episodic_dataloader(params.target_dataset, n_way=w, n_shot=s, support=False,
53 | n_query_shot=q, n_episodes=n_episodes, augmentation=None,
54 | unlabeled_ratio=params.unlabeled_ratio,
55 | num_workers=params.num_workers,
56 | split_seed=params.split_seed,
57 | episode_seed=params.ft_episode_seed)
58 | assert (len(query_loader) == n_episodes)
59 | assert (len(support_loader) == n_episodes * support_epochs)
60 |
61 | query_iterator = iter(query_loader)
62 | support_iterator = iter(support_loader)
63 | support_batches = math.ceil(w * s / bs)
64 |
65 | # Output (history, params)
66 | train_history_path = get_ft_train_history_path(output_dir)
67 | test_history_path = get_ft_test_history_path(output_dir)
68 | params_path = get_ft_params_path(output_dir)
69 | print('Saving finetune params to {}'.format(params_path))
70 | print('Saving finetune train history to {}'.format(train_history_path))
71 | print('Saving finetune validation history to {}'.format(train_history_path))
72 | with open(params_path, 'w') as f:
73 | json.dump(vars(params), f, indent=4)
74 | df_train = pd.DataFrame(None, index=list(range(1, n_episodes + 1)),
75 | columns=['epoch{}'.format(e + 1) for e in range(params.ft_epochs)])
76 | df_test = pd.DataFrame(None, index=list(range(1, n_episodes + 1)),
77 | columns=['epoch{}'.format(e + 1) for e in range(params.ft_epochs)])
78 |
79 | # Pre-train state
80 | if params.ft_pretrain_epoch is None:
81 | body_state_path = get_final_pretrain_state_path(base_output_dir)
82 | else:
83 | body_state_path = get_pretrain_state_path(base_output_dir, params.ft_pretrain_epoch)
84 | if not os.path.exists(body_state_path):
85 | raise ValueError('Invalid pre-train state path: ' + body_state_path)
86 | print('Using pre-train state:')
87 | print(body_state_path)
88 | print()
89 | state = torch.load(body_state_path)
90 |
91 | # HOTFIX
92 | # print("HOTFIX: removing classifier weights from state")
93 | # del state["classifier.weight"]
94 | # del state["classifier.bias"]
95 |
96 | # Loss function
97 | loss_fn = nn.CrossEntropyLoss().cuda()
98 |
99 | print('Starting fine-tune')
100 | if use_fixed_features:
101 | print('Running optimized fixed-feature fine-tuning (no augmentation, fixed body)')
102 | print()
103 |
104 | for episode in range(n_episodes):
105 | # Reset models for each episode
106 | # classifier.bias issue
107 |
108 | # HOTFIX: load state dict non-strict to ignore classifier weights
109 | # body.load_state_dict(copy.deepcopy(state), strict=False) # note, override model.load_state_dict to change this behavior.
110 | body.load_state_dict(copy.deepcopy(state)) # note, override model.load_state_dict to change this behavior.
111 | head = get_classifier_head_class(params.ft_head)(body.final_feat_dim, params.n_way, params) # TODO: apply ft_features
112 | body.cuda()
113 | head.cuda()
114 |
115 | opt_params = []
116 | if params.ft_train_head:
117 | opt_params.append({'params': head.parameters()})
118 | if params.ft_train_body:
119 | opt_params.append({'params': body.parameters()})
120 | optimizer = torch.optim.SGD(opt_params, lr=params.ft_lr, momentum=0.9, dampening=0.9, weight_decay=0.001)
121 |
122 | # Labels are always [0, 0, ..., 1, ..., w-1]
123 | x_support = None
124 | f_support = None
125 | y_support = torch.arange(w).repeat_interleave(s).cuda()
126 | x_query = next(query_iterator)[0].cuda()
127 |
128 | f_query = None
129 | y_query = torch.arange(w).repeat_interleave(q).cuda()
130 |
131 | if use_fixed_features: # load data and extract features once per episode
132 | with torch.no_grad():
133 | x_support, _ = next(support_iterator)
134 | x_support = x_support.cuda()
135 | f_support = body.forward_features(x_support, params.ft_features)
136 | f_query = body.forward_features(x_query, params.ft_features)
137 |
138 | train_acc_history = []
139 | test_acc_history = []
140 | for epoch in range(params.ft_epochs):
141 | # Train
142 | body.train()
143 | head.train()
144 |
145 | if not use_fixed_features: # load data every epoch
146 | x_support, _ = next(support_iterator)
147 | x_support = x_support.cuda()
148 |
149 | total_loss = 0
150 | correct = 0
151 | indices = np.random.permutation(w * s)
152 | for i in range(support_batches):
153 | start_index = i * bs
154 | end_index = min(i * bs + bs, w * s)
155 | batch_indices = indices[start_index:end_index]
156 | y = y_support[batch_indices]
157 |
158 | if use_fixed_features:
159 | f = f_support[batch_indices]
160 | else:
161 | f = body.forward_features(x_support[batch_indices], params.ft_features)
162 | p = head(f)
163 |
164 | correct += torch.eq(y, p.argmax(dim=1)).sum()
165 | loss = loss_fn(p, y)
166 | optimizer.zero_grad()
167 | loss.backward()
168 | optimizer.step()
169 |
170 | total_loss += loss.item()
171 |
172 | train_loss = total_loss / support_batches
173 | train_acc = correct / (w * s)
174 |
175 | # Evaluation
176 | body.eval()
177 | head.eval()
178 |
179 | if params.ft_intermediate_test or epoch == params.ft_epochs - 1:
180 | with torch.no_grad():
181 | if not use_fixed_features:
182 | f_query = body.forward_features(x_query, params.ft_features)
183 | p_query = head(f_query)
184 | test_acc = torch.eq(y_query, p_query.argmax(dim=1)).sum() / (w * q)
185 | else:
186 | test_acc = torch.tensor(0)
187 |
188 | print_epoch_logs = False
189 | if print_epoch_logs and (epoch + 1) % 10 == 0:
190 | fmt = 'Epoch {:03d}: Loss={:6.3f} Train ACC={:6.3f} Test ACC={:6.3f}'
191 | print(fmt.format(epoch + 1, train_loss, train_acc, test_acc))
192 |
193 | train_acc_history.append(train_acc.item())
194 | test_acc_history.append(test_acc.item())
195 |
196 | df_train.loc[episode + 1] = train_acc_history
197 | df_train.to_csv(train_history_path)
198 | df_test.loc[episode + 1] = test_acc_history
199 | df_test.to_csv(test_history_path)
200 |
201 | fmt = 'Episode {:03d}: train_loss={:6.4f} train_acc={:6.2f} test_acc={:6.2f}'
202 | print(fmt.format(episode, train_loss, train_acc_history[-1] * 100, test_acc_history[-1] * 100))
203 |
204 | fmt = 'Final Results: Acc={:5.2f} Std={:5.2f}'
205 | print(fmt.format(df_test.mean()[-1] * 100, 1.96 * df_test.std()[-1] / np.sqrt(n_episodes) * 100))
206 |
207 | print('Saved history to:')
208 | print(train_history_path)
209 | print(test_history_path)
210 | df_train.to_csv(train_history_path)
211 | df_test.to_csv(test_history_path)
212 |
213 |
214 | if __name__ == '__main__':
215 | np.random.seed(10)
216 | params = parse_args()
217 |
218 | targets = params.target_dataset
219 | if targets is None:
220 | targets = [targets]
221 | elif len(targets) > 1:
222 | print('#' * 80)
223 | print("Running finetune iteratively for multiple target datasets: {}".format(targets))
224 | print('#' * 80)
225 |
226 | for target in targets:
227 | params.target_dataset = target
228 | main(params)
229 |
--------------------------------------------------------------------------------
/finetune.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 |
4 | SOURCES=("miniImageNet" "tieredImageNet" "ImageNet")
5 | SOURCE=${SOURCES[2]}
6 |
7 | TARGETS=("CropDisease" "ISIC" "EuroSAT" "ChestX" "places" "cub" "plantae" "cars")
8 | TARGET=${TARGETS[0]}
9 |
10 | # BACKBONE=resnet10 # for mini
11 | BACKBONE=resnet18 # for tiered and full imagenet
12 |
13 | N_SHOT=5
14 |
15 |
16 | # Source SL
17 | python finetune.py --ls --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "base" --tag "default" --n_shot $N_SHOT
18 |
19 | # Target SSL
20 | python finetune.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default" --n_shot $N_SHOT
21 | python finetune.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default" --n_shot $N_SHOT
22 |
23 | # MSL (Source SL + Target SSL)
24 | python finetune.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78" --n_shot $N_SHOT
25 | python finetune.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78" --n_shot $N_SHOT
26 |
27 | # Two-Stage SSL (Source SL -> Target SSL)
28 | python finetune.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default" --n_shot $N_SHOT
29 | python finetune.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default" --n_shot $N_SHOT
30 |
31 | # Two-Stage MSL (Source SL -> Source SL + Target SSL)
32 | python finetune.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78" --n_shot $N_SHOT
33 | python finetune.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78" --n_shot $N_SHOT
34 |
--------------------------------------------------------------------------------
/io_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def parse_args():
5 | parser = argparse.ArgumentParser(description='CD-FSL')
6 | parser.add_argument('--dataset', default='miniImageNet', help='training base model')
7 | parser.add_argument('--backbone', default='resnet10', help='Refer to backbone._backbone_class_map')
8 | parser.add_argument('--model', default='base', help='backbone architecture')
9 |
10 | parser.add_argument('--num_classes', default=200, type=int,
11 | help='deprecated. Value is overwritten based on `target_dataset`')
12 |
13 | parser.add_argument('--source_dataset', default='miniImageNet')
14 | parser.add_argument('--target_dataset', type=str,
15 | nargs='+') # replaces dataset_names / HOTFIX: changed to list to allow for multiple targets with one CLI command
16 | parser.add_argument('--imagenet_pretrained', action="store_true", help='Use ImageNet pretrained weights')
17 |
18 | # Split related params
19 | parser.add_argument('--unlabeled_ratio', default=20, type=int,
20 | help='Percentage of dataset used for unlabeled split')
21 | parser.add_argument('--split_seed', default=1, type=int,
22 | help='Random seed used for split. If set to 1 and unlabeled_ratio==20, will use split defined by STARTUP')
23 |
24 | # Pre-train params (determines pre-trained model output directory)
25 | # These must be specified during evaluation and fine-tuning to select pre-trained model
26 | parser.add_argument('--pls', action='store_true',
27 | help='Second-step pre-training on top of model trained with source labeled data')
28 | parser.add_argument('--put', action='store_true',
29 | help='Second-step pre-training on top of model trained with target unlabeled data')
30 | parser.add_argument('--pmsl', action='store_true',
31 | help='Second-step pre-training on top of model trained with MSL (instead of pls_put)')
32 | parser.add_argument('--ls', action='store_true', help='Use labeled source data for pre-training')
33 | parser.add_argument('--us', action='store_true', help='Use unlabeled source data for pre-training')
34 | parser.add_argument('--ut', action='store_true', help='Use unlabeled target data for pre-training')
35 | parser.add_argument('--tag', default='default', type=str,
36 | help='Tag used to differentiate output directories for pre-trained models') # similar to aug_mode
37 | parser.add_argument('--pls_tag', default=None, type=str, help='Deprecated. Please use `previous_tag`.')
38 | parser.add_argument('--previous_tag', default=None, type=str,
39 | help='Tag of pre-trained previous model for pls, put, pmsl. Uses --tag by default.')
40 |
41 | # Pre-train params (non-identifying, i.e., does not affect output directory)
42 | # You must specify --tag to differentiate models with different non-identifying parameters)
43 | parser.add_argument('--augmentation', default='strong', type=str,
44 | help="Augmentation used for pre-training {'base', 'strong'}") # similar to aug_mode
45 | parser.add_argument('--batch_size', default=64, type=int,
46 | help='Batch size for pre-training.') # similar to aug_mode
47 | parser.add_argument('--ls_batch_size', default=None, type=int,
48 | help='Batch size for LS source pre-training.') # if None, reverts to batch_size
49 | parser.add_argument('--lr', default=None, type=float, help='LR for pre-training.')
50 | parser.add_argument('--gamma', default=0.5, type=float, help='Gamma value for {LS,US} + UT.') # similar to aug_mode
51 | parser.add_argument('--gamma_schedule', default=None, type=str, help='None | "linear"')
52 | parser.add_argument('--epochs', default=1000, type=int, help='Pre-training epochs.') # similar to aug_mode
53 | parser.add_argument('--model_save_interval', default=50, type=int,
54 | help='Save model state every N epochs during pre-training.') # similar to aug_mode
55 | parser.add_argument('--optimizer', default=None, type=str,
56 | help="Optimizer used during pre-training {'sgd', 'adam'}. Default if None") # similar to aug_mode
57 | parser.add_argument('--scheduler', default="MultiStepLR", type=str,
58 | help="Scheduler to use (refer to `pretrain.py`)")
59 | parser.add_argument('--scheduler_milestones', default=[400, 600, 800], type=int, nargs="+",
60 | help="Milestones for (Repeated)MultiStepLR scheduler")
61 | parser.add_argument('--num_workers', default=None, type=int)
62 |
63 | # Fine-tune params
64 | parser.add_argument('--n_shot', default=5, type=int, help='number of labeled data in each class, same as n_support')
65 | parser.add_argument('--n_way', default=5, type=int)
66 | parser.add_argument('--n_query_shot', default=15, type=int)
67 |
68 | parser.add_argument('--ft_tag', default='default', type=str,
69 | help='Tag used to differentiate output directories for fine-tuned models')
70 | parser.add_argument('--ft_head', default='linear', help='See `model.classifier_head.CLASSIFIER_HEAD_CLASS_MAP`')
71 | parser.add_argument('--ft_epochs', default=100, type=int)
72 | parser.add_argument('--ft_pretrain_epoch', default=None, type=int)
73 | parser.add_argument('--ft_batch_size', default=4, type=int)
74 | parser.add_argument('--ft_lr', default=1e-2, type=float, help='Learning rate for fine-tuning')
75 | parser.add_argument('--ft_augmentation', default=None, type=str,
76 | help="Augmentation used for fine-tuning {None, 'base', 'strong'}")
77 | parser.add_argument('--ft_parts', default='head', type=str, help="Where to fine-tune: {'full', 'body', 'head'}")
78 | parser.add_argument('--ft_features', default=None, type=str,
79 | help='Specify which features to use from the base model (see model/base.py)')
80 | parser.add_argument('--ft_intermediate_test', action='store_true', help='Evaluate on query set during fine-tuning')
81 | parser.add_argument('--ft_episode_seed', default=0, type=int)
82 |
83 | # Model parameters (make sure to prepend with `model_`)
84 | parser.add_argument('--model_simclr_projection_dim', default=128, type=int)
85 | parser.add_argument('--model_simclr_temperature', default=1.0, type=float)
86 |
87 | # Batch normalization (likely deprecated)
88 | parser.add_argument('--track_bn', action='store_true', help='tracking BN stats')
89 | parser.add_argument('--freeze_bn', action='store_true',
90 | help='freeze bn stats, i.e., use accumulated stats of pretrained model during inference. Note, track_bn must be on to do this.')
91 | parser.add_argument('--reinit_bn_stats', action='store_true',
92 | help='Re-initialize BN running statistics every iteration')
93 |
94 | params = parser.parse_args()
95 |
96 | # Double-checking parameters
97 | if params.freeze_bn and not params.track_bn:
98 | raise AssertionError('Invalid parameter combination')
99 | if params.reinit_bn_stats:
100 | raise AssertionError('Plz consult w/ anon author.')
101 | if params.ut and not params.target_dataset:
102 | raise AssertionError('Invalid parameter combination')
103 | if params.ft_parts not in ["head", "body", "full"]:
104 | raise AssertionError('Invalid params.ft_parts: {}'.format(params.ft_parts))
105 |
106 | # pls, put, pmsl parameters
107 | if sum((params.pls, params.put, params.pmsl)) > 1:
108 | raise AssertionError('You may only specify one of params.{pls,put,pmsl}')
109 |
110 | # Assign num_classes (*_new)
111 | if params.source_dataset == 'miniImageNet':
112 | params.num_classes = 64
113 | elif params.source_dataset == 'tieredImageNet':
114 | params.num_classes = 351
115 | elif params.source_dataset == 'ImageNet':
116 | params.num_classes = 1000
117 | elif params.source_dataset == 'CropDisease':
118 | params.num_classes = 38
119 | elif params.source_dataset == 'EuroSAT':
120 | params.num_classes = 10
121 | elif params.source_dataset == 'ISIC':
122 | params.num_classes = 7
123 | elif params.source_dataset == 'ChestX':
124 | params.num_classes = 7
125 | elif params.source_dataset == 'places':
126 | params.num_classes = 16
127 | elif params.source_dataset == 'plantae':
128 | params.num_classes = 69
129 | elif params.source_dataset == 'cars':
130 | params.num_classes = 196
131 | elif params.source_dataset == 'cub':
132 | params.num_classes = 200
133 | elif params.source_dataset == 'none':
134 | params.num_classes = 5
135 | else:
136 | raise ValueError('Invalid `source_dataset` argument: {}'.format(params.source_dataset))
137 |
138 | # Default workers
139 | if params.num_workers is None:
140 | params.num_workers = 3
141 | if params.target_dataset in ["cars", "cub", "plantae"]:
142 | params.num_workers = 4
143 | if params.target_dataset in ["ChestX"]:
144 | params.num_workers = 6
145 | print("Using default num_workers={}".format(params.num_workers))
146 | params.num_workers *= 2 # TEMP
147 |
148 | # Default optimizers
149 | if params.optimizer is None:
150 | if params.model in ['simsiam', 'byol']:
151 | params.optimizer = 'adam' if not params.ls else 'sgd'
152 | else:
153 | params.optimizer = 'sgd'
154 | print("Using default optimizer for model {}: {}".format(params.model, params.optimizer))
155 |
156 | # Default learning rate
157 | if params.lr is None:
158 | if params.model in ['simsiam', 'byol']:
159 | params.lr = 3e-4 if not params.ls else 0.1
160 | elif params.model in ['moco']:
161 | params.lr = 0.01
162 | else:
163 | params.lr = 0.1
164 | print("Using default lr for model {}: {}".format(params.model, params.lr))
165 |
166 | # Default ls_batch_size
167 | if params.ls_batch_size is None:
168 | params.ls_batch_size = params.batch_size
169 |
170 | params.ft_train_body = params.ft_parts in ['body', 'full']
171 | params.ft_train_head = params.ft_parts in ['head', 'full']
172 |
173 | if params.previous_tag is None:
174 | if params.pls_tag: # support for deprecated argument (changed 5/8/2022)
175 | print("Warning: params.pls_tag is deprecated. Please use params.previous_tag")
176 | params.previous_tag = params.pls_tag
177 | elif params.pls or params.put or params.pmsl:
178 | print("Using params.tag for params.previous_tag")
179 | params.previous_tag = params.tag
180 |
181 | return params
182 |
--------------------------------------------------------------------------------
/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from . import meta_template
2 | from . import maml
3 | from . import boil
4 | from . import protonet
5 |
6 |
7 |
--------------------------------------------------------------------------------
/methods/baselinefinetune.py:
--------------------------------------------------------------------------------
1 | import backbone
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import torch.nn.functional as F
7 | from methods.meta_template import MetaTemplate
8 |
9 | class BaselineFinetune(MetaTemplate):
10 | def __init__(self, model_func, n_way, n_support, loss_type = "softmax"):
11 | super(BaselineFinetune, self).__init__( model_func, n_way, n_support)
12 | self.loss_type = loss_type
13 |
14 | def set_forward(self, x, is_feature = True):
15 | return self.set_forward_adaptation(x, is_feature); #Baseline always do adaptation
16 |
17 | def set_forward_adaptation(self, x, is_feature = True):
18 | assert is_feature == True, 'Baseline only support testing with feature'
19 | z_support, z_query = self.parse_feature(x,is_feature)
20 |
21 | z_support = z_support.contiguous().view(self.n_way* self.n_support, -1 )
22 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 )
23 |
24 | y_support = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support ))
25 | y_support = Variable(y_support.cuda())
26 |
27 | if self.loss_type == 'softmax':
28 | linear_clf = nn.Linear(self.feat_dim, self.n_way)
29 |
30 | elif self.loss_type == 'dist':
31 | linear_clf = backbone.distLinear(self.feat_dim, self.n_way)
32 |
33 |
34 | linear_clf = linear_clf.cuda()
35 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001)
36 |
37 | loss_function = nn.CrossEntropyLoss()
38 | loss_function = loss_function.cuda()
39 |
40 | batch_size = 4
41 | support_size = self.n_way* self.n_support
42 | for epoch in range(100):
43 | rand_id = np.random.permutation(support_size)
44 | for i in range(0, support_size , batch_size):
45 | set_optimizer.zero_grad()
46 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda()
47 | z_batch = z_support[selected_id]
48 |
49 | scores = linear_clf(z_batch)
50 |
51 | y_batch = y_support[selected_id]
52 |
53 | loss = loss_function(scores,y_batch)
54 | loss.backward()
55 | set_optimizer.step()
56 |
57 | scores = linear_clf(z_query)
58 | return scores
59 |
60 | def set_forward_loss(self,x):
61 | raise ValueError('Baseline predict on pretrained feature and do not support finetune backbone')
62 |
--------------------------------------------------------------------------------
/methods/baselinetrain.py:
--------------------------------------------------------------------------------
1 | import backbone
2 | import utils
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 |
10 | class BaselineTrain(nn.Module):
11 | def __init__(self, model_func, num_class, loss_type = 'softmax'):
12 | super(BaselineTrain, self).__init__()
13 | self.feature = model_func
14 |
15 | if loss_type == 'softmax':
16 | self.classifier = nn.Linear(self.feature.final_feat_dim, num_class)
17 | self.classifier.bias.data.fill_(0)
18 | elif loss_type == 'dist': #Baseline ++
19 | self.classifier = backbone.distLinear(self.feature.final_feat_dim, num_class)
20 |
21 | self.loss_type = loss_type #'softmax' #'dist'
22 | self.num_class = num_class
23 | self.loss_fn = nn.CrossEntropyLoss()
24 | self.top1 = utils.AverageMeter()
25 |
26 | def forward(self,x):
27 | x = Variable(x.cuda())
28 | out = self.feature.forward(x)
29 | scores = self.classifier.forward(out)
30 | return scores
31 |
32 | def forward_loss(self, x, y):
33 | y = Variable(y.cuda())
34 |
35 | scores = self.forward(x)
36 |
37 | _, predicted = torch.max(scores.data, 1)
38 | correct = predicted.eq(y.data).cpu().sum()
39 | self.top1.update(correct.item()*100 / (y.size(0)+0.0), y.size(0))
40 |
41 | return self.loss_fn(scores, y)
42 |
43 | def train_loop(self, epoch, train_loader, optimizer, scheduler):
44 | print_freq = 10
45 | avg_loss=0
46 | for i, (x,y) in enumerate(train_loader):
47 | optimizer.zero_grad()
48 | loss = self.forward_loss(x, y)
49 | loss.backward()
50 | optimizer.step()
51 |
52 | avg_loss = avg_loss+loss.item()
53 | if i % print_freq==0:
54 | #print(optimizer.state_dict()['param_groups'][0]['lr'])
55 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} | Top1 Val {:f} | Top1 Avg {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1), self.top1.val, self.top1.avg))
56 | if scheduler is not None:
57 | scheduler.step()
58 |
59 | def test_loop(self, val_loader):
60 | return -1 #no validation, just save model during iteration
61 |
62 |
--------------------------------------------------------------------------------
/methods/boil.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml
2 |
3 | import backbone
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from methods.meta_template import MetaTemplate
10 |
11 | class BOIL(MetaTemplate):
12 | def __init__(self, model_func, n_way, n_support, approx = False):
13 | super(BOIL, self).__init__( model_func, n_way, n_support, change_way = False)
14 |
15 | self.loss_fn = nn.CrossEntropyLoss()
16 | self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
17 | self.classifier.bias.data.fill_(0)
18 |
19 | self.n_task = 4
20 | self.task_update_num = 1
21 | self.train_lr = 0.5
22 | self.approx = approx #first order approx.
23 |
24 | def forward(self,x):
25 | out = self.feature.forward(x)
26 | scores = self.classifier.forward(out)
27 | return scores
28 |
29 | def set_forward(self, x, is_feature = False):
30 | assert is_feature == False, 'BOIL do not support fixed feature'
31 |
32 | x = x.cuda()
33 | x_var = Variable(x)
34 | x_a_i = x_var[:,:self.n_support,:,:,:].contiguous().view( self.n_way* self.n_support, *x.size()[2:]) #support data
35 | x_b_i = x_var[:,self.n_support:,:,:,:].contiguous().view( self.n_way* self.n_query, *x.size()[2:]) #query data
36 | y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_support ) )).cuda() #label for support data
37 |
38 | fast_parameters = list(self.parameters()) #the first gradient calcuated in line 45 is based on original weight
39 | len_parameters = len(fast_parameters)
40 | for weight in self.parameters():
41 | weight.fast = None
42 | self.zero_grad()
43 |
44 | for task_step in range(self.task_update_num):
45 | scores = self.forward(x_a_i)
46 | set_loss = self.loss_fn( scores, y_a_i)
47 | grad = torch.autograd.grad(set_loss, fast_parameters, create_graph=True) #build full graph support gradient of gradient
48 | if self.approx:
49 | grad = [ g.detach() for g in grad ] #do not calculate gradient of gradient if using first order approximation
50 | fast_parameters = []
51 | for k, weight in enumerate(self.parameters()):
52 | #for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
53 | if k == len_parameters-2 or k == len_parameters-1:
54 | weight.fast = weight
55 | else:
56 | if weight.fast is None:
57 | weight.fast = weight - self.train_lr * grad[k] #create weight.fast
58 | else:
59 | weight.fast = weight.fast - self.train_lr * grad[k] #create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast
60 | fast_parameters.append(weight.fast) #gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
61 |
62 | scores = self.forward(x_b_i)
63 | return scores
64 |
65 | def set_forward_adaptation(self,x, is_feature = False): #overwrite parrent function
66 | raise ValueError('BOIL performs further adapation simply by increasing task_upate_num')
67 |
68 | def set_forward_loss(self, x):
69 | scores = self.set_forward(x, is_feature = False)
70 | y_b_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_query ) )).cuda()
71 | loss = self.loss_fn(scores, y_b_i)
72 |
73 | return loss
74 |
75 | def train_loop(self, epoch, train_loader, optimizer): #overwrite parrent function
76 | print_freq = 10
77 | avg_loss=0
78 | task_count = 0
79 | loss_all = []
80 | optimizer.zero_grad()
81 |
82 | #train
83 | for i, (x,_) in enumerate(train_loader):
84 | self.n_query = x.size(1) - self.n_support
85 | assert self.n_way == x.size(0), "BOIL do not support way change"
86 |
87 | loss = self.set_forward_loss(x)
88 | avg_loss = avg_loss+loss.item()#.data[0]
89 | loss_all.append(loss)
90 |
91 | task_count += 1
92 |
93 | if task_count == self.n_task: #BOIL update several tasks at one time
94 | loss_q = torch.stack(loss_all).sum(0)
95 | loss_q.backward()
96 |
97 | optimizer.step()
98 | task_count = 0
99 | loss_all = []
100 | optimizer.zero_grad()
101 | if i % print_freq==0:
102 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1)))
103 |
104 | def test_loop(self, test_loader, return_std = False): #overwrite parrent function
105 | correct =0
106 | count = 0
107 | acc_all = []
108 |
109 | iter_num = len(test_loader)
110 | for i, (x,_) in enumerate(test_loader):
111 | self.n_query = x.size(1) - self.n_support
112 | assert self.n_way == x.size(0), "BOIL do not support way change"
113 | correct_this, count_this = self.correct(x)
114 | acc_all.append(correct_this/ count_this *100 )
115 |
116 | acc_all = np.asarray(acc_all)
117 | acc_mean = np.mean(acc_all)
118 | acc_std = np.std(acc_all)
119 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
120 | if return_std:
121 | return acc_mean, acc_std
122 | else:
123 | return acc_mean
124 |
125 | def get_logits(self, x):
126 | self.n_query = x.size(1) - self.n_support
127 | logits = self.set_forward(x)
128 | return logits
--------------------------------------------------------------------------------
/methods/byol.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from functools import wraps
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 | from torchvision import transforms as T
10 | from .baselinetrain import BaselineTrain
11 |
12 | # helper functions
13 |
14 | def default(val, def_val):
15 | return def_val if val is None else val
16 |
17 | def flatten(t):
18 | return t.reshape(t.shape[0], -1)
19 |
20 | def singleton(cache_key):
21 | def inner_fn(fn):
22 | @wraps(fn)
23 | def wrapper(self, *args, **kwargs):
24 | instance = getattr(self, cache_key)
25 | if instance is not None:
26 | return instance
27 |
28 | instance = fn(self, *args, **kwargs)
29 | setattr(self, cache_key, instance)
30 | return instance
31 | return wrapper
32 | return inner_fn
33 |
34 | def get_module_device(module):
35 | return next(module.parameters()).device
36 |
37 | def set_requires_grad(model, val):
38 | for p in model.parameters():
39 | p.requires_grad = val
40 |
41 | # loss fn
42 |
43 | def loss_fn(x, y):
44 | x = F.normalize(x, dim=-1, p=2)
45 | y = F.normalize(y, dim=-1, p=2)
46 | return 2 - 2 * (x * y).sum(dim=-1)
47 |
48 | # augmentation utils
49 |
50 | class RandomApply(nn.Module):
51 | def __init__(self, fn, p):
52 | super().__init__()
53 | self.fn = fn
54 | self.p = p
55 | def forward(self, x):
56 | if random.random() > self.p:
57 | return x
58 | return self.fn(x)
59 |
60 | # exponential moving average
61 |
62 | class EMA():
63 | def __init__(self, beta):
64 | super().__init__()
65 | self.beta = beta
66 |
67 | def update_average(self, old, new):
68 | if old is None:
69 | return new
70 | return old * self.beta + (1 - self.beta) * new
71 |
72 | def update_moving_average(ema_updater, ma_model, current_model):
73 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
74 | old_weight, up_weight = ma_params.data, current_params.data
75 | ma_params.data = ema_updater.update_average(old_weight, up_weight)
76 |
77 | # MLP class for projector and predictor
78 |
79 | class MLP(nn.Module):
80 | def __init__(self, dim, projection_size, hidden_size = 4096):
81 | super().__init__()
82 | self.net = nn.Sequential(
83 | nn.Linear(dim, hidden_size),
84 | nn.BatchNorm1d(hidden_size),
85 | nn.ReLU(inplace=True),
86 | nn.Linear(hidden_size, projection_size)
87 | )
88 |
89 | def forward(self, x):
90 | return self.net(x)
91 |
92 | # a wrapper class for the base neural network
93 | # will manage the interception of the hidden layer output
94 | # and pipe it into the projecter and predictor nets
95 |
96 | class NetWrapper(nn.Module):
97 | def __init__(self, net, projection_size, projection_hidden_size, layer = -1): # default layer = -2 since network includes classifier. Ours does not have classifier.
98 | super().__init__()
99 | self.net = net
100 | self.layer = layer
101 |
102 | self.projector = None
103 | self.projection_size = projection_size
104 | self.projection_hidden_size = projection_hidden_size
105 |
106 | self.hidden = {}
107 | self.hook_registered = False
108 |
109 | def _find_layer(self):
110 | if type(self.layer) == str:
111 | modules = dict([*self.net.named_modules()])
112 | return modules.get(self.layer, None)
113 | elif type(self.layer) == int:
114 | children = [*self.net.children()]
115 | return children[self.layer]
116 | return None
117 |
118 | def _hook(self, _, input, output):
119 | device = input[0].device
120 | self.hidden[device] = flatten(output)
121 |
122 | def _register_hook(self):
123 | layer = self._find_layer()
124 | assert layer is not None, f'hidden layer ({self.layer}) not found'
125 | handle = layer.register_forward_hook(self._hook)
126 | self.hook_registered = True
127 |
128 | @singleton('projector')
129 | def _get_projector(self, hidden):
130 | _, dim = hidden.shape
131 | projector = MLP(dim, self.projection_size, self.projection_hidden_size)
132 | return projector.to(hidden)
133 |
134 | def get_representation(self, x):
135 | if self.layer == -1:
136 | return self.net(x)
137 |
138 | if not self.hook_registered:
139 | self._register_hook()
140 |
141 | self.hidden.clear()
142 | _ = self.net(x)
143 | hidden = self.hidden[x.device]
144 | self.hidden.clear()
145 |
146 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
147 | return hidden
148 |
149 | def forward(self, x, return_projection = True):
150 | representation = self.get_representation(x)
151 |
152 | if not return_projection:
153 | return representation
154 |
155 | projector = self._get_projector(representation)
156 | projection = projector(representation)
157 | return projection, representation
158 |
159 | # main class
160 |
161 | class BYOL(nn.Module):
162 | def __init__(
163 | self,
164 | net,
165 | image_size,
166 | hidden_layer = -1,
167 | projection_size = 256,
168 | projection_hidden_size = 4096,
169 | augment_fn = None,
170 | augment_fn2 = None,
171 | moving_average_decay = 0.99,
172 | use_momentum = True
173 | ):
174 | super().__init__()
175 | assert isinstance(net, BaselineTrain)
176 | self.net = net
177 |
178 | # default SimCLR augmentation
179 |
180 | # DEFAULT_AUG = torch.nn.Sequential(
181 | # RandomApply(
182 | # T.ColorJitter(0.8, 0.8, 0.8, 0.2),
183 | # p = 0.3
184 | # ),
185 | # T.RandomGrayscale(p=0.2),
186 | # T.RandomHorizontalFlip(),
187 | # RandomApply(
188 | # T.GaussianBlur((3, 3), (1.0, 2.0)),
189 | # p = 0.2
190 | # ),
191 | # T.RandomResizedCrop((image_size, image_size)),
192 | # T.Normalize(
193 | # mean=torch.tensor([0.485, 0.456, 0.406]),
194 | # std=torch.tensor([0.229, 0.224, 0.225])),
195 | # )
196 |
197 | # self.augment1 = default(augment_fn, DEFAULT_AUG)
198 | # self.augment2 = default(augment_fn2, self.augment1)
199 |
200 | self.online_encoder = NetWrapper(net.feature, projection_size, projection_hidden_size, layer=hidden_layer)
201 |
202 | self.use_momentum = use_momentum
203 | self.target_encoder = None
204 | self.target_ema_updater = EMA(moving_average_decay)
205 |
206 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
207 |
208 | # get device of network and make wrapper same device
209 | device = get_module_device(net.feature)
210 | self.to(device)
211 |
212 | # send a mock image tensor to instantiate singleton parameters
213 | self.forward(torch.randn(2, 3, image_size, image_size, device=device), torch.randn(2, 3, image_size, image_size, device=device))
214 |
215 | @singleton('target_encoder')
216 | def _get_target_encoder(self):
217 | target_encoder = copy.deepcopy(self.online_encoder)
218 | set_requires_grad(target_encoder, False)
219 | return target_encoder
220 |
221 | def reset_moving_average(self):
222 | del self.target_encoder
223 | self.target_encoder = None
224 |
225 | def update_moving_average(self):
226 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
227 | assert self.target_encoder is not None, 'target encoder has not been created yet'
228 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
229 |
230 | def forward(
231 | self,
232 | x1, x2,
233 | return_embedding = False,
234 | return_projection = True
235 | ):
236 | assert not (self.training and x1.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'
237 |
238 | if return_embedding:
239 | return self.online_encoder(torch.cat([x1, x2], 0), return_projection = return_projection)
240 |
241 | # image_one, image_two = self.augment1(x), self.augment2(x)
242 |
243 | online_proj_one, _ = self.online_encoder(x1)
244 | online_proj_two, _ = self.online_encoder(x2)
245 |
246 | online_pred_one = self.online_predictor(online_proj_one)
247 | online_pred_two = self.online_predictor(online_proj_two)
248 |
249 | with torch.no_grad():
250 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
251 | target_proj_one, _ = target_encoder(x1)
252 | target_proj_two, _ = target_encoder(x2)
253 | target_proj_one.detach_()
254 | target_proj_two.detach_()
255 |
256 | loss_one = loss_fn(online_pred_one, target_proj_two.detach())
257 | loss_two = loss_fn(online_pred_two, target_proj_one.detach())
258 |
259 | loss = loss_one + loss_two
260 | return loss.mean()
261 |
--------------------------------------------------------------------------------
/methods/maml.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml
2 |
3 | import backbone
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from methods.meta_template import MetaTemplate
10 |
11 | class MAML(MetaTemplate):
12 | def __init__(self, model_func, n_way, n_support, approx = False):
13 | super(MAML, self).__init__( model_func, n_way, n_support, change_way = False)
14 |
15 | self.loss_fn = nn.CrossEntropyLoss()
16 | self.classifier = backbone.Linear_fw(self.feat_dim, n_way)
17 | self.classifier.bias.data.fill_(0)
18 |
19 | self.n_task = 4
20 | self.task_update_num = 1
21 | self.train_lr = 0.5
22 | self.approx = approx #first order approx.
23 |
24 | def forward(self,x):
25 | out = self.feature.forward(x)
26 | scores = self.classifier.forward(out)
27 | return scores
28 |
29 | def set_forward(self, x, is_feature = False):
30 | assert is_feature == False, 'MAML do not support fixed feature'
31 |
32 | x = x.cuda()
33 | x_var = Variable(x)
34 | x_a_i = x_var[:,:self.n_support,:,:,:].contiguous().view( self.n_way* self.n_support, *x.size()[2:]) #support data
35 | x_b_i = x_var[:,self.n_support:,:,:,:].contiguous().view( self.n_way* self.n_query, *x.size()[2:]) #query data
36 | y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_support ) )).cuda() #label for support data
37 |
38 | fast_parameters = list(self.parameters()) #the first gradient calcuated in line 45 is based on original weight
39 | for weight in self.parameters():
40 | weight.fast = None
41 | self.zero_grad()
42 |
43 | for task_step in range(self.task_update_num):
44 | scores = self.forward(x_a_i)
45 | set_loss = self.loss_fn(scores, y_a_i)
46 | grad = torch.autograd.grad(set_loss, fast_parameters, create_graph=True) #build full graph support gradient of gradient
47 | if self.approx:
48 | grad = [ g.detach() for g in grad ] #do not calculate gradient of gradient if using first order approximation
49 | fast_parameters = []
50 | for k, weight in enumerate(self.parameters()):
51 | #for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py
52 | if weight.fast is None:
53 | weight.fast = weight - self.train_lr * grad[k] #create weight.fast
54 | else:
55 | weight.fast = weight.fast - self.train_lr * grad[k] #create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast
56 | fast_parameters.append(weight.fast) #gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts
57 |
58 | scores = self.forward(x_b_i)
59 | return scores
60 |
61 | def set_forward_adaptation(self,x, is_feature = False): #overwrite parrent function
62 | raise ValueError('MAML performs further adapation simply by increasing task_upate_num')
63 |
64 |
65 | def set_forward_loss(self, x):
66 | scores = self.set_forward(x, is_feature = False)
67 | y_b_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_query ) )).cuda()
68 | loss = self.loss_fn(scores, y_b_i)
69 |
70 | return loss
71 |
72 | def train_loop(self, epoch, train_loader, optimizer): #overwrite parrent function
73 | print_freq = 10
74 | avg_loss=0
75 | task_count = 0
76 | loss_all = []
77 | optimizer.zero_grad()
78 |
79 | #train
80 | for i, (x,_) in enumerate(train_loader):
81 | self.n_query = x.size(1) - self.n_support
82 | assert self.n_way == x.size(0), "MAML do not support way change"
83 |
84 | loss = self.set_forward_loss(x)
85 | avg_loss = avg_loss+loss.item()#.data[0]
86 | loss_all.append(loss)
87 |
88 | task_count += 1
89 |
90 | if task_count == self.n_task: #MAML update several tasks at one time
91 | loss_q = torch.stack(loss_all).sum(0)
92 | loss_q.backward()
93 |
94 | optimizer.step()
95 | task_count = 0
96 | loss_all = []
97 | optimizer.zero_grad()
98 | if i % print_freq==0:
99 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1)))
100 |
101 | def test_loop(self, test_loader, return_std = False): #overwrite parrent function
102 | correct =0
103 | count = 0
104 | acc_all = []
105 |
106 | iter_num = len(test_loader)
107 | for i, (x,_) in enumerate(test_loader):
108 | self.n_query = x.size(1) - self.n_support
109 | assert self.n_way == x.size(0), "MAML do not support way change"
110 | correct_this, count_this = self.correct(x)
111 | acc_all.append(correct_this/ count_this *100 )
112 |
113 | acc_all = np.asarray(acc_all)
114 | acc_mean = np.mean(acc_all)
115 | acc_std = np.std(acc_all)
116 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
117 | if return_std:
118 | return acc_mean, acc_std
119 | else:
120 | return acc_mean
121 |
122 | def get_logits(self, x):
123 | self.n_query = x.size(1) - self.n_support
124 | logits = self.set_forward(x)
125 | return logits
--------------------------------------------------------------------------------
/methods/meta_template.py:
--------------------------------------------------------------------------------
1 | import backbone
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import torch.nn.functional as F
7 | import utils
8 | from abc import abstractmethod
9 |
10 | class MetaTemplate(nn.Module):
11 | def __init__(self, model_func, n_way, n_support, change_way = True):
12 | super(MetaTemplate, self).__init__()
13 | self.n_way = n_way
14 | self.n_support = n_support
15 | self.n_query = 15 #(If -1, change depends on input)
16 | self.feature = model_func
17 | self.feat_dim = self.feature.final_feat_dim
18 | self.change_way = change_way #some methods allow different_way classification during training and test
19 |
20 | @abstractmethod
21 | def set_forward(self,x,is_feature):
22 | pass
23 |
24 | @abstractmethod
25 | def set_forward_loss(self, x):
26 | pass
27 |
28 | def forward(self,x):
29 | out = self.feature.forward(x)
30 | return out
31 |
32 | def parse_feature(self,x,is_feature):
33 | x = Variable(x.cuda())
34 | if is_feature:
35 | z_all = x
36 | else:
37 | x = x.contiguous().view( self.n_way * (self.n_support + self.n_query), *x.size()[2:])
38 | z_all = self.feature.forward(x)
39 | z_all = z_all.view( self.n_way, self.n_support + self.n_query, -1)
40 | z_support = z_all[:, :self.n_support]
41 | z_query = z_all[:, self.n_support:]
42 |
43 | return z_support, z_query
44 |
45 | def correct(self, x):
46 | scores = self.set_forward(x)
47 | y_query = np.repeat(range( self.n_way ), self.n_query )
48 |
49 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
50 | topk_ind = topk_labels.cpu().numpy()
51 | top1_correct = np.sum(topk_ind[:,0] == y_query)
52 | return float(top1_correct), len(y_query)
53 |
54 | def train_loop(self, epoch, train_loader, optimizer):
55 | print_freq = 10
56 |
57 | avg_loss=0
58 | for i, (x,_ ) in enumerate(train_loader):
59 | self.n_query = x.size(1) - self.n_support
60 | if self.change_way:
61 | self.n_way = x.size(0)
62 | optimizer.zero_grad()
63 | loss = self.set_forward_loss( x )
64 | loss.backward()
65 | optimizer.step()
66 | avg_loss = avg_loss+loss.item()
67 |
68 | if i % print_freq==0:
69 | #print(optimizer.state_dict()['param_groups'][0]['lr'])
70 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1)))
71 |
72 | def test_loop(self, test_loader, record = None):
73 | correct =0
74 | count = 0
75 | acc_all = []
76 |
77 | iter_num = len(test_loader)
78 | for i, (x,_) in enumerate(test_loader):
79 | self.n_query = x.size(1) - self.n_support
80 | if self.change_way:
81 | self.n_way = x.size(0)
82 | correct_this, count_this = self.correct(x)
83 | acc_all.append(correct_this/ count_this*100 )
84 |
85 | acc_all = np.asarray(acc_all)
86 | acc_mean = np.mean(acc_all)
87 | acc_std = np.std(acc_all)
88 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num)))
89 |
90 | return acc_mean
91 |
92 | def set_forward_adaptation(self, x, is_feature = True): #further adaptation, default is fixing feature and train a new softmax clasifier
93 | assert is_feature == True, 'Feature is fixed in further adaptation'
94 | z_support, z_query = self.parse_feature(x,is_feature)
95 |
96 | z_support = z_support.contiguous().view(self.n_way* self.n_support, -1 )
97 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 )
98 |
99 | y_support = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support ))
100 | y_support = Variable(y_support.cuda())
101 |
102 | linear_clf = nn.Linear(self.feat_dim, self.n_way)
103 | linear_clf = linear_clf.cuda()
104 |
105 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001)
106 |
107 | loss_function = nn.CrossEntropyLoss()
108 | loss_function = loss_function.cuda()
109 |
110 | batch_size = 4
111 | support_size = self.n_way* self.n_support
112 | for epoch in range(100):
113 | rand_id = np.random.permutation(support_size)
114 | for i in range(0, support_size , batch_size):
115 | set_optimizer.zero_grad()
116 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda()
117 | z_batch = z_support[selected_id]
118 | y_batch = y_support[selected_id]
119 | scores = linear_clf(z_batch)
120 | loss = loss_function(scores,y_batch)
121 | loss.backward()
122 | set_optimizer.step()
123 |
124 | scores = linear_clf(z_query)
125 | return scores
126 |
--------------------------------------------------------------------------------
/methods/protonet.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/jakesnell/prototypical-networks
2 |
3 | import backbone
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from methods.meta_template import MetaTemplate
10 |
11 | class ProtoNet(MetaTemplate):
12 | def __init__(self, model_func, n_way, n_support):
13 | super(ProtoNet, self).__init__( model_func, n_way, n_support)
14 | self.loss_fn = nn.CrossEntropyLoss()
15 |
16 |
17 | def set_forward(self, x, is_feature = False):
18 | z_support, z_query = self.parse_feature(x,is_feature)
19 |
20 | z_support = z_support.contiguous()
21 | z_proto = z_support.view(self.n_way, self.n_support, -1 ).mean(1) #the shape of z is [n_data, n_dim]
22 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 )
23 |
24 | dists = euclidean_dist(z_query, z_proto)
25 | scores = -dists
26 | return scores
27 |
28 |
29 | def set_forward_loss(self, x):
30 | y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query ))
31 | y_query = Variable(y_query.cuda())
32 |
33 | scores = self.set_forward(x)
34 |
35 | return self.loss_fn(scores, y_query )
36 |
37 |
38 | def euclidean_dist( x, y):
39 | # x: N x D
40 | # y: M x D
41 | n = x.size(0)
42 | m = y.size(0)
43 | d = x.size(1)
44 | assert d == y.size(1)
45 |
46 | x = x.unsqueeze(1).expand(n, m, d)
47 | y = y.unsqueeze(0).expand(n, m, d)
48 |
49 | return torch.pow(x - y, 2).sum(2)
50 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from model.base import BaseModel
2 | from model.byol import BYOL
3 | from model.moco import MoCo
4 | from model.simclr import SimCLR
5 | from model.simsiam import SimSiam
6 |
7 | _model_class_map = {
8 | 'base': BaseModel,
9 | 'simclr': SimCLR,
10 | 'byol': BYOL,
11 | 'moco': MoCo,
12 | 'simsiam': SimSiam,
13 | }
14 |
15 |
16 | def get_model_class(key):
17 | if key in _model_class_map:
18 | return _model_class_map[key]
19 | else:
20 | raise ValueError('Invalid model: {}'.format(key))
21 |
--------------------------------------------------------------------------------
/model/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from argparse import Namespace
3 | from typing import Tuple
4 |
5 | import torch
6 | from torch import nn
7 |
8 |
9 | class BaseModel(nn.Module):
10 | """
11 | BaseModel subclasses self-contain all modules and losses required for pre-training.
12 |
13 | - supported_feature_selectors: Feature selectors (see `forward_features()`) are used during fine-tuning
14 | to select which features (from which layer the features should be extracted) should be used for downstream
15 | tasks. This class attribute should be set for subclasses to prevent mistakes regarding the feature_selector
16 | argument (see `params.ft_features`).
17 | """
18 | supported_feature_selectors = []
19 |
20 | def __init__(self, backbone: nn.Module, params: Namespace):
21 | super().__init__()
22 | self.backbone = backbone
23 | self.params = params
24 | self.classifier = nn.Linear(backbone.final_feat_dim, params.num_classes)
25 | self.classifier.bias.data.fill_(0)
26 | self.cls_loss_function = nn.CrossEntropyLoss()
27 | self.final_feat_dim = backbone.final_feat_dim
28 |
29 | def forward_features(self, x, feature_selector: str = None):
30 | """
31 | You'll likely need to override this method for SSL models.
32 | """
33 | return self.backbone(x)
34 |
35 | def forward(self, x):
36 | x = self.backbone(x)
37 | x = self.classifier(x)
38 | return x
39 |
40 | def compute_cls_loss_and_accuracy(self, x, y, return_predictions=False) -> Tuple:
41 | scores = self.forward(x)
42 | _, predicted = torch.max(scores.data, 1)
43 | accuracy = predicted.eq(y.data).cpu().sum() / x.shape[0]
44 | if return_predictions:
45 | return self.cls_loss_function(scores, y), accuracy, predicted
46 | else:
47 | return self.cls_loss_function(scores, y), accuracy
48 |
49 | def on_step_start(self):
50 | pass
51 |
52 | def on_step_end(self):
53 | pass
54 |
55 | def on_epoch_start(self):
56 | pass
57 |
58 | def on_epoch_end(self):
59 | pass
60 |
61 |
62 | class BaseSelfSupervisedModel(BaseModel):
63 | @abstractmethod
64 | def compute_ssl_loss(self, x1, x2=None, return_features=False):
65 | """
66 | If SSL is based on paired input:
67 | By default: x1, x2 represent the input pair.
68 | If x2=None: x1 alone contains the full concatenated input pair.
69 | Else:
70 | x1 contains the input.
71 | """
72 | raise NotImplementedError()
73 |
--------------------------------------------------------------------------------
/model/byol.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from argparse import Namespace
4 | from functools import wraps
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 | from model.base import BaseSelfSupervisedModel
11 |
12 |
13 | def _singleton(cache_key):
14 | def inner_fn(fn):
15 | @wraps(fn)
16 | def wrapper(self, *args, **kwargs):
17 | instance = getattr(self, cache_key)
18 | if instance is not None:
19 | return instance
20 |
21 | instance = fn(self, *args, **kwargs)
22 | setattr(self, cache_key, instance)
23 | return instance
24 |
25 | return wrapper
26 |
27 | return inner_fn
28 |
29 |
30 | def _get_module_device(module):
31 | return next(module.parameters()).device
32 |
33 |
34 | def _set_requires_grad(model, val):
35 | for p in model.parameters():
36 | p.requires_grad = val
37 |
38 |
39 | def _loss_fn(x, y):
40 | x = F.normalize(x, dim=-1, p=2)
41 | y = F.normalize(y, dim=-1, p=2)
42 | return 2 - 2 * (x * y).sum(dim=-1)
43 |
44 |
45 | class RandomApply(nn.Module):
46 | def __init__(self, fn, p):
47 | super().__init__()
48 | self.fn = fn
49 | self.p = p
50 |
51 | def forward(self, x):
52 | if random.random() > self.p:
53 | return x
54 | return self.fn(x)
55 |
56 |
57 | class EMA:
58 | def __init__(self, beta):
59 | super().__init__()
60 | self.beta = beta
61 |
62 | def update_average(self, old, new):
63 | if old is None:
64 | return new
65 | return old * self.beta + (1 - self.beta) * new
66 |
67 |
68 | def _update_moving_average(ema_updater, ma_model, current_model):
69 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
70 | old_weight, up_weight = ma_params.data, current_params.data
71 | ma_params.data = ema_updater.update_average(old_weight, up_weight)
72 |
73 |
74 | class MLP(nn.Module):
75 | def __init__(self, dim, projection_size, hidden_size=4096):
76 | super().__init__()
77 | self.net = nn.Sequential(
78 | nn.Linear(dim, hidden_size),
79 | nn.BatchNorm1d(hidden_size),
80 | nn.ReLU(inplace=True),
81 | nn.Linear(hidden_size, projection_size)
82 | )
83 |
84 | def forward(self, x):
85 | return self.net(x)
86 |
87 |
88 | class NetWrapper(nn.Module):
89 | def __init__(self, net, projection_size, projection_hidden_size,
90 | layer=-1): # default layer = -2 since network includes classifier. Ours does not have classifier.
91 | super().__init__()
92 | self.net = net
93 | self.layer = layer
94 |
95 | self.projector = None
96 | self.projection_size = projection_size
97 | self.projection_hidden_size = projection_hidden_size
98 |
99 | self.hidden = {}
100 | self.hook_registered = False
101 |
102 | def _find_layer(self):
103 | if type(self.layer) == str:
104 | modules = dict([*self.net.named_modules()])
105 | return modules.get(self.layer, None)
106 | elif type(self.layer) == int:
107 | children = [*self.net.children()]
108 | return children[self.layer]
109 | return None
110 |
111 | def _hook(self, _, input, output):
112 | device = input[0].device
113 | self.hidden[device] = output.reshape(output.shape[0], -1) # flatten
114 |
115 | def _register_hook(self):
116 | layer = self._find_layer()
117 | assert layer is not None, f'hidden layer ({self.layer}) not found'
118 | handle = layer.register_forward_hook(self._hook)
119 | self.hook_registered = True
120 |
121 | @_singleton('projector')
122 | def _get_projector(self, hidden):
123 | _, dim = hidden.shape
124 | projector = MLP(dim, self.projection_size, self.projection_hidden_size)
125 | return projector.to(hidden)
126 |
127 | def get_representation(self, x):
128 | if self.layer == -1:
129 | return self.net(x)
130 |
131 | if not self.hook_registered:
132 | self._register_hook()
133 |
134 | self.hidden.clear()
135 | _ = self.net(x)
136 | hidden = self.hidden[x.device]
137 | self.hidden.clear()
138 |
139 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
140 | return hidden
141 |
142 | def forward(self, x, return_projection=True):
143 | representation = self.get_representation(x)
144 |
145 | if not return_projection:
146 | return representation
147 |
148 | projector = self._get_projector(representation)
149 | projection = projector(representation)
150 | return projection, representation
151 |
152 |
153 | class BYOL(BaseSelfSupervisedModel):
154 | def __init__(self, backbone: nn.Module, params: Namespace, use_momentum=True):
155 | super().__init__(backbone, params)
156 |
157 | image_size = 224
158 | hidden_layer = -1
159 | projection_size = 256
160 | projection_hidden_size = 4096
161 | moving_average_decay = 0.99
162 | use_momentum = use_momentum
163 |
164 | self.online_encoder = NetWrapper(self.backbone, projection_size, projection_hidden_size, layer=hidden_layer)
165 |
166 | self.use_momentum = use_momentum
167 | self.target_encoder = None
168 | self.target_ema_updater = EMA(moving_average_decay)
169 |
170 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
171 |
172 | # get device of network and make wrapper same device
173 | device = _get_module_device(backbone)
174 | self.to(device)
175 |
176 | # send a mock image tensor to instantiate singleton parameters
177 | self.compute_ssl_loss(torch.randn(2, 3, image_size, image_size, device=device),
178 | torch.randn(2, 3, image_size, image_size, device=device))
179 |
180 | @_singleton('target_encoder')
181 | def _get_target_encoder(self):
182 | target_encoder = copy.deepcopy(self.online_encoder)
183 | _set_requires_grad(target_encoder, False)
184 | return target_encoder
185 |
186 | def _reset_moving_average(self):
187 | del self.target_encoder
188 | self.target_encoder = None
189 |
190 | def _update_moving_average(self):
191 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
192 | assert self.target_encoder is not None, 'target encoder has not been created yet'
193 | _update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
194 |
195 | def compute_ssl_loss(self, x1, x2=None, return_features=False):
196 | if x2 is None:
197 | x = x1
198 | batch_size = int(x.shape[0] / 2)
199 | x1 = x[:batch_size]
200 | x2 = x[batch_size:]
201 |
202 | assert not (self.training and x1.shape[
203 | 0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'
204 |
205 | online_proj_one, _ = self.online_encoder(x1)
206 | online_proj_two, _ = self.online_encoder(x2)
207 |
208 | online_pred_one = self.online_predictor(online_proj_one)
209 | online_pred_two = self.online_predictor(online_proj_two)
210 |
211 | with torch.no_grad():
212 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
213 | target_proj_one, _ = target_encoder(x1)
214 | target_proj_two, _ = target_encoder(x2)
215 | target_proj_one.detach_()
216 | target_proj_two.detach_()
217 |
218 | loss_one = _loss_fn(online_pred_one, target_proj_two.detach())
219 | loss_two = _loss_fn(online_pred_two, target_proj_one.detach())
220 |
221 | loss = loss_one + loss_two
222 | loss = loss.mean()
223 |
224 | if return_features:
225 | if x2 is None:
226 | return loss, torch.cat([online_proj_one, online_proj_two])
227 | else:
228 | return loss, online_proj_one, online_proj_two
229 | else:
230 | return loss
231 |
232 | def on_step_end(self):
233 | if self.use_momentum:
234 | self._update_moving_average()
235 |
236 |
--------------------------------------------------------------------------------
/model/classifier_head.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class LinearClassifier(nn.Module):
5 | def __init__(self, in_dim, out_dim, params):
6 | super().__init__()
7 | self.fc = nn.Linear(in_dim, out_dim)
8 |
9 | def forward(self, x):
10 | x = self.fc(x)
11 | return x
12 |
13 |
14 | class TwoLayerMLPClassifier(nn.Module):
15 | def __init__(self, in_dim, out_dim, params):
16 | super().__init__()
17 | self.fc1 = nn.Linear(in_dim, in_dim)
18 | self.relu = nn.ReLU()
19 | self.fc2 = nn.Linear(in_dim, out_dim)
20 |
21 | def forward(self, x):
22 | x = self.fc1(x)
23 | x = self.relu(x)
24 | x = self.fc2(x)
25 | return x
26 |
27 |
28 | CLASSIFIER_HEAD_CLASS_MAP = {
29 | 'linear': LinearClassifier,
30 | 'two_layer_mlp': TwoLayerMLPClassifier,
31 | }
32 |
33 | def get_classifier_head_class(key):
34 | if key in CLASSIFIER_HEAD_CLASS_MAP:
35 | return CLASSIFIER_HEAD_CLASS_MAP[key]
36 | else:
37 | raise ValueError('Invalid classifier head specifier: {}'.format(key))
38 |
--------------------------------------------------------------------------------
/model/moco.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from argparse import Namespace
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from model.base import BaseSelfSupervisedModel
9 |
10 |
11 | class MoCo(BaseSelfSupervisedModel):
12 | def __init__(self, backbone: nn.Module, params: Namespace):
13 | super().__init__(backbone, params)
14 |
15 | dim = 128
16 | mlp = False
17 | self.K = 1024
18 | self.m = 0.999
19 | self.T = 1.0
20 |
21 | self.encoder_q = self.backbone
22 | self.encoder_k = copy.deepcopy(self.backbone)
23 |
24 | if not mlp:
25 | self.projector_q = nn.Linear(self.encoder_q.final_feat_dim, dim)
26 | self.projector_k = nn.Linear(self.encoder_k.final_feat_dim, dim)
27 | else:
28 | mlp_dim = self.encoder_q.feature.final_feat_dim
29 | self.projector_q = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim))
30 | self.projector_k = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim))
31 |
32 | self.encoder_k.requires_grad_(False)
33 | self.projector_k.requires_grad_(False)
34 | # Just in case (copied from old code)
35 | for param_k in self.encoder_k.parameters():
36 | param_k.requires_grad = False
37 | for param_k in self.projector_k.parameters():
38 | param_k.requires_grad = False
39 |
40 | self.register_buffer("queue", torch.randn(dim, self.K))
41 | self.queue = F.normalize(self.queue, dim=0)
42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
43 |
44 | self.ce_loss = nn.CrossEntropyLoss()
45 |
46 | @torch.no_grad()
47 | def _momentum_update_key_encoder(self):
48 | """
49 | Momentum update of the key encoder
50 | """
51 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
52 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
53 | for param_q_, param_k_ in zip(self.projector_q.parameters(), self.projector_k.parameters()):
54 | param_k_.data = param_k_.data * self.m + param_q_.data * (1. - self.m)
55 |
56 | @torch.no_grad()
57 | def _dequeue_and_enqueue(self, keys):
58 | batch_size = keys.shape[0]
59 | ptr = int(self.queue_ptr)
60 | assert self.K % batch_size == 0 # for simplicity
61 |
62 | # replace the keys at ptr (dequeue and enqueue)
63 | self.queue[:, ptr:ptr + batch_size] = keys.T
64 | ptr = (ptr + batch_size) % self.K # move pointer
65 |
66 | self.queue_ptr[0] = ptr
67 |
68 | def compute_ssl_loss(self, x1, x2=None, return_features=False):
69 | if x2 is None:
70 | x = x1
71 | batch_size = int(x.shape[0] / 2)
72 | im_q = x[:batch_size]
73 | im_k = x[batch_size:]
74 | else:
75 | im_q = x1
76 | im_k = x2
77 |
78 | q_features = self.encoder_q(im_q)
79 | q = self.projector_q(q_features) # queries: NxC
80 | q = F.normalize(q, dim=1)
81 |
82 | # compute key features
83 | with torch.no_grad(): # no gradient to keys
84 | self._momentum_update_key_encoder() # update the key encoder
85 |
86 | k_features = self.encoder_k(im_k)
87 | k = self.projector_k(k_features) # keys: NxC
88 | k = F.normalize(k, dim=1)
89 |
90 | # compute logits (Einstein sum is more intuitive)
91 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # positive logits: Nx1
92 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # negative logits: NxK
93 |
94 | logits = torch.cat([l_pos, l_neg], dim=1) # logits: Nx(1+K)
95 | logits /= self.T # apply temperature
96 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # labels: positive key indicators
97 |
98 | self._dequeue_and_enqueue(k)
99 |
100 | loss = self.ce_loss(logits, labels)
101 |
102 | if return_features:
103 | if x2 is None:
104 | return loss, torch.cat([q_features, k_features])
105 | else:
106 | return loss, q_features, k_features
107 | else:
108 | return loss
109 |
--------------------------------------------------------------------------------
/model/simclr.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 | from functools import lru_cache
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 |
8 | from model.base import BaseSelfSupervisedModel
9 |
10 |
11 | class ProjectionHead(nn.Module):
12 | def __init__(self, in_dim, out_dim):
13 | super(ProjectionHead, self).__init__()
14 | self.in_dim = in_dim
15 | self.out_dim = out_dim
16 |
17 | self.fc1 = nn.Linear(in_dim, in_dim)
18 | self.relu = nn.ReLU()
19 | self.fc2 = nn.Linear(in_dim, out_dim)
20 |
21 | def forward(self, x):
22 | return self.fc2(self.relu(self.fc1(x)))
23 |
24 |
25 | class NTXentLoss(nn.Module):
26 | def __init__(self, temperature, use_cosine_similarity):
27 | super(NTXentLoss, self).__init__()
28 | self.temperature = temperature
29 | self.softmax = torch.nn.Softmax(dim=-1)
30 | self.similarity_function = self._get_similarity_function(use_cosine_similarity)
31 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
32 |
33 | def _get_similarity_function(self, use_cosine_similarity):
34 | if use_cosine_similarity:
35 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
36 | return self._cosine_simililarity
37 | else:
38 | return self._dot_simililarity
39 |
40 | @lru_cache(maxsize=4)
41 | def _get_correlated_mask(self, batch_size):
42 | diag = np.eye(2 * batch_size)
43 | l1 = np.eye((2 * batch_size), 2 *
44 | batch_size, k=-batch_size)
45 | l2 = np.eye((2 * batch_size), 2 *
46 | batch_size, k=batch_size)
47 | mask = torch.from_numpy((diag + l1 + l2))
48 | mask = (1 - mask).type(torch.bool)
49 | return mask
50 |
51 | @staticmethod
52 | def _dot_simililarity(x, y):
53 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
54 | # x shape: (N, 1, C)
55 | # y shape: (1, C, 2N)
56 | # v shape: (N, 2N)
57 | return v
58 |
59 | def _cosine_simililarity(self, x, y):
60 | # x shape: (N, 1, C)
61 | # y shape: (1, 2N, C)
62 | # v shape: (N, 2N)
63 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
64 | return v
65 |
66 | def forward(self, zis, zjs):
67 | batch_size = zis.shape[0]
68 | representations = torch.cat([zjs, zis], dim=0)
69 | device = representations.device
70 |
71 | similarity_matrix = self.similarity_function(
72 | representations, representations)
73 |
74 | # filter out the scores from the positive samples
75 | l_pos = torch.diag(similarity_matrix, batch_size)
76 | r_pos = torch.diag(similarity_matrix, -batch_size)
77 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1)
78 |
79 | mask = self._get_correlated_mask(batch_size).to(device)
80 | negatives = similarity_matrix[mask].view(2 * batch_size, -1)
81 |
82 | logits = torch.cat((positives, negatives), dim=1)
83 | logits /= self.temperature
84 |
85 | labels = torch.zeros(2 * batch_size).to(device).long()
86 | loss = self.criterion(logits, labels)
87 |
88 | return loss / (2 * batch_size)
89 |
90 |
91 | class SimCLR(BaseSelfSupervisedModel):
92 |
93 | def __init__(self, backbone: nn.Module, params: Namespace):
94 | super().__init__(backbone, params)
95 | self.head = ProjectionHead(backbone.final_feat_dim, out_dim=params.model_simclr_projection_dim)
96 | self.ssl_loss_fn = NTXentLoss(temperature=params.model_simclr_temperature, use_cosine_similarity=True)
97 | self.final_feat_dim = self.backbone.final_feat_dim
98 |
99 | def compute_ssl_loss(self, x1, x2=None, return_features=False):
100 | if x2 is None:
101 | x = x1
102 | else:
103 | x = torch.cat([x1, x2])
104 | batch_size = int(x.shape[0] / 2)
105 |
106 | f = self.backbone(x)
107 | f1, f2 = f[:batch_size], f[batch_size:]
108 | p1 = self.head(f1)
109 | p2 = self.head(f2)
110 | loss = self.ssl_loss_fn(p1, p2)
111 |
112 | if return_features:
113 | if x2 is None:
114 | return loss, f
115 | else:
116 | return loss, f1, f2
117 | else:
118 | return loss
119 |
--------------------------------------------------------------------------------
/model/simsiam.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 |
3 | from torch import nn
4 |
5 | from model import BYOL
6 |
7 |
8 | class SimSiam(BYOL):
9 | def __init__(self, backbone: nn.Module, params: Namespace):
10 | super().__init__(backbone, params, use_momentum=False)
11 |
--------------------------------------------------------------------------------
/paths.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import re
4 | from argparse import Namespace
5 |
6 | import configs
7 |
8 | DATASET_KEYS = {
9 | 'miniImageNet': 'mini',
10 | 'miniImageNet_test': 'mini_test',
11 | 'tieredImageNet': 'tiered',
12 | 'tieredImageNet_test': 'tiered_test',
13 | 'ImageNet': 'imagenet',
14 | 'CropDisease': 'crop',
15 | 'EuroSAT': 'euro',
16 | 'ISIC': 'isic',
17 | 'ChestX': 'chest',
18 | 'cars': 'cars',
19 | 'cub': 'cub',
20 | 'places': 'places',
21 | 'plantae': 'plantae',
22 | }
23 |
24 | BACKBONE_KEYS = {
25 | 'resnet10': 'resnet10',
26 | 'resnet18': 'resnet18',
27 | 'resnet50': 'resnet50',
28 | }
29 |
30 | MODEL_KEYS = {
31 | 'base': 'base',
32 | 'simclr': 'simclr',
33 | 'simsiam': 'simsiam',
34 | 'moco': 'moco',
35 | 'swav': 'swav',
36 | 'byol': 'byol',
37 | }
38 |
39 |
40 | def get_output_directory(params: Namespace, previous=False, makedirs=True):
41 | """
42 | :param params:
43 | :param previous: get previous output directory for pls, put, pmsl modes
44 | :return:
45 | """
46 | if previous and not (params.pls or params.put or params.pmsl):
47 | raise ValueError('Invalid arguments for previous=True')
48 |
49 | path = configs.save_dir
50 | path = os.path.join(path, 'output')
51 | path = os.path.join(path, DATASET_KEYS[params.source_dataset])
52 |
53 | pretrain_specifiers = [BACKBONE_KEYS[params.backbone]]
54 | if previous:
55 | if params.pls:
56 | pretrain_specifiers.append(MODEL_KEYS['base'])
57 | else:
58 | pretrain_specifiers.append(MODEL_KEYS[params.model])
59 |
60 | if params.pls:
61 | pretrain_specifiers.append('LS')
62 | elif params.put:
63 | pretrain_specifiers.append('UT')
64 | elif params.pmsl:
65 | pretrain_specifiers.append('LS_UT')
66 | else:
67 | raise AssertionError("Invalid parameters")
68 | pretrain_specifiers.append(params.previous_tag)
69 | else:
70 | pretrain_specifiers.append(MODEL_KEYS[params.model])
71 | if params.pls:
72 | pretrain_specifiers.append('PLS')
73 | if params.put:
74 | pretrain_specifiers.append('PUT')
75 | if params.pmsl:
76 | pretrain_specifiers.append('PMSL')
77 | if params.ls:
78 | pretrain_specifiers.append('LS')
79 | if params.us:
80 | pretrain_specifiers.append('US')
81 | if params.ut:
82 | pretrain_specifiers.append('UT')
83 | pretrain_specifiers.append(params.tag)
84 |
85 | path = os.path.join(path, '_'.join(pretrain_specifiers))
86 |
87 | if previous:
88 | if params.put or params.pmsl:
89 | path = os.path.join(path, DATASET_KEYS[params.target_dataset])
90 | else:
91 | if params.put or params.pmsl or params.ut:
92 | path = os.path.join(path, DATASET_KEYS[params.target_dataset])
93 |
94 | if makedirs:
95 | os.makedirs(path, exist_ok=True)
96 |
97 | return path
98 |
99 |
100 | def get_pretrain_history_path(output_directory):
101 | basename = 'pretrain_history.csv'
102 | return os.path.join(output_directory, basename)
103 |
104 |
105 | def get_pretrain_state_path(output_directory, epoch=0):
106 | """
107 | :param output_directory:
108 | :param epoch: Number of completed epochs. I.e., 0 = initial.
109 | :return:
110 | """
111 | basename = 'pretrain_state_{:04d}.pt'.format(epoch)
112 | return os.path.join(output_directory, basename)
113 |
114 |
115 | def get_final_pretrain_state_path(output_directory):
116 | glob_pattern = os.path.join(output_directory, 'pretrain_state_*.pt')
117 | paths = glob.glob(glob_pattern)
118 |
119 | pattern = re.compile('pretrain_state_(\d{4}).pt')
120 | paths_by_epoch = dict()
121 | for path in paths:
122 | match = pattern.search(path)
123 | if match:
124 | paths_by_epoch[match.group(1)] = path
125 |
126 | if len(paths_by_epoch) == 0:
127 | raise FileNotFoundError('Could not find valid pre-train state file in {}'.format(output_directory))
128 |
129 | max_epoch = max(paths_by_epoch.keys())
130 | return paths_by_epoch[max_epoch]
131 |
132 |
133 | def get_pretrain_params_path(output_directory):
134 | return os.path.join(output_directory, 'pretrain_params.json')
135 |
136 |
137 | def get_ft_output_directory(params, makedirs=True):
138 | path = get_output_directory(params, makedirs=makedirs)
139 | if not params.ut:
140 | path = os.path.join(path, params.target_dataset)
141 | ft_basename = '{:02d}way_{:03d}shot_{}_{}'.format(params.n_way, params.n_shot, params.ft_parts, params.ft_tag)
142 | path = os.path.join(path, ft_basename)
143 |
144 | if makedirs:
145 | os.makedirs(path, exist_ok=True)
146 |
147 | return path
148 |
149 |
150 | def get_ft_params_path(output_directory):
151 | return os.path.join(output_directory, 'params.json')
152 |
153 |
154 | def get_ft_train_history_path(output_directory):
155 | return os.path.join(output_directory, 'train_history.csv')
156 |
157 |
158 | def get_ft_test_history_path(output_directory):
159 | return os.path.join(output_directory, 'test_history.csv')
160 |
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | import torch.optim
8 | from tqdm import tqdm
9 |
10 | from backbone import get_backbone_class
11 | from datasets.dataloader import get_dataloader, get_unlabeled_dataloader
12 | from io_utils import parse_args
13 | from model import get_model_class
14 | from paths import get_output_directory, get_final_pretrain_state_path, get_pretrain_state_path, \
15 | get_pretrain_params_path, get_pretrain_history_path
16 | from scheduler import RepeatedMultiStepLR
17 |
18 |
19 | def _get_dataloaders(params):
20 | labeled_source_bs = params.ls_batch_size
21 | batch_size = params.batch_size
22 | unlabeled_source_bs = batch_size
23 | unlabeled_target_bs = batch_size
24 |
25 | if params.us and params.ut:
26 | unlabeled_source_bs //= 2
27 | unlabeled_target_bs //= 2
28 |
29 | ls, us, ut = None, None, None
30 | if params.ls:
31 | print('Using source data {} (labeled)'.format(params.source_dataset))
32 | ls = get_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation,
33 | batch_size=labeled_source_bs, num_workers=params.num_workers)
34 |
35 | if params.us:
36 | print('Using source data {} (unlabeled)'.format(params.source_dataset))
37 | us = get_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation,
38 | batch_size=unlabeled_source_bs, num_workers=params.num_workers,
39 | siamese=True) # important
40 |
41 | if params.ut:
42 | print('Using target data {} (unlabeled)'.format(params.target_dataset))
43 | ut = get_unlabeled_dataloader(dataset_name=params.target_dataset, augmentation=params.augmentation,
44 | batch_size=unlabeled_target_bs, num_workers=params.num_workers, siamese=True,
45 | unlabeled_ratio=params.unlabeled_ratio)
46 |
47 | return ls, us, ut
48 |
49 |
50 | def main(params):
51 | backbone = get_backbone_class(params.backbone)()
52 | model = get_model_class(params.model)(backbone, params)
53 | output_dir = get_output_directory(params)
54 | labeled_source_loader, unlabeled_source_loader, unlabeled_target_loader = _get_dataloaders(params)
55 |
56 | params_path = get_pretrain_params_path(output_dir)
57 | with open(params_path, 'w') as f:
58 | json.dump(vars(params), f, indent=4)
59 | pretrain_history_path = get_pretrain_history_path(output_dir)
60 | print('Saving pretrain params to {}'.format(params_path))
61 | print('Saving pretrain history to {}'.format(pretrain_history_path))
62 |
63 | if params.pls or params.put or params.pmsl:
64 | # Load previous pre-trained weights for second-step pre-training
65 | previous_base_output_dir = get_output_directory(params, previous=True)
66 | state_path = get_final_pretrain_state_path(previous_base_output_dir)
67 | print('Loading previous state for second-step pre-training:')
68 | print(state_path)
69 |
70 | # Note, override model.load_state_dict to change this behavior.
71 | state = torch.load(state_path)
72 | # del state["classifier.weight"] # hotfixes for using weights pre-trained on different source datasets w/o LS
73 | # del state["classifier.bias"]
74 | missing, unexpected = model.load_state_dict(state, strict=False)
75 | if len(unexpected):
76 | raise Exception("Unexpected keys from previous state: {}".format(unexpected))
77 | elif params.imagenet_pretrained:
78 | print("Loading ImageNet pretrained weights")
79 | backbone.load_imagenet_weights()
80 |
81 | model.train()
82 | model.cuda()
83 |
84 | if params.optimizer == 'sgd':
85 | optimizer = torch.optim.SGD(model.parameters(),
86 | lr=params.lr, momentum=0.9,
87 | weight_decay=1e-4,
88 | nesterov=False)
89 | elif params.optimizer == 'adam':
90 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
91 | else:
92 | raise ValueError('Invalid value for params.optimizer: {}'.format(params.optimizer))
93 |
94 | if params.scheduler == "MultiStepLR":
95 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params.scheduler_milestones, gamma=0.1)
96 | elif params.scheduler == "RepeatedMultiStepLR":
97 | scheduler = RepeatedMultiStepLR(optimizer, milestones=params.scheduler_milestones, interval=1000, gamma=0.1)
98 | else:
99 | raise ValueError("Invalid value for params.scheduler: {}".format(params.scheduler))
100 |
101 | pretrain_history = {
102 | 'loss': [0] * params.epochs,
103 | 'source_loss': [0] * params.epochs,
104 | 'target_loss': [0] * params.epochs,
105 | }
106 |
107 | for epoch in range(params.epochs):
108 | print('EPOCH {}'.format(epoch).center(40).center(80, '#'))
109 |
110 | epoch_loss = 0
111 | epoch_source_loss = 0
112 | epoch_target_loss = 0
113 | steps = 0
114 |
115 | if epoch == 0:
116 | state_path = get_pretrain_state_path(output_dir, epoch=0)
117 | print('Saving pre-train state to:')
118 | print(state_path)
119 | torch.save(model.state_dict(), state_path)
120 |
121 | model.on_epoch_start()
122 | model.train()
123 |
124 | if params.ls and not params.us and not params.ut: # only ls (type 1)
125 | for x, y in tqdm(labeled_source_loader):
126 | model.on_step_start()
127 | optimizer.zero_grad()
128 | loss, _ = model.compute_cls_loss_and_accuracy(x.cuda(), y.cuda())
129 | loss.backward()
130 | optimizer.step()
131 | model.on_step_end()
132 |
133 | epoch_loss += loss.item()
134 | epoch_source_loss += loss.item()
135 | steps += 1
136 | elif not params.ls and params.us and not params.ut: # only us (type 2)
137 | for x, _ in tqdm(unlabeled_source_loader):
138 | model.on_step_start()
139 | optimizer.zero_grad()
140 | loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda())
141 | loss.backward()
142 | optimizer.step()
143 | model.on_step_end()
144 |
145 | epoch_loss += loss.item()
146 | epoch_source_loss += loss.item()
147 | steps += 1
148 | elif params.ut: # ut (epoch is based on unlabeled target)
149 | max_epochs = params.epochs
150 | max_steps = len(unlabeled_target_loader) * max_epochs
151 | for i, (x, _) in enumerate(tqdm(unlabeled_target_loader)):
152 | current_step = epoch * len(unlabeled_target_loader) + i
153 | model.on_step_start()
154 | optimizer.zero_grad()
155 | target_loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda()) # UT loss
156 | epoch_target_loss += target_loss.item()
157 | source_loss = None
158 | if params.ls: # type 4, 7
159 | try:
160 | sx, sy = labeled_source_loader_iter.next()
161 | except (StopIteration, NameError):
162 | labeled_source_loader_iter = iter(labeled_source_loader)
163 | sx, sy = labeled_source_loader_iter.next()
164 | source_loss = model.compute_cls_loss_and_accuracy(sx.cuda(), sy.cuda())[0] # LS loss
165 | epoch_source_loss += source_loss.item()
166 | if params.us: # type 5, 8
167 | try:
168 | sx, sy = unlabeled_source_loader_iter.next()
169 | except (StopIteration, NameError):
170 | unlabeled_source_loader_iter = iter(unlabeled_source_loader)
171 | sx, sy = unlabeled_source_loader_iter.next()
172 | source_loss = model.compute_ssl_loss(sx[0].cuda(), sx[1].cuda()) # US loss
173 | epoch_source_loss += source_loss.item()
174 | if source_loss:
175 | if params.gamma_schedule is None:
176 | gamma = params.gamma
177 | elif params.gamma_schedule == "linear":
178 | gamma = current_step / (max_steps - 1) # gamma \in [0, 1]
179 | assert 0 <= gamma <= 1 # temp
180 | else:
181 | raise AssertionError("Invalid params.gamma_schedule (should be checked during argparse)")
182 | loss = source_loss * (1 - gamma) + target_loss * gamma
183 | else:
184 | loss = target_loss
185 | loss.backward()
186 | optimizer.step()
187 | model.on_step_end()
188 |
189 | epoch_loss += loss.item()
190 | steps += 1
191 | else:
192 | raise AssertionError('Unknown training combination.')
193 |
194 | if scheduler is not None:
195 | scheduler.step()
196 | model.on_epoch_end()
197 |
198 | mean_loss = epoch_loss / steps
199 | mean_source_loss = epoch_source_loss / steps
200 | mean_target_loss = epoch_target_loss / steps
201 | fmt = 'Epoch {:04d}: loss={:6.4f} source_loss={:6.4f} target_loss={:6.4f}'
202 | print(fmt.format(epoch, mean_loss, mean_source_loss, mean_target_loss))
203 |
204 | pretrain_history['loss'][epoch] = mean_loss
205 | pretrain_history['source_loss'][epoch] = mean_source_loss
206 | pretrain_history['target_loss'][epoch] = mean_target_loss
207 |
208 | pd.DataFrame(pretrain_history).to_csv(pretrain_history_path)
209 |
210 | epoch += 1
211 | if epoch % params.model_save_interval == 0 or epoch == params.epochs:
212 | state_path = get_pretrain_state_path(output_dir, epoch=epoch)
213 | print('Saving pre-train state to:')
214 | print(state_path)
215 | torch.save(model.state_dict(), state_path)
216 |
217 |
218 | if __name__ == '__main__':
219 | np.random.seed(10)
220 | params = parse_args()
221 |
222 | targets = params.target_dataset
223 | if targets is None:
224 | targets = [targets]
225 | elif len(targets) > 1:
226 | print('#' * 80)
227 | print("Running pretrain iteratively for multiple target datasets: {}".format(targets))
228 | print('#' * 80)
229 |
230 | for target in targets:
231 | params.target_dataset = target
232 | main(params)
233 |
--------------------------------------------------------------------------------
/pretrain.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 |
4 | SOURCES=("miniImageNet" "tieredImageNet" "ImageNet")
5 | SOURCE=${SOURCES[2]}
6 |
7 | TARGETS=("CropDisease" "ISIC" "EuroSAT" "ChestX" "places" "cub" "plantae" "cars")
8 | TARGET=${TARGETS[0]}
9 |
10 | # BACKBONE=resnet10 # for mini
11 | BACKBONE=resnet18 # for tiered and full imagenet
12 |
13 |
14 | # Source SL (note, we adapt the torchvision pre-trained model for ResNet18 + ImageNet. Do not use this command as-is.)
15 | python pretrain.py --ls --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "base" --tag "default"
16 |
17 | # Target SSL
18 | python pretrain.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default"
19 | python pretrain.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default"
20 |
21 | # MSL (Source SL + Target SSL)
22 | python pretrain.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78"
23 | python pretrain.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78"
24 |
25 | # Two-Stage SSL (Source SL -> Target SSL)
26 | python pretrain.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default" --previous_tag "default"
27 | python pretrain.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default" --previous_tag "default"
28 |
29 | # Two-Stage MSL (Source SL -> Source SL + Target SSL)
30 | python pretrain.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78" --previous_tag "default"
31 | python pretrain.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78" --previous_tag "default"
--------------------------------------------------------------------------------
/pretrain_new_lu20.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | import torch.optim
8 | from tqdm import tqdm
9 |
10 | from backbone import get_backbone_class
11 | from datasets.dataloader import get_dataloader, get_unlabeled_dataloader
12 | from io_utils import parse_args
13 | from model import get_model_class
14 | from paths import get_output_directory, get_final_pretrain_state_path, get_pretrain_state_path, \
15 | get_pretrain_params_path, get_pretrain_history_path
16 |
17 |
18 | def _get_dataloaders(params):
19 | batch_size = params.batch_size
20 | labeled_source_bs = batch_size
21 | unlabeled_source_bs = batch_size
22 | unlabeled_target_bs = batch_size
23 |
24 | if params.us and params.ut:
25 | unlabeled_source_bs //= 2
26 | unlabeled_target_bs //= 2
27 |
28 | ls, us, ut = None, None, None
29 | if params.ls:
30 | print('Using source data {} (labeled)'.format(params.source_dataset))
31 | ls = get_unlabeled_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation,
32 | batch_size=labeled_source_bs, siamese=False, unlabeled_ratio=params.unlabeled_ratio,
33 | num_workers=params.num_workers, split_seed=params.split_seed)
34 |
35 | if params.us:
36 | raise NotImplementedError
37 | print('Using source data {} (unlabeled)'.format(params.source_dataset))
38 | us = get_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation,
39 | batch_size=unlabeled_source_bs, num_workers=params.num_workers,
40 | siamese=True) # important
41 |
42 | if params.ut:
43 | print('Using target data {} (unlabeled)'.format(params.target_dataset))
44 | ut = get_unlabeled_dataloader(dataset_name=params.target_dataset, augmentation=params.augmentation,
45 | batch_size=unlabeled_target_bs, num_workers=params.num_workers, siamese=True,
46 | unlabeled_ratio=params.unlabeled_ratio)
47 |
48 | return ls, us, ut
49 |
50 |
51 | def main(params):
52 | backbone = get_backbone_class(params.backbone)()
53 | model = get_model_class(params.model)(backbone, params)
54 | output_dir = get_output_directory(params)
55 | labeled_source_loader, unlabeled_source_loader, unlabeled_target_loader = _get_dataloaders(params)
56 |
57 | params_path = get_pretrain_params_path(output_dir)
58 | with open(params_path, 'w') as f:
59 | json.dump(vars(params), f, indent=4)
60 | pretrain_history_path = get_pretrain_history_path(output_dir)
61 | print('Saving pretrain params to {}'.format(params_path))
62 | print('Saving pretrain history to {}'.format(pretrain_history_path))
63 |
64 | if params.pls:
65 | # Load previous pre-trained weights for second-step pre-training
66 | previous_base_output_dir = get_output_directory(params, pls_previous=True)
67 | state_path = get_final_pretrain_state_path(previous_base_output_dir)
68 | print('Loading previous state for second-step pre-training:')
69 | print(state_path)
70 |
71 | # Note, override model.load_state_dict to change this behavior.
72 | state = torch.load(state_path)
73 | missing, unexpected = model.load_state_dict(state, strict=False)
74 | if len(unexpected):
75 | raise Exception("Unexpected keys from previous state: {}".format(unexpected))
76 |
77 | model.train()
78 | model.cuda()
79 |
80 | if params.optimizer == 'sgd':
81 | optimizer = torch.optim.SGD(model.parameters(),
82 | lr=params.lr, momentum=0.9,
83 | weight_decay=1e-4,
84 | nesterov=False)
85 | elif params.optimizer == 'adam':
86 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
87 | else:
88 | raise ValueError('Invalid value for params.optimizer: {}'.format(params.optimizer))
89 |
90 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
91 | milestones=[400, 600, 800],
92 | gamma=0.1)
93 |
94 | pretrain_history = {
95 | 'loss': [0] * params.epochs,
96 | 'source_loss': [0] * params.epochs,
97 | 'target_loss': [0] * params.epochs,
98 | }
99 |
100 | for epoch in range(params.epochs):
101 | print('EPOCH {}'.format(epoch).center(40).center(80, '#'))
102 |
103 | epoch_loss = 0
104 | epoch_source_loss = 0
105 | epoch_target_loss = 0
106 | steps = 0
107 |
108 | if epoch == 0:
109 | state_path = get_pretrain_state_path(output_dir, epoch=0)
110 | print('Saving pre-train state to:')
111 | print(state_path)
112 | torch.save(model.state_dict(), state_path)
113 |
114 | model.on_epoch_start()
115 | model.train()
116 |
117 | if params.ls and not params.us and not params.ut: # only ls (type 1)
118 | for x, y in tqdm(labeled_source_loader):
119 | model.on_step_start()
120 | optimizer.zero_grad()
121 | loss, _ = model.compute_cls_loss_and_accuracy(x.cuda(), y.cuda())
122 | loss.backward()
123 | optimizer.step()
124 | model.on_step_end()
125 |
126 | epoch_loss += loss.item()
127 | epoch_source_loss += loss.item()
128 | steps += 1
129 | elif not params.ls and params.us and not params.ut: # only us (type 2)
130 | for x, _ in tqdm(unlabeled_source_loader):
131 | model.on_step_start()
132 | optimizer.zero_grad()
133 | loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda())
134 | loss.backward()
135 | optimizer.step()
136 | model.on_step_end()
137 |
138 | epoch_loss += loss.item()
139 | epoch_source_loss += loss.item()
140 | steps += 1
141 | elif params.ut: # ut (epoch is based on unlabeled target)
142 | for x, _ in tqdm(unlabeled_target_loader):
143 | model.on_step_start()
144 | optimizer.zero_grad()
145 | target_loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda()) # UT loss
146 | epoch_target_loss += target_loss.item()
147 | source_loss = None
148 | if params.ls: # type 4, 7
149 | try:
150 | sx, sy = labeled_source_loader_iter.next()
151 | except (StopIteration, NameError):
152 | labeled_source_loader_iter = iter(labeled_source_loader)
153 | sx, sy = labeled_source_loader_iter.next()
154 | source_loss = model.compute_cls_loss_and_accuracy(sx.cuda(), sy.cuda())[0] # LS loss
155 | epoch_source_loss += source_loss.item()
156 | if params.us: # type 5, 8
157 | try:
158 | sx, sy = unlabeled_source_loader_iter.next()
159 | except (StopIteration, NameError):
160 | unlabeled_source_loader_iter = iter(unlabeled_source_loader)
161 | sx, sy = unlabeled_source_loader_iter.next()
162 | source_loss = model.compute_ssl_loss(sx[0].cuda(), sx[1].cuda()) # US loss
163 | epoch_source_loss += source_loss.item()
164 |
165 | if source_loss:
166 | loss = source_loss * (1 - params.gamma) + target_loss * params.gamma
167 | else:
168 | loss = target_loss
169 | loss.backward()
170 | optimizer.step()
171 | model.on_step_end()
172 |
173 | epoch_loss += loss.item()
174 | steps += 1
175 | else:
176 | raise AssertionError('Unknown training combination.')
177 |
178 | if scheduler is not None:
179 | scheduler.step()
180 | model.on_epoch_end()
181 |
182 | mean_loss = epoch_loss / steps
183 | mean_source_loss = epoch_source_loss / steps
184 | mean_target_loss = epoch_target_loss / steps
185 | fmt = 'Epoch {:04d}: loss={:6.4f} source_loss={:6.4f} target_loss={:6.4f}'
186 | print(fmt.format(epoch, mean_loss, mean_source_loss, mean_target_loss))
187 |
188 | pretrain_history['loss'][epoch] = mean_loss
189 | pretrain_history['source_loss'][epoch] = mean_source_loss
190 | pretrain_history['target_loss'][epoch] = mean_target_loss
191 |
192 | pd.DataFrame(pretrain_history).to_csv(pretrain_history_path)
193 |
194 | epoch += 1
195 | if epoch % params.model_save_interval == 0 or epoch == params.epochs:
196 | state_path = get_pretrain_state_path(output_dir, epoch=epoch)
197 | print('Saving pre-train state to:')
198 | print(state_path)
199 | torch.save(model.state_dict(), state_path)
200 |
201 |
202 | if __name__ == '__main__':
203 | np.random.seed(10)
204 | params = parse_args('pretrain')
205 |
206 | targets = params.target_dataset
207 | if targets is None:
208 | targets = [targets]
209 | elif len(targets) > 1:
210 | print('#' * 80)
211 | print("Running pretrain iteratively for multiple target datasets: {}".format(targets))
212 | print('#' * 80)
213 |
214 | for target in targets:
215 | params.target_dataset = target
216 | main(params)
217 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py
2 | numpy==1.22.0
3 | pandas==1.3.5
4 | Pillow==10.2.0
5 | torch==1.13.1
6 | torchvision==0.14.1
7 | tqdm==4.62.3
8 |
--------------------------------------------------------------------------------
/scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import Adam
3 | from torch.optim.lr_scheduler import LambdaLR
4 | from torchvision.models import resnet18
5 |
6 |
7 | class RepeatedMultiStepLR(LambdaLR):
8 | def __init__(self, optimizer, milestones=(400, 600, 800), gamma=0.1, interval=1000, **kwargs):
9 | self.milestones = milestones
10 | self.interval = interval
11 | self.gamma = gamma
12 | super().__init__(optimizer, self._lambda, **kwargs)
13 |
14 | def _lambda(self, epoch):
15 | factor = 1
16 | for milestone in self.milestones:
17 | if epoch % self.interval >= milestone:
18 | factor *= self.gamma
19 | return factor
20 |
21 |
22 | def main():
23 | resnet = resnet18()
24 |
25 | optimizer1 = Adam(resnet.parameters(), lr=0.1)
26 | optimizer2 = Adam(resnet.parameters(), lr=0.1)
27 |
28 | s1 = torch.optim.lr_scheduler.MultiStepLR(optimizer1, milestones=[400, 600, 800], gamma=0.1)
29 | s2 = RepeatedMultiStepLR(optimizer2, milestones=[400, 600, 800])
30 | s1_history = []
31 | s2_history = []
32 |
33 | for i in range(2000):
34 | # print("Epoch {:04d}: {:.6f} / {:.6f}".format(i, s1.get_last_lr()[0], s2.get_last_lr()[0]))
35 | s1_history.append(s1.get_last_lr()[0])
36 | s2_history.append(s2.get_last_lr()[0])
37 | s1.step()
38 | s2.step()
39 |
40 | assert (s1_history[:1000] == s2_history[:1000])
41 | assert (s1_history[:1000] == s2_history[1000:])
42 |
43 | print("Manual test passed!")
44 |
45 |
46 | if __name__ == "__main__": # manual unit test
47 | main()
48 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='CD-FSL',
5 | version='',
6 | packages=['data', 'model', 'methods', 'datasets'],
7 | url='',
8 | license='',
9 | author='',
10 | author_email='',
11 | description=''
12 | )
13 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def adjust_learning_rate(optimizer, epoch, lr=0.01, step1=30, step2=60, step3=90):
5 | """Sets the learning rate to the initial LR decayed by 10 every X epochs"""
6 | if epoch >= step3:
7 | lr = lr * 0.001
8 | elif epoch >= step2:
9 | lr = lr * 0.01
10 | elif epoch >= step1:
11 | lr = lr * 0.1
12 | else:
13 | lr = lr
14 | for param_group in optimizer.param_groups:
15 | param_group['lr'] = lr
16 |
17 | class AverageMeter(object):
18 | """Computes and stores the average and current value"""
19 | def __init__(self):
20 | self.reset()
21 |
22 | def reset(self):
23 | self.val = 0
24 | self.avg = 0
25 | self.sum = 0
26 | self.count = 0
27 |
28 | def update(self, val, n=1):
29 | self.val = val
30 | self.sum += val * n
31 | self.count += n
32 | self.avg = self.sum / self.count
33 |
34 |
35 | def one_hot(y, num_class):
36 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1)
37 |
38 | def sparsity(cl_data_file):
39 | class_list = cl_data_file.keys()
40 | cl_sparsity = []
41 | for cl in class_list:
42 | cl_sparsity.append(np.mean([np.sum(x!=0) for x in cl_data_file[cl] ]) )
43 |
44 | return np.mean(cl_sparsity)
--------------------------------------------------------------------------------