├── .gitignore ├── LICENSE ├── README.md ├── backbone.py ├── configs.py ├── data_preparation ├── cars.py ├── input │ └── .gitignore ├── output │ └── .gitignore ├── places.py ├── places_cdfsl_subset_16_class_1715_sample_seed_0.json ├── places_plantae_subset_sampler.ipynb ├── plantae.py └── plantae_cdfsl_subset_69_class.json ├── datasets ├── __init__.py ├── dataloader.py ├── datasets.py ├── sampler.py ├── split.py ├── split_seed_1 │ ├── ChestX_labeled_80.csv │ ├── ChestX_unlabeled_20.csv │ ├── CropDisease_labeled_80.csv │ ├── CropDisease_unlabeled_20.csv │ ├── EuroSAT_labeled_80.csv │ ├── EuroSAT_unlabeled_20.csv │ ├── ISIC_labeled_80.csv │ ├── ISIC_unlabeled_20.csv │ ├── cars_labeled_80.csv │ ├── cars_unlabeled_20.csv │ ├── cub_labeled_80.csv │ ├── cub_unlabeled_20.csv │ ├── miniImageNet_test_labeled_80.csv │ ├── miniImageNet_test_unlabeled_20.csv │ ├── places_labeled_80.csv │ ├── places_unlabeled_20.csv │ ├── plantae_labeled_80.csv │ ├── plantae_unlabeled_20.csv │ ├── tieredImageNet_test_labeled_80.csv │ └── tieredImageNet_test_unlabeled_20.csv └── transforms.py ├── finetune.py ├── finetune.sh ├── io_utils.py ├── methods ├── __init__.py ├── baselinefinetune.py ├── baselinetrain.py ├── boil.py ├── byol.py ├── maml.py ├── meta_template.py └── protonet.py ├── model ├── __init__.py ├── base.py ├── byol.py ├── classifier_head.py ├── moco.py ├── simclr.py └── simsiam.py ├── paths.py ├── pretrain.py ├── pretrain.sh ├── pretrain_new_lu20.py ├── requirements.txt ├── scheduler.py ├── setup.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pt 3 | *.json 4 | 5 | scripts/ 6 | 7 | # user-defined files 8 | *.log 9 | *.tar 10 | *.pkl 11 | *.csv 12 | *.JPEG 13 | wandb/ 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Official] Understanding Cross-Domain Few-Shot Learning Based on Domain Similarity and Few-shot Difficulty 2 | 3 | This repo contains the implementation of our [paper](https://arxiv.org/abs/2202.01339) accepted at NeurIPS 2022. 4 | 5 | ## Abstract 6 | Cross-domain few-shot learning (CD-FSL) has drawn increasing attention for handling large differences between the source and target domains--an important concern in real-world scenarios. To overcome these large differences, recent works have considered exploiting small-scale unlabeled data from the target domain during the pre-training stage. This data enables self-supervised pre-training on the target domain, in addition to supervised pre-training on the source domain. In this paper, we empirically investigate which pre-training is preferred based on domain similarity and few-shot difficulty of the target domain. We discover that the performance gain of self-supervised pre-training over supervised pre-training becomes large when the target domain is dissimilar to the source domain, or the target domain itself has low few-shot difficulty. We further design two pre-training schemes, mixed-supervised and two-stage learning, that improve performance. In this light, we present six findings for CD-FSL, which are supported by extensive experiments and analyses on three source and eight target benchmark datasets with varying levels of domain similarity and few-shot difficulty. 7 | 8 | ## Table of Contents 9 | 10 | * [Prerequisites](#prerequisites) 11 | * [Data Preparation](#data-preparation) 12 | * [BSCD-FSL](#bscd-fsl) 13 | * [Cars](#cars) 14 | * [CUB](#cub (caltech-ucsd birds-200-2011)) 15 | * [Places](#places) 16 | * [Plantae](#plantae) 17 | * [Usage](#usage) 18 | * [Model Checkpoints](#model-checkpoints) 19 | * [Attribution](#attribution) 20 | * [License](#license) 21 | * [Citation](#citation) 22 | * [Contact](#contact) 23 | 24 | ## Prerequisites 25 | 26 | Our code works on `torch>=1.8`. Install the required Python packages via 27 | 28 | ```sh 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ## Data Preparation 33 | 34 | Prepare and place all dataset folders in `/data/cdfsl/`. You may specify custom locations in `configs.py`. 35 | 36 | ### BSCD-FSL 37 | 38 | Refer to the original BSCD-FSL [repository](https://github.com/IBM/cdfsl-benchmark). 39 | 40 | The dataset folders should be organized in `/data/cdfsl/` as follows: 41 | 42 | ``` 43 | CropDiseases/train 44 | ├── Apple___Apple_scab 45 | │ ├── 00075aa8-d81a-4184-8541-b692b78d398a___FREC_Scab 3335.JPG 46 | │ ├── 0208f4eb-45a4-4399-904e-989ac2c6257c___FREC_Scab 3037.JPG 47 | 48 | EuroSAT 49 | ├── AnnualCrop 50 | │ ├── AnnualCrop_1.jpg 51 | │ ├── AnnualCrop_2.jpg 52 | 53 | ISIC 54 | ├── ATTRIBUTION.txt 55 | ├── ISIC2018_Task3_Training_GroundTruth.csv 56 | ├── ISIC2018_Task3_Training_LesionGroupings.csv 57 | ├── ISIC_0024306.jpg 58 | ├── ISIC_0024307.jpg 59 | 60 | chestX/images 61 | ├── 00000001_000.png 62 | ├── 00000001_001.png 63 | ``` 64 | 65 | Note: `chestX/images/` should contain **all** images from the ChestX dataset (the dataset archive provided online will 66 | typically split these images across multiple folders). 67 | 68 | ### Cars 69 | 70 | https://ai.stanford.edu/~jkrause/cars/car_dataset.html 71 | 72 | We use all images from both training and test sets for CD-FSL experiments. 73 | For convenience, we pre-process the data such that each image goes into its respective class folder. 74 | 75 | 1. Download [`car_ims.tgz`](http://ai.stanford.edu/~jkrause/car196/car_ims.tgz) (the tar of all images) and [`cars_annos.mat`](http://ai.stanford.edu/~jkrause/car196/cars_annos.mat) (all bounding boxes and labels for both training and test). 76 | 2. Copy `cars_annos.mat` and unzip `car_ims.tgz` into `./data_preparation/input/`. The directory should contain the following: 77 | ``` 78 | data_preparation/input 79 | ├── cars_annos.mat 80 | ├── car_ims 81 | │ ├── 000001.jpg 82 | │ ├── 000002.jpg 83 | ``` 84 | 3. Run `./data_preparation/cars.py`, to generate the cars dataset folder at `./data_preparation/output/cars_cdfsl/`. 85 | 4. Move the `cars_cdfsl` directory to `/data/cdfsl/`. You may specify a custom location in `configs.py`. 86 | 87 | ### CUB (Caltech-UCSD Birds-200-2011) 88 | 89 | http://www.vision.caltech.edu/datasets/cub_200_2011/ 90 | 91 | We use all images for CD-FSL experiments. 92 | 93 | 1. Download [`CUB_200_2011.tgz`](https://data.caltech.edu/records/20098). 94 | 2. Unzip the archive and copy the *enclosed* `CUB_200_2011/` folder to `/data/cdfsl/`. You may specify a custom location in `configs.py`. The directory should contain the following: 95 | ``` 96 | CUB_200_2011/ 97 | ├── attributes/ 98 | ├── bounding_boxes.txt 99 | ├── classes.txt 100 | ├── image_class_labels.txt 101 | ├── images/ 102 | ├── images.txt 103 | ├── parts/ 104 | ├── README 105 | └── train_test_split.txt 106 | ``` 107 | 108 | ### Places (Places 205) 109 | 110 | http://places.csail.mit.edu/user/ 111 | 112 | Due to the size of the original dataset, we only use a subset of the training set for CD-FSL experiments. 113 | We use 27,440 images from 16 classes. Please refer to the paper for details or refer to the subset sampling code at 114 | `data_prepratation/places_plantae_subset_sampler.ipynb`. 115 | 116 | 1. Download [places365standard_easyformat.tar](http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar). 117 | 2. Unzip the archive into `./data_preparation/input/`. The directory should contain the following: 118 | ``` 119 | data_preparation/input/ 120 | ├── places365_standard/ 121 | │ ├── train/ 122 | │ │ ├── airfield/ 123 | │ │ ├── airplane_cabin/ 124 | │ │ ├── aiport_terminal/ 125 | ``` 126 | 3. Run `./data_preparation/places.py` to generate the places dataset folder at `./data_preparation/outuput/places_cdfsl/`. 127 | 4. Move the `places_cdfsl` directory to `/data/cdfsl/`. You may specify a custom location in `configs.py`. 128 | 129 | ### Plantae (from iNaturalist 2018) 130 | 131 | https://github.com/visipedia/inat_comp/tree/master/2018#Data 132 | 133 | Due to the size of the original dataset, we only use a subset of the training set (of the Plantae super category) for CD-FSL experiments. 134 | We use 26,650 images from 69 classes. Please refer to the paper for details or refer to the subset sampling code at 135 | `data_prepratation/places_plantae_subset_sampler.ipynb` 136 | 137 | 1. Download [`train_val2018.tar.gz`](https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz) (~120GB). 138 | 2. Unzip the archive and copy the enclosed `Plantae/` folder (~43GB) to `./data_preparation/input/`. The directory should contain the following: 139 | ``` 140 | data_preparation/input 141 | └── Plantae/ 142 | ├── 5221/ 143 | ├── 5222/ 144 | ├── 5223/ 145 | ``` 146 | 3. Run `./data_preparation/plantae.py` to generate the plantae dataset folder at `./data_preparation/outuput/plantae_cdfsl`. 147 | 4. Move the `plantae_cdfsl` directory to `/data/cdfsl/`. You may specify a custom location in `configs.py`. 148 | 149 | ## Usage 150 | 151 | The main training scripts are `pretrain.py` and `finetune.py`. Refer to `pretrain.sh` and `finetune.sh` on example 152 | usages for the main results in our paper, e.g., SL, SSL, MSL, two-stage SSL and two-stage MSL. 153 | To see all CLI arguments, refer to `io_utils.py`. 154 | 155 | ## Model Checkpoints 156 | 157 | | Backbone | Pretraining | Augmentation | Model Checkpoints | 158 | | :-------- | :------------: |:---------: |:--------------:| 159 | | ResNet10 | miniImageNet (SL) | default (strong) | [google drive](https://drive.google.com/file/d/1J4weUMgMhdjYe0sbPBNavaf5D7aRkAog/view?usp=sharing) | 160 | | ResNet10 | miniImageNet (SL) | base | [google drive](https://drive.google.com/file/d/11HSAg85vlS67sVsEgd-RYgksnlX61WOj/view?usp=sharing) | 161 | | ResNet18 | tieredImageNet (SL) | default (strong) | [google drive](https://drive.google.com/file/d/1hRbE5VwDvgsKV6E7okOgqJaNOVsSfitP/view?usp=share_link) | 162 | | ResNet18 | tieredImageNet (SL) | base | [google drive](https://drive.google.com/file/d/1-UOzOG-NkqRFXng-Zu3RAbuMhZJZfEhN/view?usp=share_link) | 163 | | ResNet18 | ImageNet (SL) | base | [torchvision](https://pytorch.org/vision/stable/models.html) | 164 | 165 | ## Attribution 166 | 167 | Below, we provide the licenses, attribution, citations, and URL of the datasets considered in our paper (if applicable). 168 | 169 | - **ImageNet** is available at https://image-net.org/. miniImageNet and tieredImageNet are subsets of ImageNet. 170 | 1. *Deng, Jia, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. "Imagenet: A large-scale hierarchical image database." In 2009 IEEE conference on computer vision and pattern recognition, pp. 248-255. Ieee, 2009.* 171 | - **CropDisease** refers to the Plant Disease dataset on Kaggle, licensed under GPL 2. It is available at https://www.kaggle.com/saroz014/plant-disease/. 172 | - **EuroSAT** is available at https://github.com/phelber/eurosat. 173 | 1. *Helber, Patrick, Benjamin Bischke, Andreas Dengel, and Damian Borth. "Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification." IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing 12, no. 7 (2019): 2217-2226.* 174 | 2. *Helber, Patrick, Benjamin Bischke, Andreas Dengel, and Damian Borth. "Introducing eurosat: A novel dataset and deep learning benchmark for land use and land cover classification." In IGARSS 2018-2018 IEEE international geoscience and remote sensing symposium, pp. 204-207. IEEE, 2018.* 175 | - **ISIC** refers to the ISIC 2018 Challenge Task 3 dataset. It is available at https://challenge.isic-archive.com. 176 | 1. *Codella, Noel, Veronica Rotemberg, Philipp Tschandl, M. Emre Celebi, Stephen Dusza, David Gutman, Brian Helba et al. "Skin lesion analysis toward melanoma detection 2018: A challenge hosted by the international skin imaging collaboration (isic)." arXiv preprint arXiv:1902.03368 (2019).* 177 | 2. *Tschandl, Philipp, Cliff Rosendahl, and Harald Kittler. "The HAM10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions." Scientific data 5, no. 1 (2018): 1-9.* 178 | - **ChestX** refers to the Chest-Xray8 provided by the NIH Clinical Center. The dataset is available at https://nihcc.app.box.com/v/ChestXray-NIHCC. 179 | 1. *Wang, Xiaosong, Yifan Peng, Le Lu, Zhiyong Lu, Mohammadhadi Bagheri, and Ronald M. Summers. "Chestx-ray8: Hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2097-2106. 2017.* 180 | - **Places** is licensed under CC BY. It is available at http://places.csail.mit.edu/user. 181 | 1. *Zhou, Bolei, Agata Lapedriza, Aditya Khosla, Aude Oliva, and Antonio Torralba. "Places: A 10 million image database for scene recognition." IEEE transactions on pattern analysis and machine intelligence 40, no. 6 (2017): 1452-1464.* 182 | - **Plantae** refers to the Plantae super category of the iNaturalist 2018 dataset. It is available at https://github.com/visipedia/inat_comp/tree/master/2017. 183 | 1. *Van Horn, Grant, Oisin Mac Aodha, Yang Song, Yin Cui, Chen Sun, Alex Shepard, Hartwig Adam, Pietro Perona, and Serge Belongie. "The inaturalist species classification and detection dataset." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 8769-8778. 2018.* 184 | - **CUB** refers to the Caltech-UCSD Birds-200-2011 dataset, licensed under CC0 (public domain). It is available at https://www.kaggle.com/datasets/veeralakrishna/200-bird-species-with-11788-images. 185 | 1. *Wah, Catherine, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. "The caltech-ucsd birds-200-2011 dataset." (2011).* 186 | - **Cars** is available at https://ai.stanford.edu/~jkrause/cars/car_dataset.html. 187 | 1. *Krause, Jonathan, Michael Stark, Jia Deng, and Li Fei-Fei. "3d object representations for fine-grained categorization." In Proceedings of the IEEE international conference on computer vision workshops, pp. 554-561. 2013.* 188 | 189 | ## License 190 | 191 | Distributed under the MIT License. 192 | 193 | ## Citation 194 | 195 | If you find this repo useful for your research, please consider citing our paper: 196 | 197 | ``` 198 | @inproceedings{oh2022understanding, 199 | title={Understanding Cross-Domain Few-Shot Learning Based on Domain Similarity and Few-Shot Difficulty}, 200 | author={Oh, Jaehoon and Kim, Sungnyun and Ho, Namgyu and Kim, Jin-Hwa and Song, Hwanjun and Yun, Se-Young}, 201 | booktitle={Advances in Neural Information Processing Systems}, 202 | year={2022} 203 | } 204 | ``` 205 | 206 | ## Contact 207 | * Jaehoon Oh: jhoon.oh@kaist.ac.kr 208 | * Sungnyun Kim: ksn4397@kaist.ac.kr 209 | * Namgyu Ho: itsnamgyu@kaist.ac.kr 210 | -------------------------------------------------------------------------------- /backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import Tensor 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from torch.hub import load_state_dict_from_url 10 | from torch.distributions import Bernoulli 11 | from torch.nn.utils.weight_norm import WeightNorm 12 | 13 | 14 | def init_layer(L): 15 | # Initialization using fan-in 16 | if isinstance(L, nn.Conv2d): 17 | n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels 18 | L.weight.data.normal_(0,math.sqrt(2.0/float(n))) 19 | elif isinstance(L, nn.BatchNorm2d): 20 | L.weight.data.fill_(1) 21 | L.bias.data.fill_(0) 22 | 23 | class distLinear(nn.Module): 24 | def __init__(self, indim, outdim): 25 | super(distLinear, self).__init__() 26 | self.L = nn.Linear( indim, outdim, bias = False) 27 | self.class_wise_learnable_norm = True #See the issue#4&8 in the github 28 | if self.class_wise_learnable_norm: 29 | WeightNorm.apply(self.L, 'weight', dim=0) #split the weight update component to direction and norm 30 | 31 | if outdim <=200: 32 | self.scale_factor = 2; #a fixed scale factor to scale the output of cos value into a reasonably large input for softmax 33 | else: 34 | self.scale_factor = 10; #in omniglot, a larger scale factor is required to handle >1000 output classes. 35 | 36 | def forward(self, x): 37 | x_norm = torch.norm(x, p=2, dim =1).unsqueeze(1).expand_as(x) 38 | x_normalized = x.div(x_norm+ 0.00001) 39 | if not self.class_wise_learnable_norm: 40 | L_norm = torch.norm(self.L.weight.data, p=2, dim =1).unsqueeze(1).expand_as(self.L.weight.data) 41 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 42 | cos_dist = self.L(x_normalized) #matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 43 | scores = self.scale_factor* (cos_dist) 44 | 45 | return scores 46 | 47 | class Flatten(nn.Module): 48 | def __init__(self): 49 | super(Flatten, self).__init__() 50 | 51 | def forward(self, x): 52 | return x.view(x.size(0), -1) 53 | 54 | # For meta-learning based algorithms (task-specific weight) 55 | class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight 56 | def __init__(self, in_features, out_features): 57 | super(Linear_fw, self).__init__(in_features, out_features) 58 | self.weight.fast = None #Lazy hack to add fast weight link 59 | self.bias.fast = None 60 | 61 | def forward(self, x): 62 | if self.weight.fast is not None and self.bias.fast is not None: 63 | out = F.linear(x, self.weight.fast, self.bias.fast) #weight.fast (fast weight) is the temporaily adapted weight 64 | else: 65 | out = super(Linear_fw, self).forward(x) 66 | return out 67 | 68 | class Conv2d_fw(nn.Conv2d): #used in MAML to forward input with fast weight 69 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, bias = True): 70 | super(Conv2d_fw, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) 71 | self.weight.fast = None 72 | if not self.bias is None: 73 | self.bias.fast = None 74 | 75 | def forward(self, x): 76 | if self.bias is None: 77 | if self.weight.fast is not None: 78 | out = F.conv2d(x, self.weight.fast, None, stride= self.stride, padding=self.padding) 79 | else: 80 | out = super(Conv2d_fw, self).forward(x) 81 | else: 82 | if self.weight.fast is not None and self.bias.fast is not None: 83 | out = F.conv2d(x, self.weight.fast, self.bias.fast, stride= self.stride, padding=self.padding) 84 | else: 85 | out = super(Conv2d_fw, self).forward(x) 86 | 87 | return out 88 | 89 | class BatchNorm2d_fw(nn.BatchNorm2d): #used in MAML to forward input with fast weight 90 | def __init__(self, num_features): 91 | super(BatchNorm2d_fw, self).__init__(num_features) 92 | self.weight.fast = None 93 | self.bias.fast = None 94 | 95 | def forward(self, x): 96 | running_mean = torch.zeros(x.data.size()[1]).cuda() 97 | running_var = torch.ones(x.data.size()[1]).cuda() 98 | if self.weight.fast is not None and self.bias.fast is not None: 99 | out = F.batch_norm(x, running_mean, running_var, self.weight.fast, self.bias.fast, training = True, momentum = 1) 100 | #batch_norm momentum hack: follow hack of Kate Rakelly in pytorch-maml/src/layers.py 101 | else: 102 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, training = True, momentum = 1) 103 | return out 104 | 105 | # Simple ResNet Block 106 | class SimpleBlock(nn.Module): 107 | def __init__(self, method, indim, outdim, half_res, track_bn): 108 | super(SimpleBlock, self).__init__() 109 | self.indim = indim 110 | self.outdim = outdim 111 | 112 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 113 | self.BN1 = nn.BatchNorm2d(outdim, track_running_stats=track_bn) 114 | self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1, bias=False) 115 | self.BN2 = nn.BatchNorm2d(outdim, track_running_stats=track_bn) 116 | 117 | self.relu1 = nn.ReLU(inplace=True) 118 | self.relu2 = nn.ReLU(inplace=True) 119 | 120 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 121 | 122 | self.half_res = half_res 123 | 124 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 125 | if indim!=outdim: 126 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 127 | self.BNshortcut = nn.BatchNorm2d(outdim, track_running_stats=track_bn) 128 | 129 | self.parametrized_layers.append(self.shortcut) 130 | self.parametrized_layers.append(self.BNshortcut) 131 | self.shortcut_type = '1x1' 132 | else: 133 | self.shortcut_type = 'identity' 134 | 135 | for layer in self.parametrized_layers: 136 | init_layer(layer) 137 | 138 | def forward(self, x): 139 | out = self.C1(x) 140 | out = self.BN1(out) 141 | out = self.relu1(out) 142 | out = self.C2(out) 143 | out = self.BN2(out) 144 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 145 | out = out + short_out 146 | out = self.relu2(out) 147 | 148 | return out 149 | 150 | class ResNet(nn.Module): 151 | def __init__(self, method, block, list_of_num_layers, list_of_out_dims, flatten, track_bn, reinit_bn_stats): 152 | # list_of_num_layers specifies number of layers in each stage 153 | # list_of_out_dims specifies number of output channel for each stage 154 | super(ResNet,self).__init__() 155 | assert len(list_of_num_layers)==4, 'Can have only four stages' 156 | 157 | self.reinit_bn_stats = reinit_bn_stats 158 | 159 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 160 | bn1 = nn.BatchNorm2d(64, track_running_stats=track_bn) 161 | 162 | relu = nn.ReLU() 163 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 | 165 | init_layer(conv1) 166 | init_layer(bn1) 167 | 168 | trunk = [conv1, bn1, relu, pool1] 169 | 170 | indim = 64 171 | for i in range(4): 172 | for j in range(list_of_num_layers[i]): 173 | half_res = (i>=1) and (j==0) 174 | B = block(method, indim, list_of_out_dims[i], half_res, track_bn) 175 | trunk.append(B) 176 | indim = list_of_out_dims[i] 177 | 178 | if flatten: 179 | avgpool = nn.AvgPool2d(7) 180 | trunk.append(avgpool) 181 | trunk.append(Flatten()) 182 | self.final_feat_dim = indim 183 | else: 184 | self.final_feat_dim = [indim, 7, 7] 185 | 186 | self.trunk = nn.Sequential(*trunk) 187 | 188 | def forward(self,x): 189 | if self.reinit_bn_stats: 190 | self._reinit_running_batch_statistics() 191 | out = self.trunk(x) 192 | return out 193 | 194 | def _reinit_running_batch_statistics(self): 195 | with torch.no_grad(): 196 | self.trunk[1].running_mean.data.fill_(0.) 197 | self.trunk[1].running_var.data.fill_(1.) 198 | 199 | self.trunk[4].BN1.running_mean.data.fill_(0.) 200 | self.trunk[4].BN1.running_var.data.fill_(1.) 201 | self.trunk[4].BN2.running_mean.data.fill_(0.) 202 | self.trunk[4].BN2.running_var.data.fill_(1.) 203 | 204 | self.trunk[5].BN1.running_mean.data.fill_(0.) 205 | self.trunk[5].BN1.running_var.data.fill_(1.) 206 | self.trunk[5].BN2.running_mean.data.fill_(0.) 207 | self.trunk[5].BN2.running_var.data.fill_(1.) 208 | self.trunk[5].BNshortcut.running_mean.data.fill_(0.) 209 | self.trunk[5].BNshortcut.running_var.data.fill_(1.) 210 | 211 | self.trunk[6].BN1.running_mean.data.fill_(0.) 212 | self.trunk[6].BN1.running_var.data.fill_(1.) 213 | self.trunk[6].BN2.running_mean.data.fill_(0.) 214 | self.trunk[6].BN2.running_var.data.fill_(1.) 215 | self.trunk[6].BNshortcut.running_mean.data.fill_(0.) 216 | self.trunk[6].BNshortcut.running_var.data.fill_(1.) 217 | 218 | self.trunk[7].BN1.running_mean.data.fill_(0.) 219 | self.trunk[7].BN1.running_var.data.fill_(1.) 220 | self.trunk[7].BN2.running_mean.data.fill_(0.) 221 | self.trunk[7].BN2.running_var.data.fill_(1.) 222 | self.trunk[7].BNshortcut.running_mean.data.fill_(0.) 223 | self.trunk[7].BNshortcut.running_var.data.fill_(1.) 224 | 225 | 226 | def return_features(self, x, return_avg=False): 227 | flat = Flatten() 228 | m = nn.AdaptiveAvgPool2d((1,1)) 229 | 230 | with torch.no_grad(): 231 | block1_out = self.trunk[4](self.trunk[3](self.trunk[2](self.trunk[1](self.trunk[0](x))))) 232 | block2_out = self.trunk[5](block1_out) 233 | block3_out = self.trunk[6](block2_out) 234 | block4_out = self.trunk[7](block3_out) 235 | 236 | if return_avg: 237 | return flat(m(block1_out)), flat(m(block2_out)), flat(m(block3_out)), flat(m(block4_out)) 238 | else: 239 | return flat(block1_out), flat(block2_out), flat(block3_out), flat(block4_out) 240 | 241 | def forward_bodyfreeze(self,x): 242 | flat = Flatten() 243 | m = nn.AdaptiveAvgPool2d((1,1)) 244 | 245 | with torch.no_grad(): 246 | block1_out = self.trunk[4](self.trunk[3](self.trunk[2](self.trunk[1](self.trunk[0](x))))) 247 | block2_out = self.trunk[5](block1_out) 248 | block3_out = self.trunk[6](block2_out) 249 | 250 | out = self.trunk[7].C1(block3_out) 251 | out = self.trunk[7].BN1(out) 252 | out = self.trunk[7].relu1(out) 253 | 254 | out = self.trunk[7].C2(out) 255 | out = self.trunk[7].BN2(out) 256 | short_out = self.trunk[7].BNshortcut(self.trunk[7].shortcut(block3_out)) 257 | out = out + short_out 258 | out = self.trunk[7].relu2(out) 259 | 260 | return flat(m(out)) 261 | 262 | def ResNet10(method='baseline', track_bn=True, reinit_bn_stats=False): 263 | return ResNet(method, block=SimpleBlock, list_of_num_layers=[1,1,1,1], list_of_out_dims=[64,128,256,512], flatten=True, track_bn=track_bn, reinit_bn_stats=reinit_bn_stats) 264 | 265 | # -*- coding: utf-8 -*- 266 | # https://github.com/ElementAI/embedding-propagation/blob/master/src/models/backbones/resnet12.py 267 | 268 | class Block(torch.nn.Module): 269 | def __init__(self, ni, no, stride, dropout, track_bn, reinit_bn_stats): 270 | super().__init__() 271 | self.reinit_bn_stats = reinit_bn_stats 272 | 273 | self.dropout = nn.Dropout2d(dropout) if dropout > 0 else lambda x: x 274 | self.C0 = nn.Conv2d(ni, no, 3, stride, padding=1, bias=False) 275 | self.BN0 = nn.BatchNorm2d(no, track_running_stats=track_bn) 276 | self.C1 = nn.Conv2d(no, no, 3, 1, padding=1, bias=False) 277 | self.BN1 = nn.BatchNorm2d(no, track_running_stats=track_bn) 278 | self.C2 = nn.Conv2d(no, no, 3, 1, padding=1, bias=False) 279 | self.BN2 = nn.BatchNorm2d(no, track_running_stats=track_bn) 280 | if stride == 2 or ni != no: 281 | self.shortcut = nn.Conv2d(ni, no, 1, stride=1, padding=0, bias=False) 282 | self.BNshortcut = nn.BatchNorm2d(no, track_running_stats=track_bn) 283 | 284 | def get_parameters(self): 285 | return self.parameters() 286 | 287 | def forward(self, x): 288 | if self.reinit_bn_stats: 289 | self._reinit_running_batch_statistics() 290 | 291 | out = self.C0(x) 292 | out = self.BN0(out) 293 | out = F.relu(out) 294 | out = self.dropout(out) 295 | out = self.C1(out) 296 | out = self.BN1(out) 297 | out = F.relu(out) 298 | out = self.dropout(out) 299 | out = self.C2(out) 300 | out = self.BN2(out) 301 | out += self.BNshortcut(self.shortcut(x)) 302 | out = F.relu(out) 303 | 304 | return out 305 | 306 | def _reinit_running_batch_statistics(self): 307 | with torch.no_grad(): 308 | self.BN0.running_mean.data.fill_(0.) 309 | self.BN0.running_var.data.fill_(1.) 310 | self.BN1.running_mean.data.fill_(0.) 311 | self.BN1.running_var.data.fill_(1.) 312 | self.BN2.running_mean.data.fill_(0.) 313 | self.BN2.running_var.data.fill_(1.) 314 | self.BNshortcut.running_mean.data.fill_(0.) 315 | self.BNshortcut.running_var.data.fill_(1.) 316 | 317 | class ResNet12(torch.nn.Module): 318 | def __init__(self, track_bn, reinit_bn_stats, width=1, dropout=0): 319 | super().__init__() 320 | self.final_feat_dim = 512 321 | assert(width == 1) # Comment for different variants of this model 322 | self.widths = [x * int(width) for x in [64, 128, 256]] 323 | self.widths.append(self.final_feat_dim * width) 324 | # self.bn_out = nn.BatchNorm1d(self.final_feat_dim) 325 | 326 | start_width = 3 327 | for i in range(len(self.widths)): 328 | setattr(self, "group_%d" %i, Block(start_width, self.widths[i], 1, dropout, track_bn, reinit_bn_stats)) 329 | start_width = self.widths[i] 330 | 331 | def add_classifier(self, nclasses, name="classifier", modalities=None): 332 | setattr(self, name, torch.nn.Linear(self.final_feat_dim, nclasses)) 333 | 334 | def up_to_embedding(self, x): 335 | """ Applies the four residual groups 336 | Args: 337 | x: input images 338 | n: number of few-shot classes 339 | k: number of images per few-shot class 340 | """ 341 | for i in range(len(self.widths)): 342 | x = getattr(self, "group_%d" % i)(x) 343 | x = F.max_pool2d(x, 3, 2, 1) 344 | return x 345 | 346 | def forward(self, x): 347 | """Main Pytorch forward function 348 | Returns: class logits 349 | Args: 350 | x: input mages 351 | """ 352 | *args, c, h, w = x.size() 353 | x = x.view(-1, c, h, w) 354 | x = self.up_to_embedding(x) 355 | # return F.relu(self.bn_out(x.mean(3).mean(2)), True) 356 | return F.relu(x.mean(3).mean(2), True) 357 | 358 | 359 | class ResNet18(torchvision.models.resnet.ResNet): 360 | def __init__(self, track_bn=True): 361 | def norm_layer(*args, **kwargs): 362 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn) 363 | super().__init__(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], norm_layer=norm_layer) 364 | del self.fc 365 | self.final_feat_dim = 512 366 | 367 | def load_imagenet_weights(self, progress=True): 368 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet18'], 369 | progress=progress) 370 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 371 | if len(missing) > 0: 372 | raise AssertionError('Model code may be incorrect') 373 | 374 | def _forward_impl(self, x: Tensor) -> Tensor: 375 | # See note [TorchScript super()] 376 | x = self.conv1(x) 377 | x = self.bn1(x) 378 | x = self.relu(x) 379 | x = self.maxpool(x) 380 | 381 | x = self.layer1(x) 382 | x = self.layer2(x) 383 | x = self.layer3(x) 384 | x = self.layer4(x) 385 | 386 | x = self.avgpool(x) 387 | x = torch.flatten(x, 1) 388 | # x = self.fc(x) 389 | 390 | return x 391 | 392 | class ResNet50(torchvision.models.resnet.ResNet): 393 | def __init__(self, track_bn=True): 394 | def norm_layer(*args, **kwargs): 395 | return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn) 396 | super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], norm_layer=norm_layer) 397 | del self.fc 398 | self.final_feat_dim = 2048 399 | 400 | def load_imagenet_weights(self, progress=True): 401 | state_dict = load_state_dict_from_url(torchvision.models.resnet.model_urls['resnet50'], 402 | progress=progress) 403 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 404 | if len(missing) > 0: 405 | raise AssertionError('Model code may be incorrect') 406 | 407 | def _forward_impl(self, x: Tensor) -> Tensor: 408 | # See note [TorchScript super()] 409 | x = self.conv1(x) 410 | x = self.bn1(x) 411 | x = self.relu(x) 412 | x = self.maxpool(x) 413 | 414 | x = self.layer1(x) 415 | x = self.layer2(x) 416 | x = self.layer3(x) 417 | x = self.layer4(x) 418 | 419 | x = self.avgpool(x) 420 | x = torch.flatten(x, 1) 421 | # x = self.fc(x) 422 | 423 | return x 424 | 425 | 426 | ########################################################################################################## 427 | # code from https://github.com/WangYueFt/rfs/blob/f8c837ba93c62dd0ac68a2f4019c619aa86b8421/models/resnet.py#L88 428 | 429 | def conv3x3(in_planes, out_planes, stride=1): 430 | """3x3 convolution with padding""" 431 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 432 | padding=1, bias=False) 433 | 434 | 435 | class SELayer(nn.Module): 436 | def __init__(self, channel, reduction=16): 437 | super(SELayer, self).__init__() 438 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 439 | self.fc = nn.Sequential( 440 | nn.Linear(channel, channel // reduction), 441 | nn.ReLU(inplace=True), 442 | nn.Linear(channel // reduction, channel), 443 | nn.Sigmoid() 444 | ) 445 | 446 | def forward(self, x): 447 | b, c, _, _ = x.size() 448 | y = self.avg_pool(x).view(b, c) 449 | y = self.fc(y).view(b, c, 1, 1) 450 | return x * y 451 | 452 | 453 | class DropBlock(nn.Module): 454 | def __init__(self, block_size): 455 | super(DropBlock, self).__init__() 456 | 457 | self.block_size = block_size 458 | #self.gamma = gamma 459 | #self.bernouli = Bernoulli(gamma) 460 | 461 | def forward(self, x, gamma): 462 | # shape: (bsize, channels, height, width) 463 | 464 | if self.training: 465 | batch_size, channels, height, width = x.shape 466 | 467 | bernoulli = Bernoulli(gamma) 468 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 469 | block_mask = self._compute_block_mask(mask) 470 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 471 | count_ones = block_mask.sum() 472 | 473 | return block_mask * x * (countM / count_ones) 474 | else: 475 | return x 476 | 477 | def _compute_block_mask(self, mask): 478 | left_padding = int((self.block_size-1) / 2) 479 | right_padding = int(self.block_size / 2) 480 | 481 | batch_size, channels, height, width = mask.shape 482 | #print ("mask", mask[0][0]) 483 | non_zero_idxs = mask.nonzero() 484 | nr_blocks = non_zero_idxs.shape[0] 485 | 486 | offsets = torch.stack( 487 | [ 488 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 489 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 490 | ] 491 | ).t().cuda() 492 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 493 | 494 | if nr_blocks > 0: 495 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 496 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 497 | offsets = offsets.long() 498 | 499 | block_idxs = non_zero_idxs + offsets 500 | #block_idxs += left_padding 501 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 502 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 503 | else: 504 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 505 | 506 | block_mask = 1 - padded_mask#[:height, :width] 507 | return block_mask 508 | 509 | 510 | class BasicBlock(torch.nn.Module): 511 | expansion = 1 512 | 513 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, 514 | block_size=1, use_se=False): 515 | super(BasicBlock, self).__init__() 516 | self.conv1 = conv3x3(inplanes, planes) 517 | self.bn1 = nn.BatchNorm2d(planes) 518 | self.relu = nn.LeakyReLU(0.1) 519 | self.conv2 = conv3x3(planes, planes) 520 | self.bn2 = nn.BatchNorm2d(planes) 521 | self.conv3 = conv3x3(planes, planes) 522 | self.bn3 = nn.BatchNorm2d(planes) 523 | self.maxpool = nn.MaxPool2d(stride) 524 | self.downsample = downsample 525 | self.stride = stride 526 | self.drop_rate = drop_rate 527 | self.num_batches_tracked = 0 528 | self.drop_block = drop_block 529 | self.block_size = block_size 530 | self.DropBlock = DropBlock(block_size=self.block_size) 531 | self.use_se = use_se 532 | if self.use_se: 533 | self.se = SELayer(planes, 4) 534 | 535 | def forward(self, x): 536 | self.num_batches_tracked += 1 537 | 538 | residual = x 539 | 540 | out = self.conv1(x) 541 | out = self.bn1(out) 542 | out = self.relu(out) 543 | 544 | out = self.conv2(out) 545 | out = self.bn2(out) 546 | out = self.relu(out) 547 | 548 | out = self.conv3(out) 549 | out = self.bn3(out) 550 | if self.use_se: 551 | out = self.se(out) 552 | 553 | if self.downsample is not None: 554 | residual = self.downsample(x) 555 | out += residual 556 | out = self.relu(out) 557 | out = self.maxpool(out) 558 | 559 | if self.drop_rate > 0: 560 | if self.drop_block == True: 561 | feat_size = out.size()[2] 562 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 563 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 564 | out = self.DropBlock(out, gamma=gamma) 565 | else: 566 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 567 | 568 | return out 569 | 570 | 571 | class ResNet18_84x84(torch.nn.Module): 572 | def __init__(self, track_bn=True, block=BasicBlock, n_blocks=[1,1,2,2], keep_prob=1.0, avg_pool=True, drop_rate=0.1, 573 | dropblock_size=5, num_classes=-1, use_se=False): 574 | super(ResNet18_84x84, self).__init__() 575 | self.final_feat_dim = 640 576 | 577 | self.inplanes = 3 578 | self.use_se = use_se 579 | self.layer1 = self._make_layer(block, n_blocks[0], 64, 580 | stride=2, drop_rate=drop_rate) 581 | self.layer2 = self._make_layer(block, n_blocks[1], 160, 582 | stride=2, drop_rate=drop_rate) 583 | self.layer3 = self._make_layer(block, n_blocks[2], 320, 584 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 585 | self.layer4 = self._make_layer(block, n_blocks[3], 640, 586 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 587 | if avg_pool: 588 | # self.avgpool = nn.AvgPool2d(5, stride=1) 589 | self.avgpool = nn.AdaptiveAvgPool2d(1) 590 | self.keep_prob = keep_prob 591 | self.keep_avg_pool = avg_pool 592 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 593 | self.drop_rate = drop_rate 594 | 595 | for m in self.modules(): 596 | if isinstance(m, nn.Conv2d): 597 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 598 | elif isinstance(m, nn.BatchNorm2d): 599 | nn.init.constant_(m.weight, 1) 600 | nn.init.constant_(m.bias, 0) 601 | if not track_bn: 602 | m.track_running_stats = False 603 | 604 | # self.num_classes = num_classes 605 | # if self.num_classes > 0: 606 | # self.classifier = nn.Linear(640, self.num_classes) 607 | 608 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 609 | downsample = None 610 | if stride != 1 or self.inplanes != planes * block.expansion: 611 | downsample = nn.Sequential( 612 | nn.Conv2d(self.inplanes, planes * block.expansion, 613 | kernel_size=1, stride=1, bias=False), 614 | nn.BatchNorm2d(planes * block.expansion), 615 | ) 616 | 617 | layers = [] 618 | if n_block == 1: 619 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se) 620 | else: 621 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se) 622 | layers.append(layer) 623 | self.inplanes = planes * block.expansion 624 | 625 | for i in range(1, n_block): 626 | if i == n_block - 1: 627 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block, 628 | block_size=block_size, use_se=self.use_se) 629 | else: 630 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se) 631 | layers.append(layer) 632 | 633 | return nn.Sequential(*layers) 634 | 635 | def forward(self, x, is_feat=False): 636 | x = self.layer1(x) 637 | # f0 = x 638 | x = self.layer2(x) 639 | # f1 = x 640 | x = self.layer3(x) 641 | # f2 = x 642 | x = self.layer4(x) 643 | # f3 = x 644 | if self.keep_avg_pool: 645 | x = self.avgpool(x) 646 | x = x.view(x.size(0), -1) 647 | # feat = x 648 | # if self.num_classes > 0: 649 | # x = self.classifier(x) 650 | 651 | # if is_feat: 652 | # return [f0, f1, f2, f3, feat], x 653 | # else: 654 | # return x 655 | return x 656 | 657 | 658 | _backbone_class_map = { 659 | 'resnet10': ResNet10, 660 | 'resnet18': ResNet18, 661 | 'resnet50': ResNet50, 662 | } 663 | 664 | 665 | def get_backbone_class(key): 666 | if key in _backbone_class_map: 667 | return _backbone_class_map[key] 668 | else: 669 | raise ValueError('Invalid backbone: {}'.format(key)) -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | save_dir = './logs' 2 | 3 | miniImageNet_path = '/data/cdfsl/miniImagenet' 4 | miniImageNet_test_path = '/data/cdfsl/miniImagenet_test' 5 | tieredImageNet_path = '/data/cdfsl/tieredImagenet/train' 6 | tieredImageNet_test_path = '/data/cdfsl/tieredImagenet/test' 7 | ImageNet_path = '/data/cdfsl/Imagenet/train' 8 | DTD_path = '/ssd/dtd/images/' 9 | 10 | ISIC_path = "/data/cdfsl/ISIC" 11 | ChestX_path = "/data/cdfsl/chestX" 12 | CropDisease_path = "/data/cdfsl/CropDiseases" 13 | EuroSAT_path = "/data/cdfsl/EuroSAT" 14 | 15 | cars_path = '/data/cdfsl/cars_cdfsl' 16 | cub_path = '/data/cdfsl/CUB_200_2011/images' 17 | places_path = '/data/cdfsl/places_cdfsl' 18 | plantae_path = '/data/cdfsl/plantae_cdfsl' 19 | -------------------------------------------------------------------------------- /data_preparation/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from collections import defaultdict 4 | 5 | import scipy.io 6 | from tqdm import tqdm 7 | 8 | ROOT = os.path.dirname(os.path.abspath(__file__)) 9 | METADATA_PATH = os.path.join(ROOT, "input", "cars_annos.mat") 10 | SOURCE_DIR = os.path.join(ROOT, "input", "car_ims") 11 | TARGET_DIR = os.path.join(ROOT, "output", "cars_cdfsl") 12 | 13 | if not os.path.isfile(METADATA_PATH): 14 | raise Exception("Could not find metadata file at `{}`".format(METADATA_PATH)) 15 | if not os.path.isdir(SOURCE_DIR): 16 | raise Exception("could not find image folder at `{}`".format(SOURCE_DIR)) 17 | 18 | metadata = scipy.io.loadmat(METADATA_PATH) 19 | metadata = metadata["annotations"][0] 20 | 21 | paths_by_class = defaultdict(list) 22 | total = len(metadata) 23 | for m in metadata: 24 | path, _, _, _, _, cls, test = m 25 | cls = cls.item() 26 | path = path.item() 27 | path = os.path.basename(path) 28 | paths_by_class[str(cls)].append(path) 29 | 30 | print("Copying images to `{}`...".format(TARGET_DIR)) 31 | with tqdm(total=total) as pbar: 32 | for cls, paths in paths_by_class.items(): 33 | for path in paths: 34 | source_path = os.path.join(SOURCE_DIR, path) 35 | target_path = os.path.join(TARGET_DIR, cls, path) 36 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 37 | shutil.copy(source_path, target_path) 38 | pbar.update() 39 | print("Complete") 40 | -------------------------------------------------------------------------------- /data_preparation/input/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /data_preparation/output/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /data_preparation/places.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | 5 | from tqdm import tqdm 6 | 7 | ROOT = os.path.dirname(os.path.abspath(__file__)) 8 | SUBSET_PATH = os.path.join(ROOT, "places_cdfsl_subset_16_class_1715_sample_seed_0.json") 9 | SOURCE_DIR = os.path.join(ROOT, "input", "places365_standard/train") 10 | TARGET_DIR = os.path.join(ROOT, "output", "places_cdfsl") 11 | 12 | if not os.path.isfile(SUBSET_PATH): 13 | raise Exception("Could not find subset file at `{}`".format(SUBSET_PATH)) 14 | if not os.path.isdir(SOURCE_DIR): 15 | raise Exception("could not find image folder at `{}`".format(SOURCE_DIR)) 16 | 17 | with open(SUBSET_PATH) as f: 18 | subset_data = json.load(f) 19 | all_paths = [] 20 | for paths in subset_data.values(): 21 | all_paths.extend(paths) 22 | 23 | print("Copying images to {}...".format(TARGET_DIR)) 24 | for p in tqdm(all_paths): 25 | source_path = os.path.join(SOURCE_DIR, p) 26 | target_path = os.path.join(TARGET_DIR, p) 27 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 28 | shutil.copy(source_path, target_path) 29 | print("Complete") 30 | -------------------------------------------------------------------------------- /data_preparation/plantae.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | 5 | from tqdm import tqdm 6 | 7 | ROOT = os.path.dirname(os.path.abspath(__file__)) 8 | SUBSET_PATH = os.path.join(ROOT, "plantae_cdfsl_subset_69_class.json") 9 | SOURCE_DIR = os.path.join(ROOT, "input", "Plantae") 10 | TARGET_DIR = os.path.join(ROOT, "output", "plantae_cdfsl") 11 | 12 | if not os.path.isfile(SUBSET_PATH): 13 | raise Exception("Could not find subset file at `{}`".format(SUBSET_PATH)) 14 | if not os.path.isdir(SOURCE_DIR): 15 | raise Exception("could not find image folder at `{}`".format(SOURCE_DIR)) 16 | 17 | with open(SUBSET_PATH) as f: 18 | subset_data = json.load(f) 19 | all_paths = [] 20 | for paths in subset_data.values(): 21 | all_paths.extend(paths) 22 | 23 | print("Copying images to {}...".format(TARGET_DIR)) 24 | for p in tqdm(all_paths): 25 | source_path = os.path.join(SOURCE_DIR, p) 26 | target_path = os.path.join(TARGET_DIR, p) 27 | os.makedirs(os.path.dirname(target_path), exist_ok=True) 28 | shutil.copy(source_path, target_path) 29 | print("Complete") 30 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sungnyun/understanding-cdfsl/fdd45d8e5af0cb35bdb0b2b4046ed8089abdf304/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, MutableMapping 2 | from weakref import WeakValueDictionary 3 | 4 | import torch 5 | import torch.utils.data 6 | from torch.utils.data import Dataset 7 | 8 | from datasets.datasets import dataset_class_map 9 | from datasets.sampler import EpisodicBatchSampler 10 | from datasets.split import split_dataset 11 | from datasets.transforms import get_composed_transform 12 | 13 | _unlabeled_dataset_cache: MutableMapping[Tuple[str, str, int, bool, int], Dataset] = WeakValueDictionary() 14 | 15 | DEFAULT_IMAGE_SIZE = 224 16 | 17 | 18 | class ToSiamese: 19 | """ 20 | A wrapper for torchvision transform. The transform is applied twice for 21 | SimCLR training 22 | """ 23 | 24 | def __init__(self, transform, transform2=None): 25 | self.transform = transform 26 | 27 | if transform2 is not None: 28 | self.transform2 = transform2 29 | else: 30 | self.transform2 = transform 31 | 32 | def __call__(self, img): 33 | return self.transform(img), self.transform2(img) 34 | 35 | 36 | def get_default_dataset(dataset_name: str, augmentation: str, image_size: int = None, siamese=False): 37 | """ 38 | :param augmentation: One of {'base', 'strong', None, 'none'} 39 | """ 40 | if image_size is None: 41 | print('Using default image size: {}'.format(DEFAULT_IMAGE_SIZE)) 42 | image_size = DEFAULT_IMAGE_SIZE 43 | 44 | try: 45 | dataset_cls = dataset_class_map[dataset_name] 46 | except KeyError as e: 47 | raise ValueError('Unsupported dataset: {}'.format(dataset_name)) 48 | 49 | transform = get_composed_transform(augmentation, image_size=image_size) 50 | if siamese: 51 | transform = ToSiamese(transform) 52 | return dataset_cls(transform=transform) 53 | 54 | 55 | def get_split_dataset(dataset_name: str, augmentation: str, image_size: int = None, siamese=False, 56 | unlabeled_ratio: int = 20, seed=1): 57 | # If cache details change, just remove the cache – it's not worth the maintenance TBH. 58 | cache_key = (dataset_name, augmentation, image_size, siamese, unlabeled_ratio) 59 | if cache_key not in _unlabeled_dataset_cache: 60 | dataset = get_default_dataset(dataset_name=dataset_name, augmentation=augmentation, image_size=image_size, 61 | siamese=siamese) 62 | unlabeled, labeled = split_dataset(dataset, ratio=unlabeled_ratio, seed=seed) 63 | # Cross-reference so that strong ref persists if either split is currently referenced 64 | unlabeled.counterpart = labeled 65 | labeled.counterpart = unlabeled 66 | _unlabeled_dataset_cache[cache_key] = unlabeled 67 | 68 | unlabeled = _unlabeled_dataset_cache[cache_key] 69 | labeled = unlabeled.counterpart 70 | 71 | return unlabeled, labeled 72 | 73 | 74 | def get_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None, siamese=False, 75 | num_workers=2, shuffle=True, drop_last=False): 76 | dataset = get_default_dataset(dataset_name=dataset_name, augmentation=augmentation, image_size=image_size, 77 | siamese=siamese) 78 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, 79 | shuffle=shuffle, drop_last=drop_last) 80 | 81 | 82 | def get_split_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None, siamese=False, 83 | unlabeled_ratio: int = 20, num_workers=2, shuffle=True, drop_last=False, seed=1): 84 | unlabeled, labeled = get_split_dataset(dataset_name, augmentation, image_size=image_size, siamese=siamese, 85 | unlabeled_ratio=unlabeled_ratio, seed=seed) 86 | dataloaders = [] 87 | for dataset in [unlabeled, labeled]: 88 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, 89 | shuffle=shuffle, drop_last=drop_last) 90 | dataloaders.append(dataloader) 91 | return dataloaders 92 | 93 | 94 | def get_labeled_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None, siamese=False, 95 | unlabeled_ratio: int = 20, num_workers=2, shuffle=True, drop_last=False, split_seed=1): 96 | unlabeled, labeled = get_split_dataloader(dataset_name, augmentation, batch_size, image_size, siamese=siamese, 97 | unlabeled_ratio=unlabeled_ratio, 98 | num_workers=num_workers, shuffle=shuffle, drop_last=drop_last, 99 | seed=split_seed) 100 | return labeled 101 | 102 | 103 | def get_unlabeled_dataloader(dataset_name: str, augmentation: str, batch_size: int, image_size: int = None, 104 | siamese=False, unlabeled_ratio: int = 20, num_workers=2, 105 | shuffle=True, drop_last=True, split_seed=1): 106 | unlabeled, labeled = get_split_dataloader(dataset_name, augmentation, batch_size, image_size, siamese=siamese, 107 | unlabeled_ratio=unlabeled_ratio, 108 | num_workers=num_workers, shuffle=shuffle, drop_last=drop_last, 109 | seed=split_seed) 110 | return unlabeled 111 | 112 | 113 | def get_episodic_dataloader(dataset_name: str, n_way: int, n_shot: int, support: bool, n_episodes=600, n_query_shot=15, 114 | augmentation: str = None, image_size: int = None, num_workers=2, n_epochs=1, 115 | episode_seed=0): 116 | dataset = get_default_dataset(dataset_name=dataset_name, augmentation=augmentation, image_size=image_size, 117 | siamese=False) 118 | sampler = EpisodicBatchSampler(dataset, n_way=n_way, n_shot=n_shot, n_query_shot=n_query_shot, 119 | n_episodes=n_episodes, support=support, n_epochs=n_epochs, seed=episode_seed) 120 | return torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_sampler=sampler) 121 | 122 | 123 | def get_labeled_episodic_dataloader(dataset_name: str, n_way: int, n_shot: int, support: bool, n_episodes=600, 124 | n_query_shot=15, n_epochs=1, augmentation: str = None, image_size: int = None, 125 | unlabeled_ratio: int = 20, num_workers=2, split_seed=1, episode_seed=0): 126 | unlabeled, labeled = get_split_dataset(dataset_name, augmentation, image_size=image_size, siamese=False, 127 | unlabeled_ratio=unlabeled_ratio, seed=split_seed) 128 | sampler = EpisodicBatchSampler(labeled, n_way=n_way, n_shot=n_shot, n_query_shot=n_query_shot, 129 | n_episodes=n_episodes, support=support, n_epochs=n_epochs, seed=episode_seed) 130 | return torch.utils.data.DataLoader(labeled, num_workers=num_workers, batch_sampler=sampler) 131 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | All dataset classes in unified `torchvision.datasets.ImageFolder` format! 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from torchvision.datasets import ImageFolder 10 | 11 | from configs import * 12 | 13 | 14 | class MiniImageNetDataset(ImageFolder): 15 | name = "miniImageNet" 16 | 17 | def __init__(self, root=miniImageNet_path, *args, **kwargs): 18 | super().__init__(root=root, *args, **kwargs) 19 | 20 | 21 | class MiniImageNetTestDataset(ImageFolder): 22 | name = "miniImageNet_test" 23 | 24 | def __init__(self, root=miniImageNet_test_path, *args, **kwargs): 25 | super().__init__(root=root, *args, **kwargs) 26 | 27 | 28 | class TieredImageNetDataset(ImageFolder): 29 | name = "tieredImageNet" 30 | 31 | def __init__(self, root=tieredImageNet_path, *args, **kwargs): 32 | super().__init__(root=root, *args, **kwargs) 33 | 34 | 35 | class TieredImageNetTestDataset(ImageFolder): 36 | name = "tieredImageNet_test" 37 | 38 | def __init__(self, root=tieredImageNet_test_path, *args, **kwargs): 39 | super().__init__(root=root, *args, **kwargs) 40 | 41 | 42 | class ImageNetDataset(ImageFolder): 43 | name = "ImageNet" 44 | 45 | def __init__(self, root=ImageNet_path, *args, **kwargs): 46 | super().__init__(root=root, *args, **kwargs) 47 | 48 | 49 | class CropDiseaseDataset(ImageFolder): 50 | name = "CropDisease" 51 | 52 | def __init__(self, root=CropDisease_path, *args, **kwargs): 53 | super().__init__(root=os.path.join(root, "dataset", "train"), *args, **kwargs) 54 | 55 | 56 | class EuroSATDataset(ImageFolder): 57 | name = "EuroSAT" 58 | 59 | def __init__(self, root=EuroSAT_path, *args, **kwargs): 60 | super().__init__(root, *args, **kwargs) 61 | 62 | 63 | class ISICDataset(ImageFolder): 64 | name = "ISIC" 65 | """ 66 | Implementation note: functions for finding data files have been customized so that data is selected based on 67 | the given CSV file. 68 | """ 69 | 70 | def __init__(self, root=ISIC_path, *args, **kwargs): 71 | csv_path = os.path.join(root, "ISIC2018_Task3_Training_GroundTruth.csv") 72 | self.metadata = pd.read_csv(csv_path) 73 | super().__init__(root, *args, **kwargs) 74 | 75 | def make_dataset(self, root, *args, **kwargs): 76 | paths = np.asarray(self.metadata.iloc[:, 0]) 77 | labels = np.asarray(self.metadata.iloc[:, 1:]) 78 | labels = (labels != 0).argmax(axis=1) 79 | 80 | samples = [] 81 | for path, label in zip(paths, labels): 82 | path = os.path.join(root, path + ".jpg") 83 | samples.append((path, label)) 84 | samples.sort() 85 | 86 | return samples 87 | 88 | def find_classes(self, _): 89 | classes = self.metadata.columns[1:].tolist() 90 | classes.sort() 91 | class_to_idx = dict() 92 | for i, cls in enumerate(classes): 93 | class_to_idx[cls] = i 94 | return classes, class_to_idx 95 | 96 | _find_classes = find_classes # compatibility with earlier versions 97 | 98 | 99 | class ChestXDataset(ImageFolder): 100 | name = "ChestX" 101 | """ 102 | Implementation note: functions for finding data files have been customized so that data is selected based on 103 | the given CSV file. 104 | """ 105 | 106 | def __init__(self, root=ChestX_path, *args, **kwargs): 107 | csv_path = os.path.join(root, "Data_Entry_2017.csv") 108 | images_root = os.path.join(root, "images") 109 | # self.used_labels = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia", 110 | # "Pneumothorax"] 111 | self.used_labels = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", 112 | "Pneumothorax"] 113 | self.labels_maps = {"Atelectasis": 0, "Cardiomegaly": 1, "Effusion": 2, "Infiltration": 3, "Mass": 4, 114 | "Nodule": 5, "Pneumothorax": 6} 115 | self.metadata = pd.read_csv(csv_path) 116 | super().__init__(images_root, *args, **kwargs) 117 | 118 | def make_dataset(self, root, *args, **kwargs): 119 | samples = [] 120 | paths = np.asarray(self.metadata.iloc[:, 0]) 121 | labels = np.asarray(self.metadata.iloc[:, 1]) 122 | for path, label in zip(paths, labels): 123 | label = label.split("|") 124 | if len(label) == 1 and label[0] != "No Finding" and label[0] != "Pneumonia" and label[ 125 | 0] in self.used_labels: 126 | path = os.path.join(root, path) 127 | label = self.labels_maps[label[0]] 128 | samples.append((path, label)) 129 | samples.sort() 130 | return samples 131 | 132 | def find_classes(self, _): 133 | return self.used_labels, self.labels_maps 134 | 135 | _find_classes = find_classes # compatibility with earlier versions 136 | 137 | 138 | class CarsDataset(ImageFolder): 139 | name = "cars" 140 | 141 | def __init__(self, root=cars_path, *args, **kwargs): 142 | super().__init__(root=root, *args, **kwargs) 143 | 144 | 145 | class CUBDataset(ImageFolder): 146 | name = "cub" 147 | 148 | def __init__(self, root=cub_path, *args, **kwargs): 149 | super().__init__(root=root, *args, **kwargs) 150 | 151 | 152 | class PlacesDataset(ImageFolder): 153 | name = "places" 154 | 155 | def __init__(self, root=places_path, *args, **kwargs): 156 | super().__init__(root=root, *args, **kwargs) 157 | 158 | 159 | class PlantaeDataset(ImageFolder): 160 | name = "plantae" 161 | 162 | def __init__(self, root=plantae_path, *args, **kwargs): 163 | super().__init__(root=root, *args, **kwargs) 164 | 165 | 166 | dataset_classes = [ 167 | MiniImageNetDataset, 168 | MiniImageNetTestDataset, 169 | TieredImageNetDataset, 170 | TieredImageNetTestDataset, 171 | ImageNetDataset, 172 | CropDiseaseDataset, 173 | EuroSATDataset, 174 | ISICDataset, 175 | ChestXDataset, 176 | CarsDataset, 177 | CUBDataset, 178 | PlacesDataset, 179 | PlantaeDataset, 180 | ] 181 | 182 | dataset_class_map = { 183 | cls.name: cls for cls in dataset_classes 184 | } 185 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | from torch.utils.data import Sampler 5 | from torchvision.datasets import ImageFolder 6 | 7 | 8 | class EpisodeSampler: 9 | """ 10 | Stable sampler for support and query indices. Used by episodic batch sampler, so that the support and query sets 11 | can be sampled from independent data loaders using the same splits, i.e., such that support and query do not overlap. 12 | """ 13 | 14 | def __init__(self, dataset: ImageFolder, n_way: int, n_shot: int, n_query_shot: int, n_episodes: int, 15 | seed: int = 0): 16 | self.dataset = dataset 17 | self.n_classes = len(dataset.classes) 18 | self.w = n_way 19 | self.s = n_shot 20 | self.q = n_query_shot 21 | self.n_episodes = n_episodes 22 | self.seed = seed 23 | 24 | rs = np.random.RandomState(seed) 25 | self.episode_seeds = [] 26 | for i in range(n_episodes): 27 | self.episode_seeds.append(rs.randint(2 ** 32 - 1)) 28 | 29 | self.indices_by_class = defaultdict(list) 30 | for index, (path, label) in enumerate(dataset.samples): 31 | self.indices_by_class[label].append(index) 32 | 33 | def __getitem__(self, index): 34 | """ 35 | :param index: 36 | :return: support: ndarray[w, s], query: ndarray[w ,q] 37 | """ 38 | rs = np.random.RandomState(self.episode_seeds[index]) 39 | selected_classes = rs.permutation(self.n_classes)[:self.w] 40 | indices = [] 41 | for cls in selected_classes: 42 | candidates: int = len(self.indices_by_class[cls]) 43 | choices: int = self.s + self.q 44 | 45 | if candidates > choices: 46 | indices.append( 47 | rs.choice(self.indices_by_class[cls], choices, replace=False)) 48 | else: 49 | indices.append( 50 | rs.choice(self.indices_by_class[cls], choices, replace=True)) 51 | 52 | episode = np.stack(indices) 53 | support = episode[:, :self.s] 54 | query = episode[:, self.s:] 55 | return support, query 56 | 57 | def __len__(self): 58 | return self.n_episodes 59 | 60 | 61 | class EpisodicBatchSampler(Sampler): 62 | """ 63 | For each epoch, the same batch is yielded repeatedly. For batch-training within episodes, you need to divide up the 64 | sampled data (from the dataloader) into further smaller batches. 65 | 66 | For classification-based training, note that you need to reset the class indices to [0, 0, ..., 1, ..., w-1]. Note 67 | that this is why inter-episode batches are not supported by the sampler: it's harder to reset the class indices. 68 | """ 69 | 70 | def __init__(self, dataset: ImageFolder, n_way: int, n_shot: int, n_query_shot: int, n_episodes: int, support: bool, 71 | n_epochs=1, seed=0): 72 | super().__init__(dataset) 73 | self.dataset = dataset 74 | 75 | self.w = n_way 76 | self.s = n_shot 77 | self.q = n_query_shot 78 | self.episode_sampler = EpisodeSampler(dataset, n_way, n_shot, n_query_shot, n_episodes, seed) 79 | 80 | self.n_episodes = n_episodes 81 | self.n_epochs = n_epochs 82 | self.support = support 83 | 84 | def __len__(self): 85 | return self.n_episodes * self.n_epochs 86 | 87 | def __iter__(self): 88 | for i in range(self.n_episodes): 89 | support, query = self.episode_sampler[i] 90 | indices = support if self.support else query 91 | indices = indices.flatten() 92 | for j in range(self.n_epochs): 93 | yield indices 94 | -------------------------------------------------------------------------------- /datasets/split.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from typing import List, Tuple 4 | 5 | import pandas as pd 6 | from numpy.random import RandomState 7 | from torchvision.datasets import ImageFolder 8 | 9 | DIRNAME = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | DATASETS_WITH_DEFAULT_SPLITS = [ 12 | "miniImageNet", 13 | "miniImageNet_test", 14 | "tieredImageNet", 15 | "tieredImageNet_test", 16 | "CropDisease", 17 | "EuroSAT", 18 | "ISIC", 19 | "ChestX", 20 | ] 21 | 22 | 23 | def split_dataset(dataset: ImageFolder, ratio=20, seed=1): 24 | """ 25 | :param dataset: 26 | :param ratio: Ratio of unlabeled portion 27 | :param seed: 28 | :return: unlabeled_dataset, labeled_dataset 29 | """ 30 | assert (0 <= ratio <= 100) 31 | 32 | # Check default splits 33 | unlabeled_path = _get_split_path(dataset, ratio, seed, True) 34 | labeled_path = _get_split_path(dataset, ratio, seed, False) 35 | for path in [unlabeled_path, labeled_path]: 36 | if ratio == 20 and seed == 1 and dataset.name in DATASETS_WITH_DEFAULT_SPLITS and not os.path.exists(path): 37 | raise Exception("Default split file missing: {}".format(path)) 38 | 39 | if os.path.exists(unlabeled_path) and os.path.exists(labeled_path): 40 | print("Loading unlabeled split from {}".format(unlabeled_path)) 41 | print("Loading labeled split from {}".format(labeled_path)) 42 | unlabeled = _load_split(unlabeled_path) 43 | labeled = _load_split(labeled_path) 44 | else: 45 | unlabeled, labeled = _get_split(dataset, ratio, seed) 46 | print("Generating unlabeled split to {}".format(unlabeled_path)) 47 | print("Generating labeled split to {}".format(labeled_path)) 48 | _save_split(unlabeled, unlabeled_path) 49 | _save_split(labeled, labeled_path) 50 | 51 | ud = copy.deepcopy(dataset) 52 | ld = copy.deepcopy(dataset) 53 | 54 | _apply_split(ud, unlabeled) 55 | _apply_split(ld, labeled) 56 | 57 | return ud, ld 58 | 59 | 60 | def _get_split(dataset: ImageFolder, ratio: int, seed: int) -> Tuple[List[str], List[str]]: 61 | img_paths = [] 62 | for path, label in dataset.samples: 63 | root_with_slash = os.path.join(dataset.root, "") 64 | img_paths.append(path.replace(root_with_slash, "")) 65 | img_paths.sort() 66 | # Assert uniqueness 67 | assert (len(img_paths) == len(set(img_paths))) 68 | 69 | rs = RandomState(seed) 70 | unlabeled_count = len(img_paths) * ratio // 100 71 | unlabeled_paths = set(rs.choice(img_paths, unlabeled_count, replace=False)) 72 | labeled_paths = set(img_paths) - unlabeled_paths 73 | 74 | return sorted(list(unlabeled_paths)), sorted(list(labeled_paths)) 75 | 76 | 77 | def _save_split(split: List, path): 78 | df = pd.DataFrame({ 79 | "img_path": split 80 | }) 81 | df.to_csv(path) 82 | 83 | 84 | def _load_split(path) -> List[str]: 85 | df = pd.read_csv(path) 86 | return df["img_path"].values 87 | 88 | 89 | def _get_split_path(dataset: ImageFolder, ratio: int, seed=1, unlabeled=True, makedirs=True): 90 | if unlabeled: 91 | basename = '{}_unlabeled_{}.csv'.format(dataset.name, ratio) 92 | else: 93 | basename = '{}_labeled_{}.csv'.format(dataset.name, 100 - ratio) 94 | path = os.path.join(DIRNAME, 'split_seed_{}'.format(seed), basename) 95 | if makedirs: 96 | os.makedirs(os.path.dirname(path), exist_ok=True) 97 | return path 98 | 99 | 100 | def _apply_split(dataset: ImageFolder, split: List[str]): 101 | img_paths = [] 102 | for path, label in dataset.samples: 103 | root_with_slash = os.path.join(dataset.root, "") 104 | img_paths.append(path.replace(root_with_slash, "")) 105 | 106 | split_set = set(split) 107 | samples = [] 108 | for path, sample in zip(img_paths, dataset.samples): 109 | if len(split) > 0 and '.jpg' not in split[0] and dataset.name == 'ISIC': # HOTFIX (paths in ISIC's default split file don't have ".jpg") 110 | path = path.replace('.jpg', '') 111 | if path in split_set: 112 | samples.append(sample) 113 | 114 | dataset.samples = samples 115 | dataset.imgs = samples 116 | dataset.targets = [s[1] for s in samples] 117 | -------------------------------------------------------------------------------- /datasets/split_seed_1/ISIC_unlabeled_20.csv: -------------------------------------------------------------------------------- 1 | img_path 2 | ISIC_0028058 3 | ISIC_0029999 4 | ISIC_0029695 5 | ISIC_0033303 6 | ISIC_0025777 7 | ISIC_0025563 8 | ISIC_0028185 9 | ISIC_0024758 10 | ISIC_0032674 11 | ISIC_0030458 12 | ISIC_0030532 13 | ISIC_0030370 14 | ISIC_0025134 15 | ISIC_0033986 16 | ISIC_0027242 17 | ISIC_0033035 18 | ISIC_0034007 19 | ISIC_0025028 20 | ISIC_0034198 21 | ISIC_0031865 22 | ISIC_0032023 23 | ISIC_0033227 24 | ISIC_0031642 25 | ISIC_0027194 26 | ISIC_0025901 27 | ISIC_0031278 28 | ISIC_0032601 29 | ISIC_0030270 30 | ISIC_0029166 31 | ISIC_0033398 32 | ISIC_0033490 33 | ISIC_0030927 34 | ISIC_0029994 35 | ISIC_0028277 36 | ISIC_0027402 37 | ISIC_0028054 38 | ISIC_0028736 39 | ISIC_0029823 40 | ISIC_0031883 41 | ISIC_0029056 42 | ISIC_0032406 43 | ISIC_0028685 44 | ISIC_0028383 45 | ISIC_0024656 46 | ISIC_0029851 47 | ISIC_0027628 48 | ISIC_0026565 49 | ISIC_0032105 50 | ISIC_0026332 51 | ISIC_0034214 52 | ISIC_0027658 53 | ISIC_0032894 54 | ISIC_0033933 55 | ISIC_0025751 56 | ISIC_0029383 57 | ISIC_0034151 58 | ISIC_0029555 59 | ISIC_0032044 60 | ISIC_0032496 61 | ISIC_0029076 62 | ISIC_0028870 63 | ISIC_0029719 64 | ISIC_0025917 65 | ISIC_0024735 66 | ISIC_0034079 67 | ISIC_0028256 68 | ISIC_0029957 69 | ISIC_0029531 70 | ISIC_0032029 71 | ISIC_0025516 72 | ISIC_0025222 73 | ISIC_0028568 74 | ISIC_0031461 75 | ISIC_0024546 76 | ISIC_0030856 77 | ISIC_0030926 78 | ISIC_0026919 79 | ISIC_0026800 80 | ISIC_0025306 81 | ISIC_0029580 82 | ISIC_0024362 83 | ISIC_0027326 84 | ISIC_0034305 85 | ISIC_0024514 86 | ISIC_0026417 87 | ISIC_0028969 88 | ISIC_0032987 89 | ISIC_0029709 90 | ISIC_0025746 91 | ISIC_0024731 92 | ISIC_0032367 93 | ISIC_0025003 94 | ISIC_0031006 95 | ISIC_0025571 96 | ISIC_0025929 97 | ISIC_0029139 98 | ISIC_0031981 99 | ISIC_0025899 100 | ISIC_0031511 101 | ISIC_0026607 102 | ISIC_0028894 103 | ISIC_0029363 104 | ISIC_0030262 105 | ISIC_0032579 106 | ISIC_0026506 107 | ISIC_0031437 108 | ISIC_0033205 109 | ISIC_0025284 110 | ISIC_0027076 111 | ISIC_0028631 112 | ISIC_0025125 113 | ISIC_0033050 114 | ISIC_0030228 115 | ISIC_0031949 116 | ISIC_0033924 117 | ISIC_0033123 118 | ISIC_0025773 119 | ISIC_0033536 120 | ISIC_0024927 121 | ISIC_0032942 122 | ISIC_0029769 123 | ISIC_0032124 124 | ISIC_0028312 125 | ISIC_0027354 126 | ISIC_0026084 127 | ISIC_0030798 128 | ISIC_0028372 129 | ISIC_0024642 130 | ISIC_0026948 131 | ISIC_0031567 132 | ISIC_0025389 133 | ISIC_0031802 134 | ISIC_0026913 135 | ISIC_0031267 136 | ISIC_0027875 137 | ISIC_0028330 138 | ISIC_0032849 139 | ISIC_0024696 140 | ISIC_0025043 141 | ISIC_0024419 142 | ISIC_0025374 143 | ISIC_0028229 144 | ISIC_0025243 145 | ISIC_0026605 146 | ISIC_0030275 147 | ISIC_0033915 148 | ISIC_0032097 149 | ISIC_0029417 150 | ISIC_0031285 151 | ISIC_0031693 152 | ISIC_0024563 153 | ISIC_0026544 154 | ISIC_0025645 155 | ISIC_0025458 156 | ISIC_0033732 157 | ISIC_0030666 158 | ISIC_0025298 159 | ISIC_0032511 160 | ISIC_0026152 161 | ISIC_0026864 162 | ISIC_0029411 163 | ISIC_0026278 164 | ISIC_0028916 165 | ISIC_0028928 166 | ISIC_0027970 167 | ISIC_0031966 168 | ISIC_0027709 169 | ISIC_0030830 170 | ISIC_0024438 171 | ISIC_0026768 172 | ISIC_0028437 173 | ISIC_0024747 174 | ISIC_0028466 175 | ISIC_0026335 176 | ISIC_0025974 177 | ISIC_0030171 178 | ISIC_0029684 179 | ISIC_0031986 180 | ISIC_0026481 181 | ISIC_0029040 182 | ISIC_0029687 183 | ISIC_0029626 184 | ISIC_0031891 185 | ISIC_0029620 186 | ISIC_0027031 187 | ISIC_0034137 188 | ISIC_0033096 189 | ISIC_0029169 190 | ISIC_0033586 191 | ISIC_0025600 192 | ISIC_0032813 193 | ISIC_0025965 194 | ISIC_0032954 195 | ISIC_0031361 196 | ISIC_0027736 197 | ISIC_0027852 198 | ISIC_0024333 199 | ISIC_0030614 200 | ISIC_0024797 201 | ISIC_0032253 202 | ISIC_0034257 203 | ISIC_0028349 204 | ISIC_0033004 205 | ISIC_0026431 206 | ISIC_0026772 207 | ISIC_0033093 208 | ISIC_0030149 209 | ISIC_0030356 210 | ISIC_0028763 211 | ISIC_0029336 212 | ISIC_0031927 213 | ISIC_0024645 214 | ISIC_0028507 215 | ISIC_0032431 216 | ISIC_0027957 217 | ISIC_0024430 218 | ISIC_0030950 219 | ISIC_0027331 220 | ISIC_0026692 221 | ISIC_0026271 222 | ISIC_0027266 223 | ISIC_0031563 224 | ISIC_0025275 225 | ISIC_0028614 226 | ISIC_0032879 227 | ISIC_0025894 228 | ISIC_0025241 229 | ISIC_0030869 230 | ISIC_0024829 231 | ISIC_0025659 232 | ISIC_0030699 233 | ISIC_0027444 234 | ISIC_0025494 235 | ISIC_0032327 236 | ISIC_0032018 237 | ISIC_0031051 238 | ISIC_0033683 239 | ISIC_0025183 240 | ISIC_0027600 241 | ISIC_0032450 242 | ISIC_0033221 243 | ISIC_0029138 244 | ISIC_0030248 245 | ISIC_0027735 246 | ISIC_0031185 247 | ISIC_0032151 248 | ISIC_0030605 249 | ISIC_0029062 250 | ISIC_0034243 251 | ISIC_0031203 252 | ISIC_0032959 253 | ISIC_0028690 254 | ISIC_0027848 255 | ISIC_0030178 256 | ISIC_0032928 257 | ISIC_0025088 258 | ISIC_0026469 259 | ISIC_0030025 260 | ISIC_0025066 261 | ISIC_0024820 262 | ISIC_0025911 263 | ISIC_0028585 264 | ISIC_0031319 265 | ISIC_0027197 266 | ISIC_0028236 267 | ISIC_0026458 268 | ISIC_0024499 269 | ISIC_0025048 270 | ISIC_0032609 271 | ISIC_0026176 272 | ISIC_0026517 273 | ISIC_0032113 274 | ISIC_0027944 275 | ISIC_0029344 276 | ISIC_0030502 277 | ISIC_0027506 278 | ISIC_0033183 279 | ISIC_0030325 280 | ISIC_0034106 281 | ISIC_0033241 282 | ISIC_0024715 283 | ISIC_0024884 284 | ISIC_0024998 285 | ISIC_0030580 286 | ISIC_0032715 287 | ISIC_0029233 288 | ISIC_0030231 289 | ISIC_0032739 290 | ISIC_0033069 291 | ISIC_0025463 292 | ISIC_0031787 293 | ISIC_0033796 294 | ISIC_0027439 295 | ISIC_0024896 296 | ISIC_0025881 297 | ISIC_0028880 298 | ISIC_0027435 299 | ISIC_0033633 300 | ISIC_0030556 301 | ISIC_0031058 302 | ISIC_0026702 303 | ISIC_0033264 304 | ISIC_0027430 305 | ISIC_0029346 306 | ISIC_0025049 307 | ISIC_0026254 308 | ISIC_0024488 309 | ISIC_0027853 310 | ISIC_0028644 311 | ISIC_0027022 312 | ISIC_0029796 313 | ISIC_0025359 314 | ISIC_0032663 315 | ISIC_0024959 316 | ISIC_0024454 317 | ISIC_0025014 318 | ISIC_0024675 319 | ISIC_0024329 320 | ISIC_0025591 321 | ISIC_0026263 322 | ISIC_0032331 323 | ISIC_0026566 324 | ISIC_0029934 325 | ISIC_0025505 326 | ISIC_0028814 327 | ISIC_0026589 328 | ISIC_0025498 329 | ISIC_0032604 330 | ISIC_0033407 331 | ISIC_0032238 332 | ISIC_0029911 333 | ISIC_0032546 334 | ISIC_0032068 335 | ISIC_0029114 336 | ISIC_0026999 337 | ISIC_0028557 338 | ISIC_0032988 339 | ISIC_0027104 340 | ISIC_0033471 341 | ISIC_0026817 342 | ISIC_0032847 343 | ISIC_0026214 344 | ISIC_0032136 345 | ISIC_0032374 346 | ISIC_0026012 347 | ISIC_0025770 348 | ISIC_0028106 349 | ISIC_0027339 350 | ISIC_0027077 351 | ISIC_0030724 352 | ISIC_0034120 353 | ISIC_0024336 354 | ISIC_0026612 355 | ISIC_0031608 356 | ISIC_0025532 357 | ISIC_0026635 358 | ISIC_0027719 359 | ISIC_0033424 360 | ISIC_0024479 361 | ISIC_0033901 362 | ISIC_0027547 363 | ISIC_0029261 364 | ISIC_0030374 365 | ISIC_0028048 366 | ISIC_0030226 367 | ISIC_0032699 368 | ISIC_0032372 369 | ISIC_0025113 370 | ISIC_0025775 371 | ISIC_0028208 372 | ISIC_0031533 373 | ISIC_0024799 374 | ISIC_0031053 375 | ISIC_0025307 376 | ISIC_0026623 377 | ISIC_0033061 378 | ISIC_0033299 379 | ISIC_0025278 380 | ISIC_0029710 381 | ISIC_0026686 382 | ISIC_0032110 383 | ISIC_0030891 384 | ISIC_0030097 385 | ISIC_0029013 386 | ISIC_0031167 387 | ISIC_0034165 388 | ISIC_0032219 389 | ISIC_0028416 390 | ISIC_0031044 391 | ISIC_0025969 392 | ISIC_0031640 393 | ISIC_0027651 394 | ISIC_0024504 395 | ISIC_0025364 396 | ISIC_0024946 397 | ISIC_0031615 398 | ISIC_0024625 399 | ISIC_0031737 400 | ISIC_0026968 401 | ISIC_0033661 402 | ISIC_0033283 403 | ISIC_0026226 404 | ISIC_0026258 405 | ISIC_0024677 406 | ISIC_0024891 407 | ISIC_0027136 408 | ISIC_0032035 409 | ISIC_0034024 410 | ISIC_0034153 411 | ISIC_0024382 412 | ISIC_0030640 413 | ISIC_0031168 414 | ISIC_0029399 415 | ISIC_0026674 416 | ISIC_0031950 417 | ISIC_0030451 418 | ISIC_0029041 419 | ISIC_0024681 420 | ISIC_0024386 421 | ISIC_0031956 422 | ISIC_0032337 423 | ISIC_0026476 424 | ISIC_0034164 425 | ISIC_0027092 426 | ISIC_0031695 427 | ISIC_0027204 428 | ISIC_0031880 429 | ISIC_0031092 430 | ISIC_0029786 431 | ISIC_0024495 432 | ISIC_0029917 433 | ISIC_0027198 434 | ISIC_0034236 435 | ISIC_0033106 436 | ISIC_0027476 437 | ISIC_0026365 438 | ISIC_0028755 439 | ISIC_0031483 440 | ISIC_0025254 441 | ISIC_0030030 442 | ISIC_0032500 443 | ISIC_0034105 444 | ISIC_0030540 445 | ISIC_0033670 446 | ISIC_0032313 447 | ISIC_0027412 448 | ISIC_0029872 449 | ISIC_0031851 450 | ISIC_0026416 451 | ISIC_0032738 452 | ISIC_0025823 453 | ISIC_0031736 454 | ISIC_0028008 455 | ISIC_0030285 456 | ISIC_0033860 457 | ISIC_0025179 458 | ISIC_0025511 459 | ISIC_0031321 460 | ISIC_0025115 461 | ISIC_0029421 462 | ISIC_0026122 463 | ISIC_0031507 464 | ISIC_0033662 465 | ISIC_0029101 466 | ISIC_0028271 467 | ISIC_0029885 468 | ISIC_0027376 469 | ISIC_0031544 470 | ISIC_0033225 471 | ISIC_0031025 472 | ISIC_0033814 473 | ISIC_0027456 474 | ISIC_0030019 475 | ISIC_0027132 476 | ISIC_0028158 477 | ISIC_0026920 478 | ISIC_0033869 479 | ISIC_0031269 480 | ISIC_0027923 481 | ISIC_0029762 482 | ISIC_0032638 483 | ISIC_0031911 484 | ISIC_0030887 485 | ISIC_0026802 486 | ISIC_0031635 487 | ISIC_0030839 488 | ISIC_0031681 489 | ISIC_0028343 490 | ISIC_0029222 491 | ISIC_0025382 492 | ISIC_0033576 493 | ISIC_0025375 494 | ISIC_0033671 495 | ISIC_0025442 496 | ISIC_0025305 497 | ISIC_0027021 498 | ISIC_0033079 499 | ISIC_0031496 500 | ISIC_0030982 501 | ISIC_0033894 502 | ISIC_0032522 503 | ISIC_0024567 504 | ISIC_0028216 505 | ISIC_0027603 506 | ISIC_0030955 507 | ISIC_0026001 508 | ISIC_0032506 509 | ISIC_0030813 510 | ISIC_0026019 511 | ISIC_0027367 512 | ISIC_0024706 513 | ISIC_0027556 514 | ISIC_0029583 515 | ISIC_0029933 516 | ISIC_0028440 517 | ISIC_0030694 518 | ISIC_0025170 519 | ISIC_0026130 520 | ISIC_0030795 521 | ISIC_0033103 522 | ISIC_0029836 523 | ISIC_0026739 524 | ISIC_0032488 525 | ISIC_0028609 526 | ISIC_0027634 527 | ISIC_0027668 528 | ISIC_0029471 529 | ISIC_0025196 530 | ISIC_0030056 531 | ISIC_0031864 532 | ISIC_0032791 533 | ISIC_0028361 534 | ISIC_0029296 535 | ISIC_0026603 536 | ISIC_0029385 537 | ISIC_0028347 538 | ISIC_0025418 539 | ISIC_0024937 540 | ISIC_0031315 541 | ISIC_0032710 542 | ISIC_0028360 543 | ISIC_0026144 544 | ISIC_0030276 545 | ISIC_0031406 546 | ISIC_0029689 547 | ISIC_0031392 548 | ISIC_0032696 549 | ISIC_0033827 550 | ISIC_0030663 551 | ISIC_0029676 552 | ISIC_0028087 553 | ISIC_0031068 554 | ISIC_0034239 555 | ISIC_0029844 556 | ISIC_0031971 557 | ISIC_0027488 558 | ISIC_0028925 559 | ISIC_0030679 560 | ISIC_0033596 561 | ISIC_0026240 562 | ISIC_0027841 563 | ISIC_0032101 564 | ISIC_0032472 565 | ISIC_0027447 566 | ISIC_0030201 567 | ISIC_0026937 568 | ISIC_0024601 569 | ISIC_0034033 570 | ISIC_0027182 571 | ISIC_0033260 572 | ISIC_0033085 573 | ISIC_0028654 574 | ISIC_0030382 575 | ISIC_0031699 576 | ISIC_0026752 577 | ISIC_0031901 578 | ISIC_0026745 579 | ISIC_0032288 580 | ISIC_0031341 581 | ISIC_0033057 582 | ISIC_0024962 583 | ISIC_0026391 584 | ISIC_0026815 585 | ISIC_0030709 586 | ISIC_0032230 587 | ISIC_0033137 588 | ISIC_0033491 589 | ISIC_0030505 590 | ISIC_0028307 591 | ISIC_0030289 592 | ISIC_0026380 593 | ISIC_0024583 594 | ISIC_0028560 595 | ISIC_0033966 596 | ISIC_0031541 597 | ISIC_0024600 598 | ISIC_0031154 599 | ISIC_0030608 600 | ISIC_0033235 601 | ISIC_0031041 602 | ISIC_0027823 603 | ISIC_0031295 604 | ISIC_0027630 605 | ISIC_0027851 606 | ISIC_0031245 607 | ISIC_0025675 608 | ISIC_0028915 609 | ISIC_0028469 610 | ISIC_0033315 611 | ISIC_0025620 612 | ISIC_0028391 613 | ISIC_0026032 614 | ISIC_0026663 615 | ISIC_0027701 616 | ISIC_0032047 617 | ISIC_0031712 618 | ISIC_0030192 619 | ISIC_0030357 620 | ISIC_0033278 621 | ISIC_0030575 622 | ISIC_0033766 623 | ISIC_0025703 624 | ISIC_0025550 625 | ISIC_0031926 626 | ISIC_0025139 627 | ISIC_0024455 628 | ISIC_0027953 629 | ISIC_0029574 630 | ISIC_0024490 631 | ISIC_0025226 632 | ISIC_0025793 633 | ISIC_0027232 634 | ISIC_0029916 635 | ISIC_0030274 636 | ISIC_0027940 637 | ISIC_0030048 638 | ISIC_0026341 639 | ISIC_0030083 640 | ISIC_0024570 641 | ISIC_0027605 642 | ISIC_0029725 643 | ISIC_0026810 644 | ISIC_0028998 645 | ISIC_0031872 646 | ISIC_0027336 647 | ISIC_0028811 648 | ISIC_0032707 649 | ISIC_0031098 650 | ISIC_0026556 651 | ISIC_0026701 652 | ISIC_0025684 653 | ISIC_0025806 654 | ISIC_0026597 655 | ISIC_0033269 656 | ISIC_0031380 657 | ISIC_0027760 658 | ISIC_0032067 659 | ISIC_0033678 660 | ISIC_0027743 661 | ISIC_0028272 662 | ISIC_0031509 663 | ISIC_0027560 664 | ISIC_0031734 665 | ISIC_0034194 666 | ISIC_0033927 667 | ISIC_0024357 668 | ISIC_0030755 669 | ISIC_0032098 670 | ISIC_0031939 671 | ISIC_0032232 672 | ISIC_0033367 673 | ISIC_0031176 674 | ISIC_0026275 675 | ISIC_0027019 676 | ISIC_0032154 677 | ISIC_0029884 678 | ISIC_0025820 679 | ISIC_0031937 680 | ISIC_0033686 681 | ISIC_0032610 682 | ISIC_0027639 683 | ISIC_0024474 684 | ISIC_0026628 685 | ISIC_0028977 686 | ISIC_0024730 687 | ISIC_0026908 688 | ISIC_0027800 689 | ISIC_0028562 690 | ISIC_0025660 691 | ISIC_0028064 692 | ISIC_0031543 693 | ISIC_0027843 694 | ISIC_0032213 695 | ISIC_0027462 696 | ISIC_0027632 697 | ISIC_0028306 698 | ISIC_0034276 699 | ISIC_0025159 700 | ISIC_0032498 701 | ISIC_0033863 702 | ISIC_0031485 703 | ISIC_0026564 704 | ISIC_0024359 705 | ISIC_0031229 706 | ISIC_0033798 707 | ISIC_0026483 708 | ISIC_0030596 709 | ISIC_0030948 710 | ISIC_0033829 711 | ISIC_0028301 712 | ISIC_0029021 713 | ISIC_0028006 714 | ISIC_0029397 715 | ISIC_0033452 716 | ISIC_0030003 717 | ISIC_0027147 718 | ISIC_0032104 719 | ISIC_0025518 720 | ISIC_0032070 721 | ISIC_0033224 722 | ISIC_0033730 723 | ISIC_0028154 724 | ISIC_0028300 725 | ISIC_0034051 726 | ISIC_0033331 727 | ISIC_0027897 728 | ISIC_0032355 729 | ISIC_0031230 730 | ISIC_0024821 731 | ISIC_0032569 732 | ISIC_0027015 733 | ISIC_0024662 734 | ISIC_0031268 735 | ISIC_0032424 736 | ISIC_0029807 737 | ISIC_0033027 738 | ISIC_0027626 739 | ISIC_0024365 740 | ISIC_0029745 741 | ISIC_0024587 742 | ISIC_0033952 743 | ISIC_0027654 744 | ISIC_0025935 745 | ISIC_0027088 746 | ISIC_0033109 747 | ISIC_0029454 748 | ISIC_0032949 749 | ISIC_0027881 750 | ISIC_0031631 751 | ISIC_0034310 752 | ISIC_0025582 753 | ISIC_0028353 754 | ISIC_0025836 755 | ISIC_0031297 756 | ISIC_0025111 757 | ISIC_0025835 758 | ISIC_0031160 759 | ISIC_0033232 760 | ISIC_0032547 761 | ISIC_0029972 762 | ISIC_0026997 763 | ISIC_0029960 764 | ISIC_0027360 765 | ISIC_0025673 766 | ISIC_0029712 767 | ISIC_0032922 768 | ISIC_0026770 769 | ISIC_0024306 770 | ISIC_0033689 771 | ISIC_0028874 772 | ISIC_0030350 773 | ISIC_0029378 774 | ISIC_0024482 775 | ISIC_0025328 776 | ISIC_0033130 777 | ISIC_0033010 778 | ISIC_0025592 779 | ISIC_0031369 780 | ISIC_0026715 781 | ISIC_0030214 782 | ISIC_0027314 783 | ISIC_0030257 784 | ISIC_0030849 785 | ISIC_0026690 786 | ISIC_0025654 787 | ISIC_0029722 788 | ISIC_0028581 789 | ISIC_0030980 790 | ISIC_0030746 791 | ISIC_0024838 792 | ISIC_0033724 793 | ISIC_0024686 794 | ISIC_0033825 795 | ISIC_0027652 796 | ISIC_0031439 797 | ISIC_0026987 798 | ISIC_0025236 799 | ISIC_0026932 800 | ISIC_0026832 801 | ISIC_0034270 802 | ISIC_0034316 803 | ISIC_0032917 804 | ISIC_0025365 805 | ISIC_0028482 806 | ISIC_0031396 807 | ISIC_0024379 808 | ISIC_0030278 809 | ISIC_0027479 810 | ISIC_0026910 811 | ISIC_0031542 812 | ISIC_0025558 813 | ISIC_0028700 814 | ISIC_0033140 815 | ISIC_0026170 816 | ISIC_0028909 817 | ISIC_0032986 818 | ISIC_0032886 819 | ISIC_0033615 820 | ISIC_0033684 821 | ISIC_0025596 822 | ISIC_0027301 823 | ISIC_0029263 824 | ISIC_0030954 825 | ISIC_0024761 826 | ISIC_0030733 827 | ISIC_0025121 828 | ISIC_0031920 829 | ISIC_0027226 830 | ISIC_0026944 831 | ISIC_0033475 832 | ISIC_0029978 833 | ISIC_0027436 834 | ISIC_0034195 835 | ISIC_0032342 836 | ISIC_0029006 837 | ISIC_0027083 838 | ISIC_0026528 839 | ISIC_0032850 840 | ISIC_0031403 841 | ISIC_0031850 842 | ISIC_0032465 843 | ISIC_0029586 844 | ISIC_0028025 845 | ISIC_0033630 846 | ISIC_0031829 847 | ISIC_0030057 848 | ISIC_0028179 849 | ISIC_0025509 850 | ISIC_0025711 851 | ISIC_0026005 852 | ISIC_0032129 853 | ISIC_0030469 854 | ISIC_0029024 855 | ISIC_0033883 856 | ISIC_0029184 857 | ISIC_0030233 858 | ISIC_0024469 859 | ISIC_0026353 860 | ISIC_0031272 861 | ISIC_0030648 862 | ISIC_0027666 863 | ISIC_0029600 864 | ISIC_0032906 865 | ISIC_0028424 866 | ISIC_0030172 867 | ISIC_0027718 868 | ISIC_0029407 869 | ISIC_0024947 870 | ISIC_0031056 871 | ISIC_0025649 872 | ISIC_0028786 873 | ISIC_0025194 874 | ISIC_0028233 875 | ISIC_0024837 876 | ISIC_0030337 877 | ISIC_0032308 878 | ISIC_0031789 879 | ISIC_0031555 880 | ISIC_0032181 881 | ISIC_0026101 882 | ISIC_0031090 883 | ISIC_0033538 884 | ISIC_0030334 885 | ISIC_0026415 886 | ISIC_0029987 887 | ISIC_0027472 888 | ISIC_0024572 889 | ISIC_0028069 890 | ISIC_0025117 891 | ISIC_0024666 892 | ISIC_0025766 893 | ISIC_0025258 894 | ISIC_0026466 895 | ISIC_0026105 896 | ISIC_0031774 897 | ISIC_0032430 898 | ISIC_0031306 899 | ISIC_0027643 900 | ISIC_0033978 901 | ISIC_0029733 902 | ISIC_0032017 903 | ISIC_0030870 904 | ISIC_0026406 905 | ISIC_0032784 906 | ISIC_0031030 907 | ISIC_0029481 908 | ISIC_0033578 909 | ISIC_0028537 910 | ISIC_0033279 911 | ISIC_0026497 912 | ISIC_0032709 913 | ISIC_0027470 914 | ISIC_0032772 915 | ISIC_0033495 916 | ISIC_0024394 917 | ISIC_0026630 918 | ISIC_0032796 919 | ISIC_0032318 920 | ISIC_0028082 921 | ISIC_0024621 922 | ISIC_0032064 923 | ISIC_0027423 924 | ISIC_0028924 925 | ISIC_0026326 926 | ISIC_0026308 927 | ISIC_0027280 928 | ISIC_0025581 929 | ISIC_0031583 930 | ISIC_0024834 931 | ISIC_0030494 932 | ISIC_0028710 933 | ISIC_0027904 934 | ISIC_0029536 935 | ISIC_0027987 936 | ISIC_0025624 937 | ISIC_0031049 938 | ISIC_0031391 939 | ISIC_0030084 940 | ISIC_0027340 941 | ISIC_0030689 942 | ISIC_0029117 943 | ISIC_0033642 944 | ISIC_0025472 945 | ISIC_0029031 946 | ISIC_0034072 947 | ISIC_0033652 948 | ISIC_0030542 949 | ISIC_0025528 950 | ISIC_0030910 951 | ISIC_0029249 952 | ISIC_0028561 953 | ISIC_0031189 954 | ISIC_0026288 955 | ISIC_0033648 956 | ISIC_0028326 957 | ISIC_0032869 958 | ISIC_0027788 959 | ISIC_0031372 960 | ISIC_0025844 961 | ISIC_0034155 962 | ISIC_0027050 963 | ISIC_0025859 964 | ISIC_0031647 965 | ISIC_0027123 966 | ISIC_0033528 967 | ISIC_0032266 968 | ISIC_0033116 969 | ISIC_0026687 970 | ISIC_0029492 971 | ISIC_0028832 972 | ISIC_0027830 973 | ISIC_0026935 974 | ISIC_0031753 975 | ISIC_0028429 976 | ISIC_0033787 977 | ISIC_0027072 978 | ISIC_0025962 979 | ISIC_0029124 980 | ISIC_0032202 981 | ISIC_0032049 982 | ISIC_0026149 983 | ISIC_0031148 984 | ISIC_0027457 985 | ISIC_0034082 986 | ISIC_0027692 987 | ISIC_0029352 988 | ISIC_0028331 989 | ISIC_0033926 990 | ISIC_0024732 991 | ISIC_0028400 992 | ISIC_0029880 993 | ISIC_0026229 994 | ISIC_0030529 995 | ISIC_0034304 996 | ISIC_0031360 997 | ISIC_0032902 998 | ISIC_0032989 999 | ISIC_0030483 1000 | ISIC_0031204 1001 | ISIC_0026848 1002 | ISIC_0032491 1003 | ISIC_0027821 1004 | ISIC_0031337 1005 | ISIC_0034163 1006 | ISIC_0030113 1007 | ISIC_0029275 1008 | ISIC_0031325 1009 | ISIC_0030693 1010 | ISIC_0030462 1011 | ISIC_0033672 1012 | ISIC_0030854 1013 | ISIC_0031547 1014 | ISIC_0031669 1015 | ISIC_0031782 1016 | ISIC_0026147 1017 | ISIC_0029706 1018 | ISIC_0025515 1019 | ISIC_0024780 1020 | ISIC_0027617 1021 | ISIC_0032713 1022 | ISIC_0033154 1023 | ISIC_0028902 1024 | ISIC_0027541 1025 | ISIC_0031962 1026 | ISIC_0033714 1027 | ISIC_0031754 1028 | ISIC_0029366 1029 | ISIC_0028713 1030 | ISIC_0024437 1031 | ISIC_0029716 1032 | ISIC_0025647 1033 | ISIC_0026490 1034 | ISIC_0033597 1035 | ISIC_0033092 1036 | ISIC_0031034 1037 | ISIC_0033308 1038 | ISIC_0032252 1039 | ISIC_0024553 1040 | ISIC_0033468 1041 | ISIC_0024335 1042 | ISIC_0025425 1043 | ISIC_0030133 1044 | ISIC_0031876 1045 | ISIC_0031713 1046 | ISIC_0031879 1047 | ISIC_0029148 1048 | ISIC_0030477 1049 | ISIC_0030279 1050 | ISIC_0034280 1051 | ISIC_0034167 1052 | ISIC_0028972 1053 | ISIC_0024939 1054 | ISIC_0032544 1055 | ISIC_0032458 1056 | ISIC_0029051 1057 | ISIC_0033606 1058 | ISIC_0025174 1059 | ISIC_0029229 1060 | ISIC_0031471 1061 | ISIC_0031919 1062 | ISIC_0030425 1063 | ISIC_0028435 1064 | ISIC_0027865 1065 | ISIC_0026809 1066 | ISIC_0027873 1067 | ISIC_0028999 1068 | ISIC_0025166 1069 | ISIC_0032091 1070 | ISIC_0028411 1071 | ISIC_0031524 1072 | ISIC_0029191 1073 | ISIC_0026616 1074 | ISIC_0027065 1075 | ISIC_0031281 1076 | ISIC_0025154 1077 | ISIC_0025549 1078 | ISIC_0029019 1079 | ISIC_0029631 1080 | ISIC_0030808 1081 | ISIC_0030485 1082 | ISIC_0025486 1083 | ISIC_0027766 1084 | ISIC_0030799 1085 | ISIC_0026677 1086 | ISIC_0024308 1087 | ISIC_0031480 1088 | ISIC_0028317 1089 | ISIC_0029878 1090 | ISIC_0031265 1091 | ISIC_0025354 1092 | ISIC_0029340 1093 | ISIC_0028797 1094 | ISIC_0034038 1095 | ISIC_0033360 1096 | ISIC_0028169 1097 | ISIC_0029116 1098 | ISIC_0026857 1099 | ISIC_0034264 1100 | ISIC_0026737 1101 | ISIC_0028103 1102 | ISIC_0032490 1103 | ISIC_0024616 1104 | ISIC_0025924 1105 | ISIC_0025398 1106 | ISIC_0026246 1107 | ISIC_0030970 1108 | ISIC_0028131 1109 | ISIC_0030181 1110 | ISIC_0025535 1111 | ISIC_0031126 1112 | ISIC_0029161 1113 | ISIC_0025780 1114 | ISIC_0026039 1115 | ISIC_0026934 1116 | ISIC_0026928 1117 | ISIC_0032103 1118 | ISIC_0030765 1119 | ISIC_0028868 1120 | ISIC_0028655 1121 | ISIC_0033868 1122 | ISIC_0027140 1123 | ISIC_0027315 1124 | ISIC_0025068 1125 | ISIC_0031750 1126 | ISIC_0029416 1127 | ISIC_0028196 1128 | ISIC_0031116 1129 | ISIC_0033709 1130 | ISIC_0033747 1131 | ISIC_0027620 1132 | ISIC_0027747 1133 | ISIC_0033180 1134 | ISIC_0031138 1135 | ISIC_0029265 1136 | ISIC_0025297 1137 | ISIC_0025220 1138 | ISIC_0033504 1139 | ISIC_0032683 1140 | ISIC_0033420 1141 | ISIC_0027960 1142 | ISIC_0034096 1143 | ISIC_0031170 1144 | ISIC_0031724 1145 | ISIC_0026790 1146 | ISIC_0027454 1147 | ISIC_0025046 1148 | ISIC_0024337 1149 | ISIC_0025277 1150 | ISIC_0030145 1151 | ISIC_0028268 1152 | ISIC_0029866 1153 | ISIC_0025422 1154 | ISIC_0030111 1155 | ISIC_0024462 1156 | ISIC_0031913 1157 | ISIC_0026773 1158 | ISIC_0025137 1159 | ISIC_0029760 1160 | ISIC_0033333 1161 | ISIC_0027131 1162 | ISIC_0033000 1163 | ISIC_0033976 1164 | ISIC_0030766 1165 | ISIC_0031000 1166 | ISIC_0033421 1167 | ISIC_0031674 1168 | ISIC_0026350 1169 | ISIC_0026286 1170 | ISIC_0032747 1171 | ISIC_0026989 1172 | ISIC_0025652 1173 | ISIC_0026383 1174 | ISIC_0030993 1175 | ISIC_0027959 1176 | ISIC_0024661 1177 | ISIC_0025050 1178 | ISIC_0031499 1179 | ISIC_0034277 1180 | ISIC_0033342 1181 | ISIC_0025029 1182 | ISIC_0026282 1183 | ISIC_0027261 1184 | ISIC_0028646 1185 | ISIC_0026520 1186 | ISIC_0029700 1187 | ISIC_0034088 1188 | ISIC_0028176 1189 | ISIC_0030148 1190 | ISIC_0025564 1191 | ISIC_0032583 1192 | ISIC_0025052 1193 | ISIC_0029582 1194 | ISIC_0026173 1195 | ISIC_0030492 1196 | ISIC_0033460 1197 | ISIC_0024571 1198 | ISIC_0034084 1199 | ISIC_0026734 1200 | ISIC_0027445 1201 | ISIC_0032004 1202 | ISIC_0028839 1203 | ISIC_0027368 1204 | ISIC_0034177 1205 | ISIC_0024519 1206 | ISIC_0032132 1207 | ISIC_0029529 1208 | ISIC_0028253 1209 | ISIC_0026364 1210 | ISIC_0032392 1211 | ISIC_0030757 1212 | ISIC_0028017 1213 | ISIC_0027210 1214 | ISIC_0029228 1215 | ISIC_0028144 1216 | ISIC_0031688 1217 | ISIC_0032612 1218 | ISIC_0027442 1219 | ISIC_0024590 1220 | ISIC_0024401 1221 | ISIC_0026186 1222 | ISIC_0031333 1223 | ISIC_0032156 1224 | ISIC_0032165 1225 | ISIC_0026586 1226 | ISIC_0029842 1227 | ISIC_0026827 1228 | ISIC_0027919 1229 | ISIC_0031639 1230 | ISIC_0033917 1231 | ISIC_0026309 1232 | ISIC_0029652 1233 | ISIC_0028238 1234 | ISIC_0028884 1235 | ISIC_0032541 1236 | ISIC_0030223 1237 | ISIC_0028722 1238 | ISIC_0030659 1239 | ISIC_0027523 1240 | ISIC_0033876 1241 | ISIC_0028639 1242 | ISIC_0025309 1243 | ISIC_0031159 1244 | ISIC_0033677 1245 | ISIC_0030907 1246 | ISIC_0029601 1247 | ISIC_0032798 1248 | ISIC_0026813 1249 | ISIC_0033440 1250 | ISIC_0025490 1251 | ISIC_0025565 1252 | ISIC_0026828 1253 | ISIC_0031505 1254 | ISIC_0033634 1255 | ISIC_0028865 1256 | ISIC_0028407 1257 | ISIC_0033339 1258 | ISIC_0032824 1259 | ISIC_0026742 1260 | ISIC_0033665 1261 | ISIC_0033946 1262 | ISIC_0029294 1263 | ISIC_0031603 1264 | ISIC_0029096 1265 | ISIC_0028789 1266 | ISIC_0027409 1267 | ISIC_0031500 1268 | ISIC_0024909 1269 | ISIC_0028724 1270 | ISIC_0030269 1271 | ISIC_0033363 1272 | ISIC_0028348 1273 | ISIC_0033895 1274 | ISIC_0029867 1275 | ISIC_0030965 1276 | ISIC_0027673 1277 | ISIC_0025586 1278 | ISIC_0032034 1279 | ISIC_0033328 1280 | ISIC_0027455 1281 | ISIC_0028105 1282 | ISIC_0030522 1283 | ISIC_0030707 1284 | ISIC_0030664 1285 | ISIC_0024634 1286 | ISIC_0026010 1287 | ISIC_0026799 1288 | ISIC_0031379 1289 | ISIC_0032977 1290 | ISIC_0027935 1291 | ISIC_0033873 1292 | ISIC_0028316 1293 | ISIC_0034061 1294 | ISIC_0032615 1295 | ISIC_0030373 1296 | ISIC_0033377 1297 | ISIC_0033527 1298 | ISIC_0032964 1299 | ISIC_0030667 1300 | ISIC_0026505 1301 | ISIC_0032316 1302 | ISIC_0030895 1303 | ISIC_0027389 1304 | ISIC_0030371 1305 | ISIC_0029758 1306 | ISIC_0026207 1307 | ISIC_0033274 1308 | ISIC_0028970 1309 | ISIC_0026092 1310 | ISIC_0033785 1311 | ISIC_0033660 1312 | ISIC_0029494 1313 | ISIC_0029659 1314 | ISIC_0030400 1315 | ISIC_0024390 1316 | ISIC_0027793 1317 | ISIC_0027095 1318 | ISIC_0031661 1319 | ISIC_0024817 1320 | ISIC_0034233 1321 | ISIC_0030838 1322 | ISIC_0025037 1323 | ISIC_0029088 1324 | ISIC_0029167 1325 | ISIC_0028776 1326 | ISIC_0033789 1327 | ISIC_0027063 1328 | ISIC_0027165 1329 | ISIC_0033948 1330 | ISIC_0027428 1331 | ISIC_0027490 1332 | ISIC_0026361 1333 | ISIC_0031120 1334 | ISIC_0028640 1335 | ISIC_0024433 1336 | ISIC_0028756 1337 | ISIC_0024470 1338 | ISIC_0034054 1339 | ISIC_0031767 1340 | ISIC_0030721 1341 | ISIC_0027505 1342 | ISIC_0030292 1343 | ISIC_0025625 1344 | ISIC_0029893 1345 | ISIC_0031948 1346 | ISIC_0024778 1347 | ISIC_0026303 1348 | ISIC_0024324 1349 | ISIC_0029464 1350 | ISIC_0026443 1351 | ISIC_0025331 1352 | ISIC_0032595 1353 | ISIC_0030241 1354 | ISIC_0027420 1355 | ISIC_0032050 1356 | ISIC_0030416 1357 | ISIC_0027467 1358 | ISIC_0030246 1359 | ISIC_0034020 1360 | ISIC_0032194 1361 | ISIC_0024755 1362 | ISIC_0033048 1363 | ISIC_0034060 1364 | ISIC_0028678 1365 | ISIC_0030668 1366 | ISIC_0028050 1367 | ISIC_0024484 1368 | ISIC_0028835 1369 | ISIC_0025646 1370 | ISIC_0027854 1371 | ISIC_0025897 1372 | ISIC_0034030 1373 | ISIC_0031859 1374 | ISIC_0033239 1375 | ISIC_0026075 1376 | ISIC_0027902 1377 | ISIC_0024789 1378 | ISIC_0028273 1379 | ISIC_0031300 1380 | ISIC_0034267 1381 | ISIC_0026791 1382 | ISIC_0025086 1383 | ISIC_0028950 1384 | ISIC_0031303 1385 | ISIC_0031535 1386 | ISIC_0030365 1387 | ISIC_0024383 1388 | ISIC_0030995 1389 | ISIC_0031353 1390 | ISIC_0033916 1391 | ISIC_0032054 1392 | ISIC_0033171 1393 | ISIC_0026509 1394 | ISIC_0032956 1395 | ISIC_0030391 1396 | ISIC_0030312 1397 | ISIC_0025058 1398 | ISIC_0032619 1399 | ISIC_0027943 1400 | ISIC_0025005 1401 | ISIC_0032448 1402 | ISIC_0033741 1403 | ISIC_0031248 1404 | ISIC_0031490 1405 | ISIC_0025798 1406 | ISIC_0030634 1407 | ISIC_0026392 1408 | ISIC_0027671 1409 | ISIC_0032529 1410 | ISIC_0031513 1411 | ISIC_0026460 1412 | ISIC_0024375 1413 | ISIC_0025755 1414 | ISIC_0027481 1415 | ISIC_0028183 1416 | ISIC_0028265 1417 | ISIC_0027527 1418 | ISIC_0025837 1419 | ISIC_0028807 1420 | ISIC_0031984 1421 | ISIC_0031514 1422 | ISIC_0024810 1423 | ISIC_0031259 1424 | ISIC_0032837 1425 | ISIC_0027886 1426 | ISIC_0031395 1427 | ISIC_0030635 1428 | ISIC_0025104 1429 | ISIC_0029146 1430 | ISIC_0029380 1431 | ISIC_0034186 1432 | ISIC_0028101 1433 | ISIC_0024367 1434 | ISIC_0026081 1435 | ISIC_0030004 1436 | ISIC_0034255 1437 | ISIC_0026766 1438 | ISIC_0029813 1439 | ISIC_0034056 1440 | ISIC_0030841 1441 | ISIC_0025873 1442 | ISIC_0029048 1443 | ISIC_0028762 1444 | ISIC_0027154 1445 | ISIC_0027292 1446 | ISIC_0028760 1447 | ISIC_0031519 1448 | ISIC_0025874 1449 | ISIC_0033767 1450 | ISIC_0030518 1451 | ISIC_0031792 1452 | ISIC_0029325 1453 | ISIC_0028679 1454 | ISIC_0024687 1455 | ISIC_0030594 1456 | ISIC_0026718 1457 | ISIC_0028218 1458 | ISIC_0029426 1459 | ISIC_0025234 1460 | ISIC_0028468 1461 | ISIC_0033779 1462 | ISIC_0030478 1463 | ISIC_0032100 1464 | ISIC_0024551 1465 | ISIC_0025745 1466 | ISIC_0033520 1467 | ISIC_0032243 1468 | ISIC_0032152 1469 | ISIC_0032909 1470 | ISIC_0033560 1471 | ISIC_0028518 1472 | ISIC_0032767 1473 | ISIC_0029427 1474 | ISIC_0028649 1475 | ISIC_0027867 1476 | ISIC_0033163 1477 | ISIC_0033230 1478 | ISIC_0024420 1479 | ISIC_0026608 1480 | ISIC_0032616 1481 | ISIC_0026486 1482 | ISIC_0032660 1483 | ISIC_0032761 1484 | ISIC_0029715 1485 | ISIC_0025081 1486 | ISIC_0029177 1487 | ISIC_0031666 1488 | ISIC_0028982 1489 | ISIC_0033149 1490 | ISIC_0033920 1491 | ISIC_0033668 1492 | ISIC_0031592 1493 | ISIC_0028066 1494 | ISIC_0034269 1495 | ISIC_0030983 1496 | ISIC_0031312 1497 | ISIC_0028180 1498 | ISIC_0033544 1499 | ISIC_0025674 1500 | ISIC_0033395 1501 | ISIC_0033426 1502 | ISIC_0032512 1503 | ISIC_0033465 1504 | ISIC_0032361 1505 | ISIC_0025412 1506 | ISIC_0027731 1507 | ISIC_0028642 1508 | ISIC_0028274 1509 | ISIC_0025560 1510 | ISIC_0026954 1511 | ISIC_0027229 1512 | ISIC_0029086 1513 | ISIC_0025945 1514 | ISIC_0025439 1515 | ISIC_0027825 1516 | ISIC_0030821 1517 | ISIC_0032031 1518 | ISIC_0024577 1519 | ISIC_0029655 1520 | ISIC_0025853 1521 | ISIC_0028616 1522 | ISIC_0025324 1523 | ISIC_0028746 1524 | ISIC_0033336 1525 | ISIC_0030017 1526 | ISIC_0029197 1527 | ISIC_0032197 1528 | ISIC_0032910 1529 | ISIC_0028744 1530 | ISIC_0032875 1531 | ISIC_0025797 1532 | ISIC_0033236 1533 | ISIC_0027921 1534 | ISIC_0024358 1535 | ISIC_0029711 1536 | ISIC_0029400 1537 | ISIC_0027400 1538 | ISIC_0025695 1539 | ISIC_0032190 1540 | ISIC_0027784 1541 | ISIC_0028497 1542 | ISIC_0029901 1543 | ISIC_0027528 1544 | ISIC_0025938 1545 | ISIC_0030000 1546 | ISIC_0032037 1547 | ISIC_0024773 1548 | ISIC_0034091 1549 | ISIC_0033401 1550 | ISIC_0024377 1551 | ISIC_0033942 1552 | ISIC_0029358 1553 | ISIC_0025513 1554 | ISIC_0025896 1555 | ISIC_0029465 1556 | ISIC_0029930 1557 | ISIC_0024964 1558 | ISIC_0032207 1559 | ISIC_0032561 1560 | ISIC_0030844 1561 | ISIC_0029478 1562 | ISIC_0027342 1563 | ISIC_0032952 1564 | ISIC_0029785 1565 | ISIC_0030784 1566 | ISIC_0024708 1567 | ISIC_0029391 1568 | ISIC_0029847 1569 | ISIC_0025778 1570 | ISIC_0033175 1571 | ISIC_0027662 1572 | ISIC_0033749 1573 | ISIC_0026289 1574 | ISIC_0029224 1575 | ISIC_0029365 1576 | ISIC_0028149 1577 | ISIC_0025339 1578 | ISIC_0024480 1579 | ISIC_0033437 1580 | ISIC_0027408 1581 | ISIC_0032019 1582 | ISIC_0032832 1583 | ISIC_0033371 1584 | ISIC_0027159 1585 | ISIC_0025792 1586 | ISIC_0030457 1587 | ISIC_0028599 1588 | ISIC_0032794 1589 | ISIC_0031866 1590 | ISIC_0033943 1591 | ISIC_0034107 1592 | ISIC_0029932 1593 | ISIC_0028151 1594 | ISIC_0030147 1595 | ISIC_0032876 1596 | ISIC_0027416 1597 | ISIC_0028224 1598 | ISIC_0028444 1599 | ISIC_0026478 1600 | ISIC_0026905 1601 | ISIC_0032319 1602 | ISIC_0033413 1603 | ISIC_0030273 1604 | ISIC_0034103 1605 | ISIC_0030173 1606 | ISIC_0028126 1607 | ISIC_0033454 1608 | ISIC_0026193 1609 | ISIC_0032642 1610 | ISIC_0028023 1611 | ISIC_0025488 1612 | ISIC_0034307 1613 | ISIC_0029310 1614 | ISIC_0028908 1615 | ISIC_0033921 1616 | ISIC_0034295 1617 | ISIC_0026028 1618 | ISIC_0027190 1619 | ISIC_0027300 1620 | ISIC_0029069 1621 | ISIC_0033384 1622 | ISIC_0028362 1623 | ISIC_0029686 1624 | ISIC_0032831 1625 | ISIC_0032651 1626 | ISIC_0029822 1627 | ISIC_0030951 1628 | ISIC_0030242 1629 | ISIC_0032186 1630 | ISIC_0025027 1631 | ISIC_0024679 1632 | ISIC_0032492 1633 | ISIC_0032329 1634 | ISIC_0031040 1635 | ISIC_0024618 1636 | ISIC_0032422 1637 | ISIC_0024348 1638 | ISIC_0031339 1639 | ISIC_0031351 1640 | ISIC_0031502 1641 | ISIC_0032915 1642 | ISIC_0029059 1643 | ISIC_0027097 1644 | ISIC_0033202 1645 | ISIC_0025038 1646 | ISIC_0030900 1647 | ISIC_0025250 1648 | ISIC_0029562 1649 | ISIC_0033516 1650 | ISIC_0032493 1651 | ISIC_0029871 1652 | ISIC_0033405 1653 | ISIC_0025999 1654 | ISIC_0024825 1655 | ISIC_0025120 1656 | ISIC_0025124 1657 | ISIC_0029136 1658 | ISIC_0034015 1659 | ISIC_0031364 1660 | ISIC_0029523 1661 | ISIC_0028040 1662 | ISIC_0030578 1663 | ISIC_0030316 1664 | ISIC_0031139 1665 | ISIC_0032096 1666 | ISIC_0026898 1667 | ISIC_0024685 1668 | ISIC_0026557 1669 | ISIC_0028745 1670 | ISIC_0026454 1671 | ISIC_0028032 1672 | ISIC_0026872 1673 | ISIC_0030344 1674 | ISIC_0026109 1675 | ISIC_0034159 1676 | ISIC_0031451 1677 | ISIC_0027927 1678 | ISIC_0030414 1679 | ISIC_0025900 1680 | ISIC_0031367 1681 | ISIC_0029196 1682 | ISIC_0031916 1683 | ISIC_0033213 1684 | ISIC_0026865 1685 | ISIC_0028514 1686 | ISIC_0029401 1687 | ISIC_0031834 1688 | ISIC_0025888 1689 | ISIC_0026853 1690 | ISIC_0032817 1691 | ISIC_0031384 1692 | ISIC_0028627 1693 | ISIC_0029444 1694 | ISIC_0031190 1695 | ISIC_0030847 1696 | ISIC_0026069 1697 | ISIC_0030990 1698 | ISIC_0030778 1699 | ISIC_0027149 1700 | ISIC_0029250 1701 | ISIC_0032711 1702 | ISIC_0028739 1703 | ISIC_0033301 1704 | ISIC_0032427 1705 | ISIC_0030662 1706 | ISIC_0029584 1707 | ISIC_0031586 1708 | ISIC_0028438 1709 | ISIC_0024411 1710 | ISIC_0029252 1711 | ISIC_0027112 1712 | ISIC_0030815 1713 | ISIC_0033963 1714 | ISIC_0031131 1715 | ISIC_0027259 1716 | ISIC_0033063 1717 | ISIC_0031825 1718 | ISIC_0028651 1719 | ISIC_0032108 1720 | ISIC_0033972 1721 | ISIC_0027542 1722 | ISIC_0027135 1723 | ISIC_0026808 1724 | ISIC_0024623 1725 | ISIC_0025506 1726 | ISIC_0032225 1727 | ISIC_0026584 1728 | ISIC_0031769 1729 | ISIC_0033150 1730 | ISIC_0032568 1731 | ISIC_0032294 1732 | ISIC_0026964 1733 | ISIC_0025480 1734 | ISIC_0025394 1735 | ISIC_0033040 1736 | ISIC_0032326 1737 | ISIC_0032908 1738 | ISIC_0028370 1739 | ISIC_0034217 1740 | ISIC_0031601 1741 | ISIC_0026397 1742 | ISIC_0028941 1743 | ISIC_0024907 1744 | ISIC_0029734 1745 | ISIC_0032059 1746 | ISIC_0026525 1747 | ISIC_0030164 1748 | ISIC_0033335 1749 | ISIC_0024930 1750 | ISIC_0033304 1751 | ISIC_0029284 1752 | ISIC_0029886 1753 | ISIC_0027421 1754 | ISIC_0033135 1755 | ISIC_0032921 1756 | ISIC_0029129 1757 | ISIC_0031417 1758 | ISIC_0026302 1759 | ISIC_0032396 1760 | ISIC_0028778 1761 | ISIC_0034230 1762 | ISIC_0026792 1763 | ISIC_0032648 1764 | ISIC_0025886 1765 | ISIC_0030351 1766 | ISIC_0030176 1767 | ISIC_0033124 1768 | ISIC_0031431 1769 | ISIC_0027308 1770 | ISIC_0028907 1771 | ISIC_0024449 1772 | ISIC_0029119 1773 | ISIC_0031033 1774 | ISIC_0029557 1775 | ISIC_0027758 1776 | ISIC_0027911 1777 | ISIC_0032177 1778 | ISIC_0028726 1779 | ISIC_0026477 1780 | ISIC_0029281 1781 | ISIC_0029477 1782 | ISIC_0029026 1783 | ISIC_0027138 1784 | ISIC_0025833 1785 | ISIC_0031730 1786 | ISIC_0033115 1787 | ISIC_0033822 1788 | ISIC_0030384 1789 | ISIC_0024996 1790 | ISIC_0026594 1791 | ISIC_0027608 1792 | ISIC_0030091 1793 | ISIC_0033848 1794 | ISIC_0030908 1795 | ISIC_0030088 1796 | ISIC_0030094 1797 | ISIC_0028867 1798 | ISIC_0025493 1799 | ISIC_0025497 1800 | ISIC_0032643 1801 | ISIC_0030656 1802 | ISIC_0027878 1803 | ISIC_0031282 1804 | ISIC_0026140 1805 | ISIC_0033609 1806 | ISIC_0031705 1807 | ISIC_0030032 1808 | ISIC_0026805 1809 | ISIC_0030341 1810 | ISIC_0028502 1811 | ISIC_0028133 1812 | ISIC_0025184 1813 | ISIC_0025761 1814 | ISIC_0032889 1815 | ISIC_0029677 1816 | ISIC_0032261 1817 | ISIC_0032973 1818 | ISIC_0026000 1819 | ISIC_0033793 1820 | ISIC_0026398 1821 | ISIC_0031238 1822 | ISIC_0033635 1823 | ISIC_0029637 1824 | ISIC_0032094 1825 | ISIC_0025138 1826 | ISIC_0031643 1827 | ISIC_0024897 1828 | ISIC_0030022 1829 | ISIC_0030037 1830 | ISIC_0024415 1831 | ISIC_0031804 1832 | ISIC_0025358 1833 | ISIC_0025617 1834 | ISIC_0024974 1835 | ISIC_0032751 1836 | ISIC_0031762 1837 | ISIC_0029474 1838 | ISIC_0029894 1839 | ISIC_0025149 1840 | ISIC_0028815 1841 | ISIC_0033382 1842 | ISIC_0028853 1843 | ISIC_0029660 1844 | ISIC_0029058 1845 | ISIC_0029983 1846 | ISIC_0029888 1847 | ISIC_0025791 1848 | ISIC_0029729 1849 | ISIC_0027043 1850 | ISIC_0026573 1851 | ISIC_0030471 1852 | ISIC_0027114 1853 | ISIC_0025175 1854 | ISIC_0026774 1855 | ISIC_0031799 1856 | ISIC_0032948 1857 | ISIC_0029962 1858 | ISIC_0026285 1859 | ISIC_0025002 1860 | ISIC_0026047 1861 | ISIC_0028024 1862 | ISIC_0033309 1863 | ISIC_0032102 1864 | ISIC_0026495 1865 | ISIC_0029312 1866 | ISIC_0029923 1867 | ISIC_0029451 1868 | ISIC_0032453 1869 | ISIC_0024811 1870 | ISIC_0034193 1871 | ISIC_0027609 1872 | ISIC_0032903 1873 | ISIC_0032203 1874 | ISIC_0027946 1875 | ISIC_0030767 1876 | ISIC_0024489 1877 | ISIC_0031570 1878 | ISIC_0027388 1879 | ISIC_0026182 1880 | ISIC_0030119 1881 | ISIC_0024915 1882 | ISIC_0026339 1883 | ISIC_0028287 1884 | ISIC_0030935 1885 | ISIC_0030728 1886 | ISIC_0026274 1887 | ISIC_0031080 1888 | ISIC_0024326 1889 | ISIC_0031489 1890 | ISIC_0032825 1891 | ISIC_0027576 1892 | ISIC_0026824 1893 | ISIC_0030290 1894 | ISIC_0032168 1895 | ISIC_0026484 1896 | ISIC_0029285 1897 | ISIC_0027716 1898 | ISIC_0034219 1899 | ISIC_0024795 1900 | ISIC_0032046 1901 | ISIC_0033104 1902 | ISIC_0027518 1903 | ISIC_0024458 1904 | ISIC_0025195 1905 | ISIC_0030729 1906 | ISIC_0024575 1907 | ISIC_0024402 1908 | ISIC_0031982 1909 | ISIC_0031019 1910 | ISIC_0028191 1911 | ISIC_0034070 1912 | ISIC_0032016 1913 | ISIC_0025373 1914 | ISIC_0031806 1915 | ISIC_0029501 1916 | ISIC_0027066 1917 | ISIC_0030527 1918 | ISIC_0030713 1919 | ISIC_0025021 1920 | ISIC_0028365 1921 | ISIC_0034279 1922 | ISIC_0028871 1923 | ISIC_0026104 1924 | ISIC_0024767 1925 | ISIC_0029679 1926 | ISIC_0027250 1927 | ISIC_0033843 1928 | ISIC_0029097 1929 | ISIC_0029577 1930 | ISIC_0030154 1931 | ISIC_0029330 1932 | ISIC_0031776 1933 | ISIC_0025520 1934 | ISIC_0025678 1935 | ISIC_0024549 1936 | ISIC_0028345 1937 | ISIC_0032267 1938 | ISIC_0031561 1939 | ISIC_0025985 1940 | ISIC_0031892 1941 | ISIC_0033810 1942 | ISIC_0024954 1943 | ISIC_0028715 1944 | ISIC_0024461 1945 | ISIC_0027734 1946 | ISIC_0033801 1947 | ISIC_0025456 1948 | ISIC_0027353 1949 | ISIC_0025352 1950 | ISIC_0034138 1951 | ISIC_0025612 1952 | ISIC_0032212 1953 | ISIC_0026550 1954 | ISIC_0024321 1955 | ISIC_0026975 1956 | ISIC_0033858 1957 | ISIC_0026150 1958 | ISIC_0027404 1959 | ISIC_0025918 1960 | ISIC_0032995 1961 | ISIC_0031414 1962 | ISIC_0025643 1963 | ISIC_0025784 1964 | ISIC_0033593 1965 | ISIC_0024848 1966 | ISIC_0028380 1967 | ISIC_0024737 1968 | ISIC_0031077 1969 | ISIC_0033226 1970 | ISIC_0026867 1971 | ISIC_0033835 1972 | ISIC_0027912 1973 | ISIC_0027183 1974 | ISIC_0028721 1975 | ISIC_0025337 1976 | ISIC_0030427 1977 | ISIC_0030591 1978 | ISIC_0025416 1979 | ISIC_0026194 1980 | ISIC_0034025 1981 | ISIC_0025059 1982 | ISIC_0027355 1983 | ISIC_0032457 1984 | ISIC_0024868 1985 | ISIC_0030848 1986 | ISIC_0032399 1987 | ISIC_0029070 1988 | ISIC_0027868 1989 | ISIC_0032750 1990 | ISIC_0026822 1991 | ISIC_0027268 1992 | ISIC_0029680 1993 | ISIC_0029702 1994 | ISIC_0033700 1995 | ISIC_0027768 1996 | ISIC_0032057 1997 | ISIC_0028531 1998 | ISIC_0024450 1999 | ISIC_0029362 2000 | ISIC_0027201 2001 | ISIC_0032119 2002 | ISIC_0025211 2003 | ISIC_0029498 2004 | ISIC_0024541 2005 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | 4 | def parse_transform(transform: str, image_size=224, **transform_kwargs): 5 | if transform == 'RandomColorJitter': 6 | return transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.0)], p=1.0) 7 | elif transform == 'RandomGrayscale': 8 | return transforms.RandomGrayscale(p=0.1) 9 | elif transform == 'RandomGaussianBlur': 10 | return transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 5))], p=0.3) 11 | elif transform == 'RandomCrop': 12 | return transforms.RandomCrop(image_size) 13 | elif transform == 'RandomResizedCrop': 14 | return transforms.RandomResizedCrop(image_size) 15 | elif transform == 'CenterCrop': 16 | return transforms.CenterCrop(image_size) 17 | elif transform == 'Resize_up': 18 | return transforms.Resize( 19 | [int(image_size * 1.15), 20 | int(image_size * 1.15)]) 21 | elif transform == 'Normalize': 22 | return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 23 | elif transform == 'Resize': 24 | return transforms.Resize( 25 | [int(image_size), 26 | int(image_size)]) 27 | elif transform == 'RandomRotation': 28 | return transforms.RandomRotation(degrees=10) 29 | else: 30 | method = getattr(transforms, transform) 31 | return method(**transform_kwargs) 32 | 33 | 34 | def get_composed_transform(augmentation: str = None, image_size=224) -> transforms.Compose: 35 | if augmentation == 'base': 36 | transform_list = ['RandomResizedCrop', 'RandomColorJitter', 'RandomHorizontalFlip', 'ToTensor', 37 | 'Normalize'] 38 | elif augmentation == 'strong': 39 | transform_list = ['RandomResizedCrop', 'RandomColorJitter', 'RandomGrayscale', 'RandomGaussianBlur', 40 | 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 41 | elif augmentation is None or augmentation.lower() == 'none': 42 | transform_list = ['Resize', 'ToTensor', 'Normalize'] 43 | else: 44 | raise ValueError('Unsupported augmentation: {}'.format(augmentation)) 45 | 46 | transform_funcs = [parse_transform(x, image_size=image_size) for x in transform_list] 47 | transform = transforms.Compose(transform_funcs) 48 | return transform 49 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import math 4 | import os 5 | 6 | import pandas as pd 7 | import torch.nn as nn 8 | 9 | from backbone import get_backbone_class 10 | from datasets.dataloader import get_labeled_episodic_dataloader 11 | from io_utils import parse_args 12 | from model import get_model_class 13 | from model.classifier_head import get_classifier_head_class 14 | from paths import get_output_directory, get_ft_output_directory, get_ft_train_history_path, get_ft_test_history_path, \ 15 | get_final_pretrain_state_path, get_pretrain_state_path, get_ft_params_path 16 | from utils import * 17 | 18 | 19 | def main(params): 20 | base_output_dir = get_output_directory(params) 21 | output_dir = get_ft_output_directory(params) 22 | print('Running fine-tune with output folder:') 23 | print(output_dir) 24 | print() 25 | 26 | # Settings 27 | n_episodes = 600 28 | bs = params.ft_batch_size 29 | w = params.n_way 30 | s = params.n_shot 31 | q = params.n_query_shot 32 | # Whether to optimize for fixed features (when there is no augmentation and only head is updated) 33 | use_fixed_features = params.ft_augmentation is None and params.ft_parts == 'head' 34 | 35 | # Model 36 | backbone = get_backbone_class(params.backbone)() 37 | body = get_model_class(params.model)(backbone, params) 38 | if params.ft_features is not None: 39 | if params.ft_features not in body.supported_feature_selectors: 40 | raise ValueError( 41 | 'Feature selector "{}" is not supported for model "{}"'.format(params.ft_features, params.model)) 42 | 43 | # Dataloaders 44 | # Note that both dataloaders sample identical episodes, via episode_seed 45 | support_epochs = 1 if use_fixed_features else params.ft_epochs 46 | support_loader = get_labeled_episodic_dataloader(params.target_dataset, n_way=w, n_shot=s, support=True, 47 | n_query_shot=q, n_episodes=n_episodes, n_epochs=support_epochs, 48 | augmentation=params.ft_augmentation, 49 | unlabeled_ratio=params.unlabeled_ratio, 50 | num_workers=params.num_workers, 51 | split_seed=params.split_seed, episode_seed=params.ft_episode_seed) 52 | query_loader = get_labeled_episodic_dataloader(params.target_dataset, n_way=w, n_shot=s, support=False, 53 | n_query_shot=q, n_episodes=n_episodes, augmentation=None, 54 | unlabeled_ratio=params.unlabeled_ratio, 55 | num_workers=params.num_workers, 56 | split_seed=params.split_seed, 57 | episode_seed=params.ft_episode_seed) 58 | assert (len(query_loader) == n_episodes) 59 | assert (len(support_loader) == n_episodes * support_epochs) 60 | 61 | query_iterator = iter(query_loader) 62 | support_iterator = iter(support_loader) 63 | support_batches = math.ceil(w * s / bs) 64 | 65 | # Output (history, params) 66 | train_history_path = get_ft_train_history_path(output_dir) 67 | test_history_path = get_ft_test_history_path(output_dir) 68 | params_path = get_ft_params_path(output_dir) 69 | print('Saving finetune params to {}'.format(params_path)) 70 | print('Saving finetune train history to {}'.format(train_history_path)) 71 | print('Saving finetune validation history to {}'.format(train_history_path)) 72 | with open(params_path, 'w') as f: 73 | json.dump(vars(params), f, indent=4) 74 | df_train = pd.DataFrame(None, index=list(range(1, n_episodes + 1)), 75 | columns=['epoch{}'.format(e + 1) for e in range(params.ft_epochs)]) 76 | df_test = pd.DataFrame(None, index=list(range(1, n_episodes + 1)), 77 | columns=['epoch{}'.format(e + 1) for e in range(params.ft_epochs)]) 78 | 79 | # Pre-train state 80 | if params.ft_pretrain_epoch is None: 81 | body_state_path = get_final_pretrain_state_path(base_output_dir) 82 | else: 83 | body_state_path = get_pretrain_state_path(base_output_dir, params.ft_pretrain_epoch) 84 | if not os.path.exists(body_state_path): 85 | raise ValueError('Invalid pre-train state path: ' + body_state_path) 86 | print('Using pre-train state:') 87 | print(body_state_path) 88 | print() 89 | state = torch.load(body_state_path) 90 | 91 | # HOTFIX 92 | # print("HOTFIX: removing classifier weights from state") 93 | # del state["classifier.weight"] 94 | # del state["classifier.bias"] 95 | 96 | # Loss function 97 | loss_fn = nn.CrossEntropyLoss().cuda() 98 | 99 | print('Starting fine-tune') 100 | if use_fixed_features: 101 | print('Running optimized fixed-feature fine-tuning (no augmentation, fixed body)') 102 | print() 103 | 104 | for episode in range(n_episodes): 105 | # Reset models for each episode 106 | # classifier.bias issue 107 | 108 | # HOTFIX: load state dict non-strict to ignore classifier weights 109 | # body.load_state_dict(copy.deepcopy(state), strict=False) # note, override model.load_state_dict to change this behavior. 110 | body.load_state_dict(copy.deepcopy(state)) # note, override model.load_state_dict to change this behavior. 111 | head = get_classifier_head_class(params.ft_head)(body.final_feat_dim, params.n_way, params) # TODO: apply ft_features 112 | body.cuda() 113 | head.cuda() 114 | 115 | opt_params = [] 116 | if params.ft_train_head: 117 | opt_params.append({'params': head.parameters()}) 118 | if params.ft_train_body: 119 | opt_params.append({'params': body.parameters()}) 120 | optimizer = torch.optim.SGD(opt_params, lr=params.ft_lr, momentum=0.9, dampening=0.9, weight_decay=0.001) 121 | 122 | # Labels are always [0, 0, ..., 1, ..., w-1] 123 | x_support = None 124 | f_support = None 125 | y_support = torch.arange(w).repeat_interleave(s).cuda() 126 | x_query = next(query_iterator)[0].cuda() 127 | 128 | f_query = None 129 | y_query = torch.arange(w).repeat_interleave(q).cuda() 130 | 131 | if use_fixed_features: # load data and extract features once per episode 132 | with torch.no_grad(): 133 | x_support, _ = next(support_iterator) 134 | x_support = x_support.cuda() 135 | f_support = body.forward_features(x_support, params.ft_features) 136 | f_query = body.forward_features(x_query, params.ft_features) 137 | 138 | train_acc_history = [] 139 | test_acc_history = [] 140 | for epoch in range(params.ft_epochs): 141 | # Train 142 | body.train() 143 | head.train() 144 | 145 | if not use_fixed_features: # load data every epoch 146 | x_support, _ = next(support_iterator) 147 | x_support = x_support.cuda() 148 | 149 | total_loss = 0 150 | correct = 0 151 | indices = np.random.permutation(w * s) 152 | for i in range(support_batches): 153 | start_index = i * bs 154 | end_index = min(i * bs + bs, w * s) 155 | batch_indices = indices[start_index:end_index] 156 | y = y_support[batch_indices] 157 | 158 | if use_fixed_features: 159 | f = f_support[batch_indices] 160 | else: 161 | f = body.forward_features(x_support[batch_indices], params.ft_features) 162 | p = head(f) 163 | 164 | correct += torch.eq(y, p.argmax(dim=1)).sum() 165 | loss = loss_fn(p, y) 166 | optimizer.zero_grad() 167 | loss.backward() 168 | optimizer.step() 169 | 170 | total_loss += loss.item() 171 | 172 | train_loss = total_loss / support_batches 173 | train_acc = correct / (w * s) 174 | 175 | # Evaluation 176 | body.eval() 177 | head.eval() 178 | 179 | if params.ft_intermediate_test or epoch == params.ft_epochs - 1: 180 | with torch.no_grad(): 181 | if not use_fixed_features: 182 | f_query = body.forward_features(x_query, params.ft_features) 183 | p_query = head(f_query) 184 | test_acc = torch.eq(y_query, p_query.argmax(dim=1)).sum() / (w * q) 185 | else: 186 | test_acc = torch.tensor(0) 187 | 188 | print_epoch_logs = False 189 | if print_epoch_logs and (epoch + 1) % 10 == 0: 190 | fmt = 'Epoch {:03d}: Loss={:6.3f} Train ACC={:6.3f} Test ACC={:6.3f}' 191 | print(fmt.format(epoch + 1, train_loss, train_acc, test_acc)) 192 | 193 | train_acc_history.append(train_acc.item()) 194 | test_acc_history.append(test_acc.item()) 195 | 196 | df_train.loc[episode + 1] = train_acc_history 197 | df_train.to_csv(train_history_path) 198 | df_test.loc[episode + 1] = test_acc_history 199 | df_test.to_csv(test_history_path) 200 | 201 | fmt = 'Episode {:03d}: train_loss={:6.4f} train_acc={:6.2f} test_acc={:6.2f}' 202 | print(fmt.format(episode, train_loss, train_acc_history[-1] * 100, test_acc_history[-1] * 100)) 203 | 204 | fmt = 'Final Results: Acc={:5.2f} Std={:5.2f}' 205 | print(fmt.format(df_test.mean()[-1] * 100, 1.96 * df_test.std()[-1] / np.sqrt(n_episodes) * 100)) 206 | 207 | print('Saved history to:') 208 | print(train_history_path) 209 | print(test_history_path) 210 | df_train.to_csv(train_history_path) 211 | df_test.to_csv(test_history_path) 212 | 213 | 214 | if __name__ == '__main__': 215 | np.random.seed(10) 216 | params = parse_args() 217 | 218 | targets = params.target_dataset 219 | if targets is None: 220 | targets = [targets] 221 | elif len(targets) > 1: 222 | print('#' * 80) 223 | print("Running finetune iteratively for multiple target datasets: {}".format(targets)) 224 | print('#' * 80) 225 | 226 | for target in targets: 227 | params.target_dataset = target 228 | main(params) 229 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | 4 | SOURCES=("miniImageNet" "tieredImageNet" "ImageNet") 5 | SOURCE=${SOURCES[2]} 6 | 7 | TARGETS=("CropDisease" "ISIC" "EuroSAT" "ChestX" "places" "cub" "plantae" "cars") 8 | TARGET=${TARGETS[0]} 9 | 10 | # BACKBONE=resnet10 # for mini 11 | BACKBONE=resnet18 # for tiered and full imagenet 12 | 13 | N_SHOT=5 14 | 15 | 16 | # Source SL 17 | python finetune.py --ls --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "base" --tag "default" --n_shot $N_SHOT 18 | 19 | # Target SSL 20 | python finetune.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default" --n_shot $N_SHOT 21 | python finetune.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default" --n_shot $N_SHOT 22 | 23 | # MSL (Source SL + Target SSL) 24 | python finetune.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78" --n_shot $N_SHOT 25 | python finetune.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78" --n_shot $N_SHOT 26 | 27 | # Two-Stage SSL (Source SL -> Target SSL) 28 | python finetune.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default" --n_shot $N_SHOT 29 | python finetune.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default" --n_shot $N_SHOT 30 | 31 | # Two-Stage MSL (Source SL -> Source SL + Target SSL) 32 | python finetune.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78" --n_shot $N_SHOT 33 | python finetune.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78" --n_shot $N_SHOT 34 | -------------------------------------------------------------------------------- /io_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description='CD-FSL') 6 | parser.add_argument('--dataset', default='miniImageNet', help='training base model') 7 | parser.add_argument('--backbone', default='resnet10', help='Refer to backbone._backbone_class_map') 8 | parser.add_argument('--model', default='base', help='backbone architecture') 9 | 10 | parser.add_argument('--num_classes', default=200, type=int, 11 | help='deprecated. Value is overwritten based on `target_dataset`') 12 | 13 | parser.add_argument('--source_dataset', default='miniImageNet') 14 | parser.add_argument('--target_dataset', type=str, 15 | nargs='+') # replaces dataset_names / HOTFIX: changed to list to allow for multiple targets with one CLI command 16 | parser.add_argument('--imagenet_pretrained', action="store_true", help='Use ImageNet pretrained weights') 17 | 18 | # Split related params 19 | parser.add_argument('--unlabeled_ratio', default=20, type=int, 20 | help='Percentage of dataset used for unlabeled split') 21 | parser.add_argument('--split_seed', default=1, type=int, 22 | help='Random seed used for split. If set to 1 and unlabeled_ratio==20, will use split defined by STARTUP') 23 | 24 | # Pre-train params (determines pre-trained model output directory) 25 | # These must be specified during evaluation and fine-tuning to select pre-trained model 26 | parser.add_argument('--pls', action='store_true', 27 | help='Second-step pre-training on top of model trained with source labeled data') 28 | parser.add_argument('--put', action='store_true', 29 | help='Second-step pre-training on top of model trained with target unlabeled data') 30 | parser.add_argument('--pmsl', action='store_true', 31 | help='Second-step pre-training on top of model trained with MSL (instead of pls_put)') 32 | parser.add_argument('--ls', action='store_true', help='Use labeled source data for pre-training') 33 | parser.add_argument('--us', action='store_true', help='Use unlabeled source data for pre-training') 34 | parser.add_argument('--ut', action='store_true', help='Use unlabeled target data for pre-training') 35 | parser.add_argument('--tag', default='default', type=str, 36 | help='Tag used to differentiate output directories for pre-trained models') # similar to aug_mode 37 | parser.add_argument('--pls_tag', default=None, type=str, help='Deprecated. Please use `previous_tag`.') 38 | parser.add_argument('--previous_tag', default=None, type=str, 39 | help='Tag of pre-trained previous model for pls, put, pmsl. Uses --tag by default.') 40 | 41 | # Pre-train params (non-identifying, i.e., does not affect output directory) 42 | # You must specify --tag to differentiate models with different non-identifying parameters) 43 | parser.add_argument('--augmentation', default='strong', type=str, 44 | help="Augmentation used for pre-training {'base', 'strong'}") # similar to aug_mode 45 | parser.add_argument('--batch_size', default=64, type=int, 46 | help='Batch size for pre-training.') # similar to aug_mode 47 | parser.add_argument('--ls_batch_size', default=None, type=int, 48 | help='Batch size for LS source pre-training.') # if None, reverts to batch_size 49 | parser.add_argument('--lr', default=None, type=float, help='LR for pre-training.') 50 | parser.add_argument('--gamma', default=0.5, type=float, help='Gamma value for {LS,US} + UT.') # similar to aug_mode 51 | parser.add_argument('--gamma_schedule', default=None, type=str, help='None | "linear"') 52 | parser.add_argument('--epochs', default=1000, type=int, help='Pre-training epochs.') # similar to aug_mode 53 | parser.add_argument('--model_save_interval', default=50, type=int, 54 | help='Save model state every N epochs during pre-training.') # similar to aug_mode 55 | parser.add_argument('--optimizer', default=None, type=str, 56 | help="Optimizer used during pre-training {'sgd', 'adam'}. Default if None") # similar to aug_mode 57 | parser.add_argument('--scheduler', default="MultiStepLR", type=str, 58 | help="Scheduler to use (refer to `pretrain.py`)") 59 | parser.add_argument('--scheduler_milestones', default=[400, 600, 800], type=int, nargs="+", 60 | help="Milestones for (Repeated)MultiStepLR scheduler") 61 | parser.add_argument('--num_workers', default=None, type=int) 62 | 63 | # Fine-tune params 64 | parser.add_argument('--n_shot', default=5, type=int, help='number of labeled data in each class, same as n_support') 65 | parser.add_argument('--n_way', default=5, type=int) 66 | parser.add_argument('--n_query_shot', default=15, type=int) 67 | 68 | parser.add_argument('--ft_tag', default='default', type=str, 69 | help='Tag used to differentiate output directories for fine-tuned models') 70 | parser.add_argument('--ft_head', default='linear', help='See `model.classifier_head.CLASSIFIER_HEAD_CLASS_MAP`') 71 | parser.add_argument('--ft_epochs', default=100, type=int) 72 | parser.add_argument('--ft_pretrain_epoch', default=None, type=int) 73 | parser.add_argument('--ft_batch_size', default=4, type=int) 74 | parser.add_argument('--ft_lr', default=1e-2, type=float, help='Learning rate for fine-tuning') 75 | parser.add_argument('--ft_augmentation', default=None, type=str, 76 | help="Augmentation used for fine-tuning {None, 'base', 'strong'}") 77 | parser.add_argument('--ft_parts', default='head', type=str, help="Where to fine-tune: {'full', 'body', 'head'}") 78 | parser.add_argument('--ft_features', default=None, type=str, 79 | help='Specify which features to use from the base model (see model/base.py)') 80 | parser.add_argument('--ft_intermediate_test', action='store_true', help='Evaluate on query set during fine-tuning') 81 | parser.add_argument('--ft_episode_seed', default=0, type=int) 82 | 83 | # Model parameters (make sure to prepend with `model_`) 84 | parser.add_argument('--model_simclr_projection_dim', default=128, type=int) 85 | parser.add_argument('--model_simclr_temperature', default=1.0, type=float) 86 | 87 | # Batch normalization (likely deprecated) 88 | parser.add_argument('--track_bn', action='store_true', help='tracking BN stats') 89 | parser.add_argument('--freeze_bn', action='store_true', 90 | help='freeze bn stats, i.e., use accumulated stats of pretrained model during inference. Note, track_bn must be on to do this.') 91 | parser.add_argument('--reinit_bn_stats', action='store_true', 92 | help='Re-initialize BN running statistics every iteration') 93 | 94 | params = parser.parse_args() 95 | 96 | # Double-checking parameters 97 | if params.freeze_bn and not params.track_bn: 98 | raise AssertionError('Invalid parameter combination') 99 | if params.reinit_bn_stats: 100 | raise AssertionError('Plz consult w/ anon author.') 101 | if params.ut and not params.target_dataset: 102 | raise AssertionError('Invalid parameter combination') 103 | if params.ft_parts not in ["head", "body", "full"]: 104 | raise AssertionError('Invalid params.ft_parts: {}'.format(params.ft_parts)) 105 | 106 | # pls, put, pmsl parameters 107 | if sum((params.pls, params.put, params.pmsl)) > 1: 108 | raise AssertionError('You may only specify one of params.{pls,put,pmsl}') 109 | 110 | # Assign num_classes (*_new) 111 | if params.source_dataset == 'miniImageNet': 112 | params.num_classes = 64 113 | elif params.source_dataset == 'tieredImageNet': 114 | params.num_classes = 351 115 | elif params.source_dataset == 'ImageNet': 116 | params.num_classes = 1000 117 | elif params.source_dataset == 'CropDisease': 118 | params.num_classes = 38 119 | elif params.source_dataset == 'EuroSAT': 120 | params.num_classes = 10 121 | elif params.source_dataset == 'ISIC': 122 | params.num_classes = 7 123 | elif params.source_dataset == 'ChestX': 124 | params.num_classes = 7 125 | elif params.source_dataset == 'places': 126 | params.num_classes = 16 127 | elif params.source_dataset == 'plantae': 128 | params.num_classes = 69 129 | elif params.source_dataset == 'cars': 130 | params.num_classes = 196 131 | elif params.source_dataset == 'cub': 132 | params.num_classes = 200 133 | elif params.source_dataset == 'none': 134 | params.num_classes = 5 135 | else: 136 | raise ValueError('Invalid `source_dataset` argument: {}'.format(params.source_dataset)) 137 | 138 | # Default workers 139 | if params.num_workers is None: 140 | params.num_workers = 3 141 | if params.target_dataset in ["cars", "cub", "plantae"]: 142 | params.num_workers = 4 143 | if params.target_dataset in ["ChestX"]: 144 | params.num_workers = 6 145 | print("Using default num_workers={}".format(params.num_workers)) 146 | params.num_workers *= 2 # TEMP 147 | 148 | # Default optimizers 149 | if params.optimizer is None: 150 | if params.model in ['simsiam', 'byol']: 151 | params.optimizer = 'adam' if not params.ls else 'sgd' 152 | else: 153 | params.optimizer = 'sgd' 154 | print("Using default optimizer for model {}: {}".format(params.model, params.optimizer)) 155 | 156 | # Default learning rate 157 | if params.lr is None: 158 | if params.model in ['simsiam', 'byol']: 159 | params.lr = 3e-4 if not params.ls else 0.1 160 | elif params.model in ['moco']: 161 | params.lr = 0.01 162 | else: 163 | params.lr = 0.1 164 | print("Using default lr for model {}: {}".format(params.model, params.lr)) 165 | 166 | # Default ls_batch_size 167 | if params.ls_batch_size is None: 168 | params.ls_batch_size = params.batch_size 169 | 170 | params.ft_train_body = params.ft_parts in ['body', 'full'] 171 | params.ft_train_head = params.ft_parts in ['head', 'full'] 172 | 173 | if params.previous_tag is None: 174 | if params.pls_tag: # support for deprecated argument (changed 5/8/2022) 175 | print("Warning: params.pls_tag is deprecated. Please use params.previous_tag") 176 | params.previous_tag = params.pls_tag 177 | elif params.pls or params.put or params.pmsl: 178 | print("Using params.tag for params.previous_tag") 179 | params.previous_tag = params.tag 180 | 181 | return params 182 | -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from . import meta_template 2 | from . import maml 3 | from . import boil 4 | from . import protonet 5 | 6 | 7 | -------------------------------------------------------------------------------- /methods/baselinefinetune.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from methods.meta_template import MetaTemplate 8 | 9 | class BaselineFinetune(MetaTemplate): 10 | def __init__(self, model_func, n_way, n_support, loss_type = "softmax"): 11 | super(BaselineFinetune, self).__init__( model_func, n_way, n_support) 12 | self.loss_type = loss_type 13 | 14 | def set_forward(self, x, is_feature = True): 15 | return self.set_forward_adaptation(x, is_feature); #Baseline always do adaptation 16 | 17 | def set_forward_adaptation(self, x, is_feature = True): 18 | assert is_feature == True, 'Baseline only support testing with feature' 19 | z_support, z_query = self.parse_feature(x,is_feature) 20 | 21 | z_support = z_support.contiguous().view(self.n_way* self.n_support, -1 ) 22 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 23 | 24 | y_support = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support )) 25 | y_support = Variable(y_support.cuda()) 26 | 27 | if self.loss_type == 'softmax': 28 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 29 | 30 | elif self.loss_type == 'dist': 31 | linear_clf = backbone.distLinear(self.feat_dim, self.n_way) 32 | 33 | 34 | linear_clf = linear_clf.cuda() 35 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 36 | 37 | loss_function = nn.CrossEntropyLoss() 38 | loss_function = loss_function.cuda() 39 | 40 | batch_size = 4 41 | support_size = self.n_way* self.n_support 42 | for epoch in range(100): 43 | rand_id = np.random.permutation(support_size) 44 | for i in range(0, support_size , batch_size): 45 | set_optimizer.zero_grad() 46 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda() 47 | z_batch = z_support[selected_id] 48 | 49 | scores = linear_clf(z_batch) 50 | 51 | y_batch = y_support[selected_id] 52 | 53 | loss = loss_function(scores,y_batch) 54 | loss.backward() 55 | set_optimizer.step() 56 | 57 | scores = linear_clf(z_query) 58 | return scores 59 | 60 | def set_forward_loss(self,x): 61 | raise ValueError('Baseline predict on pretrained feature and do not support finetune backbone') 62 | -------------------------------------------------------------------------------- /methods/baselinetrain.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import utils 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | class BaselineTrain(nn.Module): 11 | def __init__(self, model_func, num_class, loss_type = 'softmax'): 12 | super(BaselineTrain, self).__init__() 13 | self.feature = model_func 14 | 15 | if loss_type == 'softmax': 16 | self.classifier = nn.Linear(self.feature.final_feat_dim, num_class) 17 | self.classifier.bias.data.fill_(0) 18 | elif loss_type == 'dist': #Baseline ++ 19 | self.classifier = backbone.distLinear(self.feature.final_feat_dim, num_class) 20 | 21 | self.loss_type = loss_type #'softmax' #'dist' 22 | self.num_class = num_class 23 | self.loss_fn = nn.CrossEntropyLoss() 24 | self.top1 = utils.AverageMeter() 25 | 26 | def forward(self,x): 27 | x = Variable(x.cuda()) 28 | out = self.feature.forward(x) 29 | scores = self.classifier.forward(out) 30 | return scores 31 | 32 | def forward_loss(self, x, y): 33 | y = Variable(y.cuda()) 34 | 35 | scores = self.forward(x) 36 | 37 | _, predicted = torch.max(scores.data, 1) 38 | correct = predicted.eq(y.data).cpu().sum() 39 | self.top1.update(correct.item()*100 / (y.size(0)+0.0), y.size(0)) 40 | 41 | return self.loss_fn(scores, y) 42 | 43 | def train_loop(self, epoch, train_loader, optimizer, scheduler): 44 | print_freq = 10 45 | avg_loss=0 46 | for i, (x,y) in enumerate(train_loader): 47 | optimizer.zero_grad() 48 | loss = self.forward_loss(x, y) 49 | loss.backward() 50 | optimizer.step() 51 | 52 | avg_loss = avg_loss+loss.item() 53 | if i % print_freq==0: 54 | #print(optimizer.state_dict()['param_groups'][0]['lr']) 55 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} | Top1 Val {:f} | Top1 Avg {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1), self.top1.val, self.top1.avg)) 56 | if scheduler is not None: 57 | scheduler.step() 58 | 59 | def test_loop(self, val_loader): 60 | return -1 #no validation, just save model during iteration 61 | 62 | -------------------------------------------------------------------------------- /methods/boil.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | 11 | class BOIL(MetaTemplate): 12 | def __init__(self, model_func, n_way, n_support, approx = False): 13 | super(BOIL, self).__init__( model_func, n_way, n_support, change_way = False) 14 | 15 | self.loss_fn = nn.CrossEntropyLoss() 16 | self.classifier = backbone.Linear_fw(self.feat_dim, n_way) 17 | self.classifier.bias.data.fill_(0) 18 | 19 | self.n_task = 4 20 | self.task_update_num = 1 21 | self.train_lr = 0.5 22 | self.approx = approx #first order approx. 23 | 24 | def forward(self,x): 25 | out = self.feature.forward(x) 26 | scores = self.classifier.forward(out) 27 | return scores 28 | 29 | def set_forward(self, x, is_feature = False): 30 | assert is_feature == False, 'BOIL do not support fixed feature' 31 | 32 | x = x.cuda() 33 | x_var = Variable(x) 34 | x_a_i = x_var[:,:self.n_support,:,:,:].contiguous().view( self.n_way* self.n_support, *x.size()[2:]) #support data 35 | x_b_i = x_var[:,self.n_support:,:,:,:].contiguous().view( self.n_way* self.n_query, *x.size()[2:]) #query data 36 | y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_support ) )).cuda() #label for support data 37 | 38 | fast_parameters = list(self.parameters()) #the first gradient calcuated in line 45 is based on original weight 39 | len_parameters = len(fast_parameters) 40 | for weight in self.parameters(): 41 | weight.fast = None 42 | self.zero_grad() 43 | 44 | for task_step in range(self.task_update_num): 45 | scores = self.forward(x_a_i) 46 | set_loss = self.loss_fn( scores, y_a_i) 47 | grad = torch.autograd.grad(set_loss, fast_parameters, create_graph=True) #build full graph support gradient of gradient 48 | if self.approx: 49 | grad = [ g.detach() for g in grad ] #do not calculate gradient of gradient if using first order approximation 50 | fast_parameters = [] 51 | for k, weight in enumerate(self.parameters()): 52 | #for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py 53 | if k == len_parameters-2 or k == len_parameters-1: 54 | weight.fast = weight 55 | else: 56 | if weight.fast is None: 57 | weight.fast = weight - self.train_lr * grad[k] #create weight.fast 58 | else: 59 | weight.fast = weight.fast - self.train_lr * grad[k] #create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast 60 | fast_parameters.append(weight.fast) #gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts 61 | 62 | scores = self.forward(x_b_i) 63 | return scores 64 | 65 | def set_forward_adaptation(self,x, is_feature = False): #overwrite parrent function 66 | raise ValueError('BOIL performs further adapation simply by increasing task_upate_num') 67 | 68 | def set_forward_loss(self, x): 69 | scores = self.set_forward(x, is_feature = False) 70 | y_b_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_query ) )).cuda() 71 | loss = self.loss_fn(scores, y_b_i) 72 | 73 | return loss 74 | 75 | def train_loop(self, epoch, train_loader, optimizer): #overwrite parrent function 76 | print_freq = 10 77 | avg_loss=0 78 | task_count = 0 79 | loss_all = [] 80 | optimizer.zero_grad() 81 | 82 | #train 83 | for i, (x,_) in enumerate(train_loader): 84 | self.n_query = x.size(1) - self.n_support 85 | assert self.n_way == x.size(0), "BOIL do not support way change" 86 | 87 | loss = self.set_forward_loss(x) 88 | avg_loss = avg_loss+loss.item()#.data[0] 89 | loss_all.append(loss) 90 | 91 | task_count += 1 92 | 93 | if task_count == self.n_task: #BOIL update several tasks at one time 94 | loss_q = torch.stack(loss_all).sum(0) 95 | loss_q.backward() 96 | 97 | optimizer.step() 98 | task_count = 0 99 | loss_all = [] 100 | optimizer.zero_grad() 101 | if i % print_freq==0: 102 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1))) 103 | 104 | def test_loop(self, test_loader, return_std = False): #overwrite parrent function 105 | correct =0 106 | count = 0 107 | acc_all = [] 108 | 109 | iter_num = len(test_loader) 110 | for i, (x,_) in enumerate(test_loader): 111 | self.n_query = x.size(1) - self.n_support 112 | assert self.n_way == x.size(0), "BOIL do not support way change" 113 | correct_this, count_this = self.correct(x) 114 | acc_all.append(correct_this/ count_this *100 ) 115 | 116 | acc_all = np.asarray(acc_all) 117 | acc_mean = np.mean(acc_all) 118 | acc_std = np.std(acc_all) 119 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) 120 | if return_std: 121 | return acc_mean, acc_std 122 | else: 123 | return acc_mean 124 | 125 | def get_logits(self, x): 126 | self.n_query = x.size(1) - self.n_support 127 | logits = self.set_forward(x) 128 | return logits -------------------------------------------------------------------------------- /methods/byol.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from functools import wraps 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from torchvision import transforms as T 10 | from .baselinetrain import BaselineTrain 11 | 12 | # helper functions 13 | 14 | def default(val, def_val): 15 | return def_val if val is None else val 16 | 17 | def flatten(t): 18 | return t.reshape(t.shape[0], -1) 19 | 20 | def singleton(cache_key): 21 | def inner_fn(fn): 22 | @wraps(fn) 23 | def wrapper(self, *args, **kwargs): 24 | instance = getattr(self, cache_key) 25 | if instance is not None: 26 | return instance 27 | 28 | instance = fn(self, *args, **kwargs) 29 | setattr(self, cache_key, instance) 30 | return instance 31 | return wrapper 32 | return inner_fn 33 | 34 | def get_module_device(module): 35 | return next(module.parameters()).device 36 | 37 | def set_requires_grad(model, val): 38 | for p in model.parameters(): 39 | p.requires_grad = val 40 | 41 | # loss fn 42 | 43 | def loss_fn(x, y): 44 | x = F.normalize(x, dim=-1, p=2) 45 | y = F.normalize(y, dim=-1, p=2) 46 | return 2 - 2 * (x * y).sum(dim=-1) 47 | 48 | # augmentation utils 49 | 50 | class RandomApply(nn.Module): 51 | def __init__(self, fn, p): 52 | super().__init__() 53 | self.fn = fn 54 | self.p = p 55 | def forward(self, x): 56 | if random.random() > self.p: 57 | return x 58 | return self.fn(x) 59 | 60 | # exponential moving average 61 | 62 | class EMA(): 63 | def __init__(self, beta): 64 | super().__init__() 65 | self.beta = beta 66 | 67 | def update_average(self, old, new): 68 | if old is None: 69 | return new 70 | return old * self.beta + (1 - self.beta) * new 71 | 72 | def update_moving_average(ema_updater, ma_model, current_model): 73 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 74 | old_weight, up_weight = ma_params.data, current_params.data 75 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 76 | 77 | # MLP class for projector and predictor 78 | 79 | class MLP(nn.Module): 80 | def __init__(self, dim, projection_size, hidden_size = 4096): 81 | super().__init__() 82 | self.net = nn.Sequential( 83 | nn.Linear(dim, hidden_size), 84 | nn.BatchNorm1d(hidden_size), 85 | nn.ReLU(inplace=True), 86 | nn.Linear(hidden_size, projection_size) 87 | ) 88 | 89 | def forward(self, x): 90 | return self.net(x) 91 | 92 | # a wrapper class for the base neural network 93 | # will manage the interception of the hidden layer output 94 | # and pipe it into the projecter and predictor nets 95 | 96 | class NetWrapper(nn.Module): 97 | def __init__(self, net, projection_size, projection_hidden_size, layer = -1): # default layer = -2 since network includes classifier. Ours does not have classifier. 98 | super().__init__() 99 | self.net = net 100 | self.layer = layer 101 | 102 | self.projector = None 103 | self.projection_size = projection_size 104 | self.projection_hidden_size = projection_hidden_size 105 | 106 | self.hidden = {} 107 | self.hook_registered = False 108 | 109 | def _find_layer(self): 110 | if type(self.layer) == str: 111 | modules = dict([*self.net.named_modules()]) 112 | return modules.get(self.layer, None) 113 | elif type(self.layer) == int: 114 | children = [*self.net.children()] 115 | return children[self.layer] 116 | return None 117 | 118 | def _hook(self, _, input, output): 119 | device = input[0].device 120 | self.hidden[device] = flatten(output) 121 | 122 | def _register_hook(self): 123 | layer = self._find_layer() 124 | assert layer is not None, f'hidden layer ({self.layer}) not found' 125 | handle = layer.register_forward_hook(self._hook) 126 | self.hook_registered = True 127 | 128 | @singleton('projector') 129 | def _get_projector(self, hidden): 130 | _, dim = hidden.shape 131 | projector = MLP(dim, self.projection_size, self.projection_hidden_size) 132 | return projector.to(hidden) 133 | 134 | def get_representation(self, x): 135 | if self.layer == -1: 136 | return self.net(x) 137 | 138 | if not self.hook_registered: 139 | self._register_hook() 140 | 141 | self.hidden.clear() 142 | _ = self.net(x) 143 | hidden = self.hidden[x.device] 144 | self.hidden.clear() 145 | 146 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output' 147 | return hidden 148 | 149 | def forward(self, x, return_projection = True): 150 | representation = self.get_representation(x) 151 | 152 | if not return_projection: 153 | return representation 154 | 155 | projector = self._get_projector(representation) 156 | projection = projector(representation) 157 | return projection, representation 158 | 159 | # main class 160 | 161 | class BYOL(nn.Module): 162 | def __init__( 163 | self, 164 | net, 165 | image_size, 166 | hidden_layer = -1, 167 | projection_size = 256, 168 | projection_hidden_size = 4096, 169 | augment_fn = None, 170 | augment_fn2 = None, 171 | moving_average_decay = 0.99, 172 | use_momentum = True 173 | ): 174 | super().__init__() 175 | assert isinstance(net, BaselineTrain) 176 | self.net = net 177 | 178 | # default SimCLR augmentation 179 | 180 | # DEFAULT_AUG = torch.nn.Sequential( 181 | # RandomApply( 182 | # T.ColorJitter(0.8, 0.8, 0.8, 0.2), 183 | # p = 0.3 184 | # ), 185 | # T.RandomGrayscale(p=0.2), 186 | # T.RandomHorizontalFlip(), 187 | # RandomApply( 188 | # T.GaussianBlur((3, 3), (1.0, 2.0)), 189 | # p = 0.2 190 | # ), 191 | # T.RandomResizedCrop((image_size, image_size)), 192 | # T.Normalize( 193 | # mean=torch.tensor([0.485, 0.456, 0.406]), 194 | # std=torch.tensor([0.229, 0.224, 0.225])), 195 | # ) 196 | 197 | # self.augment1 = default(augment_fn, DEFAULT_AUG) 198 | # self.augment2 = default(augment_fn2, self.augment1) 199 | 200 | self.online_encoder = NetWrapper(net.feature, projection_size, projection_hidden_size, layer=hidden_layer) 201 | 202 | self.use_momentum = use_momentum 203 | self.target_encoder = None 204 | self.target_ema_updater = EMA(moving_average_decay) 205 | 206 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) 207 | 208 | # get device of network and make wrapper same device 209 | device = get_module_device(net.feature) 210 | self.to(device) 211 | 212 | # send a mock image tensor to instantiate singleton parameters 213 | self.forward(torch.randn(2, 3, image_size, image_size, device=device), torch.randn(2, 3, image_size, image_size, device=device)) 214 | 215 | @singleton('target_encoder') 216 | def _get_target_encoder(self): 217 | target_encoder = copy.deepcopy(self.online_encoder) 218 | set_requires_grad(target_encoder, False) 219 | return target_encoder 220 | 221 | def reset_moving_average(self): 222 | del self.target_encoder 223 | self.target_encoder = None 224 | 225 | def update_moving_average(self): 226 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' 227 | assert self.target_encoder is not None, 'target encoder has not been created yet' 228 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) 229 | 230 | def forward( 231 | self, 232 | x1, x2, 233 | return_embedding = False, 234 | return_projection = True 235 | ): 236 | assert not (self.training and x1.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer' 237 | 238 | if return_embedding: 239 | return self.online_encoder(torch.cat([x1, x2], 0), return_projection = return_projection) 240 | 241 | # image_one, image_two = self.augment1(x), self.augment2(x) 242 | 243 | online_proj_one, _ = self.online_encoder(x1) 244 | online_proj_two, _ = self.online_encoder(x2) 245 | 246 | online_pred_one = self.online_predictor(online_proj_one) 247 | online_pred_two = self.online_predictor(online_proj_two) 248 | 249 | with torch.no_grad(): 250 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder 251 | target_proj_one, _ = target_encoder(x1) 252 | target_proj_two, _ = target_encoder(x2) 253 | target_proj_one.detach_() 254 | target_proj_two.detach_() 255 | 256 | loss_one = loss_fn(online_pred_one, target_proj_two.detach()) 257 | loss_two = loss_fn(online_pred_two, target_proj_one.detach()) 258 | 259 | loss = loss_one + loss_two 260 | return loss.mean() 261 | -------------------------------------------------------------------------------- /methods/maml.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/dragen1860/MAML-Pytorch and https://github.com/katerakelly/pytorch-maml 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | 11 | class MAML(MetaTemplate): 12 | def __init__(self, model_func, n_way, n_support, approx = False): 13 | super(MAML, self).__init__( model_func, n_way, n_support, change_way = False) 14 | 15 | self.loss_fn = nn.CrossEntropyLoss() 16 | self.classifier = backbone.Linear_fw(self.feat_dim, n_way) 17 | self.classifier.bias.data.fill_(0) 18 | 19 | self.n_task = 4 20 | self.task_update_num = 1 21 | self.train_lr = 0.5 22 | self.approx = approx #first order approx. 23 | 24 | def forward(self,x): 25 | out = self.feature.forward(x) 26 | scores = self.classifier.forward(out) 27 | return scores 28 | 29 | def set_forward(self, x, is_feature = False): 30 | assert is_feature == False, 'MAML do not support fixed feature' 31 | 32 | x = x.cuda() 33 | x_var = Variable(x) 34 | x_a_i = x_var[:,:self.n_support,:,:,:].contiguous().view( self.n_way* self.n_support, *x.size()[2:]) #support data 35 | x_b_i = x_var[:,self.n_support:,:,:,:].contiguous().view( self.n_way* self.n_query, *x.size()[2:]) #query data 36 | y_a_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_support ) )).cuda() #label for support data 37 | 38 | fast_parameters = list(self.parameters()) #the first gradient calcuated in line 45 is based on original weight 39 | for weight in self.parameters(): 40 | weight.fast = None 41 | self.zero_grad() 42 | 43 | for task_step in range(self.task_update_num): 44 | scores = self.forward(x_a_i) 45 | set_loss = self.loss_fn(scores, y_a_i) 46 | grad = torch.autograd.grad(set_loss, fast_parameters, create_graph=True) #build full graph support gradient of gradient 47 | if self.approx: 48 | grad = [ g.detach() for g in grad ] #do not calculate gradient of gradient if using first order approximation 49 | fast_parameters = [] 50 | for k, weight in enumerate(self.parameters()): 51 | #for usage of weight.fast, please see Linear_fw, Conv_fw in backbone.py 52 | if weight.fast is None: 53 | weight.fast = weight - self.train_lr * grad[k] #create weight.fast 54 | else: 55 | weight.fast = weight.fast - self.train_lr * grad[k] #create an updated weight.fast, note the '-' is not merely minus value, but to create a new weight.fast 56 | fast_parameters.append(weight.fast) #gradients calculated in line 45 are based on newest fast weight, but the graph will retain the link to old weight.fasts 57 | 58 | scores = self.forward(x_b_i) 59 | return scores 60 | 61 | def set_forward_adaptation(self,x, is_feature = False): #overwrite parrent function 62 | raise ValueError('MAML performs further adapation simply by increasing task_upate_num') 63 | 64 | 65 | def set_forward_loss(self, x): 66 | scores = self.set_forward(x, is_feature = False) 67 | y_b_i = Variable( torch.from_numpy( np.repeat(range( self.n_way ), self.n_query ) )).cuda() 68 | loss = self.loss_fn(scores, y_b_i) 69 | 70 | return loss 71 | 72 | def train_loop(self, epoch, train_loader, optimizer): #overwrite parrent function 73 | print_freq = 10 74 | avg_loss=0 75 | task_count = 0 76 | loss_all = [] 77 | optimizer.zero_grad() 78 | 79 | #train 80 | for i, (x,_) in enumerate(train_loader): 81 | self.n_query = x.size(1) - self.n_support 82 | assert self.n_way == x.size(0), "MAML do not support way change" 83 | 84 | loss = self.set_forward_loss(x) 85 | avg_loss = avg_loss+loss.item()#.data[0] 86 | loss_all.append(loss) 87 | 88 | task_count += 1 89 | 90 | if task_count == self.n_task: #MAML update several tasks at one time 91 | loss_q = torch.stack(loss_all).sum(0) 92 | loss_q.backward() 93 | 94 | optimizer.step() 95 | task_count = 0 96 | loss_all = [] 97 | optimizer.zero_grad() 98 | if i % print_freq==0: 99 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1))) 100 | 101 | def test_loop(self, test_loader, return_std = False): #overwrite parrent function 102 | correct =0 103 | count = 0 104 | acc_all = [] 105 | 106 | iter_num = len(test_loader) 107 | for i, (x,_) in enumerate(test_loader): 108 | self.n_query = x.size(1) - self.n_support 109 | assert self.n_way == x.size(0), "MAML do not support way change" 110 | correct_this, count_this = self.correct(x) 111 | acc_all.append(correct_this/ count_this *100 ) 112 | 113 | acc_all = np.asarray(acc_all) 114 | acc_mean = np.mean(acc_all) 115 | acc_std = np.std(acc_all) 116 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) 117 | if return_std: 118 | return acc_mean, acc_std 119 | else: 120 | return acc_mean 121 | 122 | def get_logits(self, x): 123 | self.n_query = x.size(1) - self.n_support 124 | logits = self.set_forward(x) 125 | return logits -------------------------------------------------------------------------------- /methods/meta_template.py: -------------------------------------------------------------------------------- 1 | import backbone 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import utils 8 | from abc import abstractmethod 9 | 10 | class MetaTemplate(nn.Module): 11 | def __init__(self, model_func, n_way, n_support, change_way = True): 12 | super(MetaTemplate, self).__init__() 13 | self.n_way = n_way 14 | self.n_support = n_support 15 | self.n_query = 15 #(If -1, change depends on input) 16 | self.feature = model_func 17 | self.feat_dim = self.feature.final_feat_dim 18 | self.change_way = change_way #some methods allow different_way classification during training and test 19 | 20 | @abstractmethod 21 | def set_forward(self,x,is_feature): 22 | pass 23 | 24 | @abstractmethod 25 | def set_forward_loss(self, x): 26 | pass 27 | 28 | def forward(self,x): 29 | out = self.feature.forward(x) 30 | return out 31 | 32 | def parse_feature(self,x,is_feature): 33 | x = Variable(x.cuda()) 34 | if is_feature: 35 | z_all = x 36 | else: 37 | x = x.contiguous().view( self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 38 | z_all = self.feature.forward(x) 39 | z_all = z_all.view( self.n_way, self.n_support + self.n_query, -1) 40 | z_support = z_all[:, :self.n_support] 41 | z_query = z_all[:, self.n_support:] 42 | 43 | return z_support, z_query 44 | 45 | def correct(self, x): 46 | scores = self.set_forward(x) 47 | y_query = np.repeat(range( self.n_way ), self.n_query ) 48 | 49 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 50 | topk_ind = topk_labels.cpu().numpy() 51 | top1_correct = np.sum(topk_ind[:,0] == y_query) 52 | return float(top1_correct), len(y_query) 53 | 54 | def train_loop(self, epoch, train_loader, optimizer): 55 | print_freq = 10 56 | 57 | avg_loss=0 58 | for i, (x,_ ) in enumerate(train_loader): 59 | self.n_query = x.size(1) - self.n_support 60 | if self.change_way: 61 | self.n_way = x.size(0) 62 | optimizer.zero_grad() 63 | loss = self.set_forward_loss( x ) 64 | loss.backward() 65 | optimizer.step() 66 | avg_loss = avg_loss+loss.item() 67 | 68 | if i % print_freq==0: 69 | #print(optimizer.state_dict()['param_groups'][0]['lr']) 70 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1))) 71 | 72 | def test_loop(self, test_loader, record = None): 73 | correct =0 74 | count = 0 75 | acc_all = [] 76 | 77 | iter_num = len(test_loader) 78 | for i, (x,_) in enumerate(test_loader): 79 | self.n_query = x.size(1) - self.n_support 80 | if self.change_way: 81 | self.n_way = x.size(0) 82 | correct_this, count_this = self.correct(x) 83 | acc_all.append(correct_this/ count_this*100 ) 84 | 85 | acc_all = np.asarray(acc_all) 86 | acc_mean = np.mean(acc_all) 87 | acc_std = np.std(acc_all) 88 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) 89 | 90 | return acc_mean 91 | 92 | def set_forward_adaptation(self, x, is_feature = True): #further adaptation, default is fixing feature and train a new softmax clasifier 93 | assert is_feature == True, 'Feature is fixed in further adaptation' 94 | z_support, z_query = self.parse_feature(x,is_feature) 95 | 96 | z_support = z_support.contiguous().view(self.n_way* self.n_support, -1 ) 97 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 98 | 99 | y_support = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support )) 100 | y_support = Variable(y_support.cuda()) 101 | 102 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 103 | linear_clf = linear_clf.cuda() 104 | 105 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 106 | 107 | loss_function = nn.CrossEntropyLoss() 108 | loss_function = loss_function.cuda() 109 | 110 | batch_size = 4 111 | support_size = self.n_way* self.n_support 112 | for epoch in range(100): 113 | rand_id = np.random.permutation(support_size) 114 | for i in range(0, support_size , batch_size): 115 | set_optimizer.zero_grad() 116 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda() 117 | z_batch = z_support[selected_id] 118 | y_batch = y_support[selected_id] 119 | scores = linear_clf(z_batch) 120 | loss = loss_function(scores,y_batch) 121 | loss.backward() 122 | set_optimizer.step() 123 | 124 | scores = linear_clf(z_query) 125 | return scores 126 | -------------------------------------------------------------------------------- /methods/protonet.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/jakesnell/prototypical-networks 2 | 3 | import backbone 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from methods.meta_template import MetaTemplate 10 | 11 | class ProtoNet(MetaTemplate): 12 | def __init__(self, model_func, n_way, n_support): 13 | super(ProtoNet, self).__init__( model_func, n_way, n_support) 14 | self.loss_fn = nn.CrossEntropyLoss() 15 | 16 | 17 | def set_forward(self, x, is_feature = False): 18 | z_support, z_query = self.parse_feature(x,is_feature) 19 | 20 | z_support = z_support.contiguous() 21 | z_proto = z_support.view(self.n_way, self.n_support, -1 ).mean(1) #the shape of z is [n_data, n_dim] 22 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 23 | 24 | dists = euclidean_dist(z_query, z_proto) 25 | scores = -dists 26 | return scores 27 | 28 | 29 | def set_forward_loss(self, x): 30 | y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query )) 31 | y_query = Variable(y_query.cuda()) 32 | 33 | scores = self.set_forward(x) 34 | 35 | return self.loss_fn(scores, y_query ) 36 | 37 | 38 | def euclidean_dist( x, y): 39 | # x: N x D 40 | # y: M x D 41 | n = x.size(0) 42 | m = y.size(0) 43 | d = x.size(1) 44 | assert d == y.size(1) 45 | 46 | x = x.unsqueeze(1).expand(n, m, d) 47 | y = y.unsqueeze(0).expand(n, m, d) 48 | 49 | return torch.pow(x - y, 2).sum(2) 50 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.base import BaseModel 2 | from model.byol import BYOL 3 | from model.moco import MoCo 4 | from model.simclr import SimCLR 5 | from model.simsiam import SimSiam 6 | 7 | _model_class_map = { 8 | 'base': BaseModel, 9 | 'simclr': SimCLR, 10 | 'byol': BYOL, 11 | 'moco': MoCo, 12 | 'simsiam': SimSiam, 13 | } 14 | 15 | 16 | def get_model_class(key): 17 | if key in _model_class_map: 18 | return _model_class_map[key] 19 | else: 20 | raise ValueError('Invalid model: {}'.format(key)) 21 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from argparse import Namespace 3 | from typing import Tuple 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class BaseModel(nn.Module): 10 | """ 11 | BaseModel subclasses self-contain all modules and losses required for pre-training. 12 | 13 | - supported_feature_selectors: Feature selectors (see `forward_features()`) are used during fine-tuning 14 | to select which features (from which layer the features should be extracted) should be used for downstream 15 | tasks. This class attribute should be set for subclasses to prevent mistakes regarding the feature_selector 16 | argument (see `params.ft_features`). 17 | """ 18 | supported_feature_selectors = [] 19 | 20 | def __init__(self, backbone: nn.Module, params: Namespace): 21 | super().__init__() 22 | self.backbone = backbone 23 | self.params = params 24 | self.classifier = nn.Linear(backbone.final_feat_dim, params.num_classes) 25 | self.classifier.bias.data.fill_(0) 26 | self.cls_loss_function = nn.CrossEntropyLoss() 27 | self.final_feat_dim = backbone.final_feat_dim 28 | 29 | def forward_features(self, x, feature_selector: str = None): 30 | """ 31 | You'll likely need to override this method for SSL models. 32 | """ 33 | return self.backbone(x) 34 | 35 | def forward(self, x): 36 | x = self.backbone(x) 37 | x = self.classifier(x) 38 | return x 39 | 40 | def compute_cls_loss_and_accuracy(self, x, y, return_predictions=False) -> Tuple: 41 | scores = self.forward(x) 42 | _, predicted = torch.max(scores.data, 1) 43 | accuracy = predicted.eq(y.data).cpu().sum() / x.shape[0] 44 | if return_predictions: 45 | return self.cls_loss_function(scores, y), accuracy, predicted 46 | else: 47 | return self.cls_loss_function(scores, y), accuracy 48 | 49 | def on_step_start(self): 50 | pass 51 | 52 | def on_step_end(self): 53 | pass 54 | 55 | def on_epoch_start(self): 56 | pass 57 | 58 | def on_epoch_end(self): 59 | pass 60 | 61 | 62 | class BaseSelfSupervisedModel(BaseModel): 63 | @abstractmethod 64 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 65 | """ 66 | If SSL is based on paired input: 67 | By default: x1, x2 represent the input pair. 68 | If x2=None: x1 alone contains the full concatenated input pair. 69 | Else: 70 | x1 contains the input. 71 | """ 72 | raise NotImplementedError() 73 | -------------------------------------------------------------------------------- /model/byol.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from argparse import Namespace 4 | from functools import wraps 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from model.base import BaseSelfSupervisedModel 11 | 12 | 13 | def _singleton(cache_key): 14 | def inner_fn(fn): 15 | @wraps(fn) 16 | def wrapper(self, *args, **kwargs): 17 | instance = getattr(self, cache_key) 18 | if instance is not None: 19 | return instance 20 | 21 | instance = fn(self, *args, **kwargs) 22 | setattr(self, cache_key, instance) 23 | return instance 24 | 25 | return wrapper 26 | 27 | return inner_fn 28 | 29 | 30 | def _get_module_device(module): 31 | return next(module.parameters()).device 32 | 33 | 34 | def _set_requires_grad(model, val): 35 | for p in model.parameters(): 36 | p.requires_grad = val 37 | 38 | 39 | def _loss_fn(x, y): 40 | x = F.normalize(x, dim=-1, p=2) 41 | y = F.normalize(y, dim=-1, p=2) 42 | return 2 - 2 * (x * y).sum(dim=-1) 43 | 44 | 45 | class RandomApply(nn.Module): 46 | def __init__(self, fn, p): 47 | super().__init__() 48 | self.fn = fn 49 | self.p = p 50 | 51 | def forward(self, x): 52 | if random.random() > self.p: 53 | return x 54 | return self.fn(x) 55 | 56 | 57 | class EMA: 58 | def __init__(self, beta): 59 | super().__init__() 60 | self.beta = beta 61 | 62 | def update_average(self, old, new): 63 | if old is None: 64 | return new 65 | return old * self.beta + (1 - self.beta) * new 66 | 67 | 68 | def _update_moving_average(ema_updater, ma_model, current_model): 69 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 70 | old_weight, up_weight = ma_params.data, current_params.data 71 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 72 | 73 | 74 | class MLP(nn.Module): 75 | def __init__(self, dim, projection_size, hidden_size=4096): 76 | super().__init__() 77 | self.net = nn.Sequential( 78 | nn.Linear(dim, hidden_size), 79 | nn.BatchNorm1d(hidden_size), 80 | nn.ReLU(inplace=True), 81 | nn.Linear(hidden_size, projection_size) 82 | ) 83 | 84 | def forward(self, x): 85 | return self.net(x) 86 | 87 | 88 | class NetWrapper(nn.Module): 89 | def __init__(self, net, projection_size, projection_hidden_size, 90 | layer=-1): # default layer = -2 since network includes classifier. Ours does not have classifier. 91 | super().__init__() 92 | self.net = net 93 | self.layer = layer 94 | 95 | self.projector = None 96 | self.projection_size = projection_size 97 | self.projection_hidden_size = projection_hidden_size 98 | 99 | self.hidden = {} 100 | self.hook_registered = False 101 | 102 | def _find_layer(self): 103 | if type(self.layer) == str: 104 | modules = dict([*self.net.named_modules()]) 105 | return modules.get(self.layer, None) 106 | elif type(self.layer) == int: 107 | children = [*self.net.children()] 108 | return children[self.layer] 109 | return None 110 | 111 | def _hook(self, _, input, output): 112 | device = input[0].device 113 | self.hidden[device] = output.reshape(output.shape[0], -1) # flatten 114 | 115 | def _register_hook(self): 116 | layer = self._find_layer() 117 | assert layer is not None, f'hidden layer ({self.layer}) not found' 118 | handle = layer.register_forward_hook(self._hook) 119 | self.hook_registered = True 120 | 121 | @_singleton('projector') 122 | def _get_projector(self, hidden): 123 | _, dim = hidden.shape 124 | projector = MLP(dim, self.projection_size, self.projection_hidden_size) 125 | return projector.to(hidden) 126 | 127 | def get_representation(self, x): 128 | if self.layer == -1: 129 | return self.net(x) 130 | 131 | if not self.hook_registered: 132 | self._register_hook() 133 | 134 | self.hidden.clear() 135 | _ = self.net(x) 136 | hidden = self.hidden[x.device] 137 | self.hidden.clear() 138 | 139 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output' 140 | return hidden 141 | 142 | def forward(self, x, return_projection=True): 143 | representation = self.get_representation(x) 144 | 145 | if not return_projection: 146 | return representation 147 | 148 | projector = self._get_projector(representation) 149 | projection = projector(representation) 150 | return projection, representation 151 | 152 | 153 | class BYOL(BaseSelfSupervisedModel): 154 | def __init__(self, backbone: nn.Module, params: Namespace, use_momentum=True): 155 | super().__init__(backbone, params) 156 | 157 | image_size = 224 158 | hidden_layer = -1 159 | projection_size = 256 160 | projection_hidden_size = 4096 161 | moving_average_decay = 0.99 162 | use_momentum = use_momentum 163 | 164 | self.online_encoder = NetWrapper(self.backbone, projection_size, projection_hidden_size, layer=hidden_layer) 165 | 166 | self.use_momentum = use_momentum 167 | self.target_encoder = None 168 | self.target_ema_updater = EMA(moving_average_decay) 169 | 170 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) 171 | 172 | # get device of network and make wrapper same device 173 | device = _get_module_device(backbone) 174 | self.to(device) 175 | 176 | # send a mock image tensor to instantiate singleton parameters 177 | self.compute_ssl_loss(torch.randn(2, 3, image_size, image_size, device=device), 178 | torch.randn(2, 3, image_size, image_size, device=device)) 179 | 180 | @_singleton('target_encoder') 181 | def _get_target_encoder(self): 182 | target_encoder = copy.deepcopy(self.online_encoder) 183 | _set_requires_grad(target_encoder, False) 184 | return target_encoder 185 | 186 | def _reset_moving_average(self): 187 | del self.target_encoder 188 | self.target_encoder = None 189 | 190 | def _update_moving_average(self): 191 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' 192 | assert self.target_encoder is not None, 'target encoder has not been created yet' 193 | _update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) 194 | 195 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 196 | if x2 is None: 197 | x = x1 198 | batch_size = int(x.shape[0] / 2) 199 | x1 = x[:batch_size] 200 | x2 = x[batch_size:] 201 | 202 | assert not (self.training and x1.shape[ 203 | 0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer' 204 | 205 | online_proj_one, _ = self.online_encoder(x1) 206 | online_proj_two, _ = self.online_encoder(x2) 207 | 208 | online_pred_one = self.online_predictor(online_proj_one) 209 | online_pred_two = self.online_predictor(online_proj_two) 210 | 211 | with torch.no_grad(): 212 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder 213 | target_proj_one, _ = target_encoder(x1) 214 | target_proj_two, _ = target_encoder(x2) 215 | target_proj_one.detach_() 216 | target_proj_two.detach_() 217 | 218 | loss_one = _loss_fn(online_pred_one, target_proj_two.detach()) 219 | loss_two = _loss_fn(online_pred_two, target_proj_one.detach()) 220 | 221 | loss = loss_one + loss_two 222 | loss = loss.mean() 223 | 224 | if return_features: 225 | if x2 is None: 226 | return loss, torch.cat([online_proj_one, online_proj_two]) 227 | else: 228 | return loss, online_proj_one, online_proj_two 229 | else: 230 | return loss 231 | 232 | def on_step_end(self): 233 | if self.use_momentum: 234 | self._update_moving_average() 235 | 236 | -------------------------------------------------------------------------------- /model/classifier_head.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LinearClassifier(nn.Module): 5 | def __init__(self, in_dim, out_dim, params): 6 | super().__init__() 7 | self.fc = nn.Linear(in_dim, out_dim) 8 | 9 | def forward(self, x): 10 | x = self.fc(x) 11 | return x 12 | 13 | 14 | class TwoLayerMLPClassifier(nn.Module): 15 | def __init__(self, in_dim, out_dim, params): 16 | super().__init__() 17 | self.fc1 = nn.Linear(in_dim, in_dim) 18 | self.relu = nn.ReLU() 19 | self.fc2 = nn.Linear(in_dim, out_dim) 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.relu(x) 24 | x = self.fc2(x) 25 | return x 26 | 27 | 28 | CLASSIFIER_HEAD_CLASS_MAP = { 29 | 'linear': LinearClassifier, 30 | 'two_layer_mlp': TwoLayerMLPClassifier, 31 | } 32 | 33 | def get_classifier_head_class(key): 34 | if key in CLASSIFIER_HEAD_CLASS_MAP: 35 | return CLASSIFIER_HEAD_CLASS_MAP[key] 36 | else: 37 | raise ValueError('Invalid classifier head specifier: {}'.format(key)) 38 | -------------------------------------------------------------------------------- /model/moco.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from argparse import Namespace 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from model.base import BaseSelfSupervisedModel 9 | 10 | 11 | class MoCo(BaseSelfSupervisedModel): 12 | def __init__(self, backbone: nn.Module, params: Namespace): 13 | super().__init__(backbone, params) 14 | 15 | dim = 128 16 | mlp = False 17 | self.K = 1024 18 | self.m = 0.999 19 | self.T = 1.0 20 | 21 | self.encoder_q = self.backbone 22 | self.encoder_k = copy.deepcopy(self.backbone) 23 | 24 | if not mlp: 25 | self.projector_q = nn.Linear(self.encoder_q.final_feat_dim, dim) 26 | self.projector_k = nn.Linear(self.encoder_k.final_feat_dim, dim) 27 | else: 28 | mlp_dim = self.encoder_q.feature.final_feat_dim 29 | self.projector_q = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim)) 30 | self.projector_k = nn.Sequential(nn.Linear(mlp_dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim)) 31 | 32 | self.encoder_k.requires_grad_(False) 33 | self.projector_k.requires_grad_(False) 34 | # Just in case (copied from old code) 35 | for param_k in self.encoder_k.parameters(): 36 | param_k.requires_grad = False 37 | for param_k in self.projector_k.parameters(): 38 | param_k.requires_grad = False 39 | 40 | self.register_buffer("queue", torch.randn(dim, self.K)) 41 | self.queue = F.normalize(self.queue, dim=0) 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | self.ce_loss = nn.CrossEntropyLoss() 45 | 46 | @torch.no_grad() 47 | def _momentum_update_key_encoder(self): 48 | """ 49 | Momentum update of the key encoder 50 | """ 51 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 52 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 53 | for param_q_, param_k_ in zip(self.projector_q.parameters(), self.projector_k.parameters()): 54 | param_k_.data = param_k_.data * self.m + param_q_.data * (1. - self.m) 55 | 56 | @torch.no_grad() 57 | def _dequeue_and_enqueue(self, keys): 58 | batch_size = keys.shape[0] 59 | ptr = int(self.queue_ptr) 60 | assert self.K % batch_size == 0 # for simplicity 61 | 62 | # replace the keys at ptr (dequeue and enqueue) 63 | self.queue[:, ptr:ptr + batch_size] = keys.T 64 | ptr = (ptr + batch_size) % self.K # move pointer 65 | 66 | self.queue_ptr[0] = ptr 67 | 68 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 69 | if x2 is None: 70 | x = x1 71 | batch_size = int(x.shape[0] / 2) 72 | im_q = x[:batch_size] 73 | im_k = x[batch_size:] 74 | else: 75 | im_q = x1 76 | im_k = x2 77 | 78 | q_features = self.encoder_q(im_q) 79 | q = self.projector_q(q_features) # queries: NxC 80 | q = F.normalize(q, dim=1) 81 | 82 | # compute key features 83 | with torch.no_grad(): # no gradient to keys 84 | self._momentum_update_key_encoder() # update the key encoder 85 | 86 | k_features = self.encoder_k(im_k) 87 | k = self.projector_k(k_features) # keys: NxC 88 | k = F.normalize(k, dim=1) 89 | 90 | # compute logits (Einstein sum is more intuitive) 91 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # positive logits: Nx1 92 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # negative logits: NxK 93 | 94 | logits = torch.cat([l_pos, l_neg], dim=1) # logits: Nx(1+K) 95 | logits /= self.T # apply temperature 96 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # labels: positive key indicators 97 | 98 | self._dequeue_and_enqueue(k) 99 | 100 | loss = self.ce_loss(logits, labels) 101 | 102 | if return_features: 103 | if x2 is None: 104 | return loss, torch.cat([q_features, k_features]) 105 | else: 106 | return loss, q_features, k_features 107 | else: 108 | return loss 109 | -------------------------------------------------------------------------------- /model/simclr.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from functools import lru_cache 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | from model.base import BaseSelfSupervisedModel 9 | 10 | 11 | class ProjectionHead(nn.Module): 12 | def __init__(self, in_dim, out_dim): 13 | super(ProjectionHead, self).__init__() 14 | self.in_dim = in_dim 15 | self.out_dim = out_dim 16 | 17 | self.fc1 = nn.Linear(in_dim, in_dim) 18 | self.relu = nn.ReLU() 19 | self.fc2 = nn.Linear(in_dim, out_dim) 20 | 21 | def forward(self, x): 22 | return self.fc2(self.relu(self.fc1(x))) 23 | 24 | 25 | class NTXentLoss(nn.Module): 26 | def __init__(self, temperature, use_cosine_similarity): 27 | super(NTXentLoss, self).__init__() 28 | self.temperature = temperature 29 | self.softmax = torch.nn.Softmax(dim=-1) 30 | self.similarity_function = self._get_similarity_function(use_cosine_similarity) 31 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 32 | 33 | def _get_similarity_function(self, use_cosine_similarity): 34 | if use_cosine_similarity: 35 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 36 | return self._cosine_simililarity 37 | else: 38 | return self._dot_simililarity 39 | 40 | @lru_cache(maxsize=4) 41 | def _get_correlated_mask(self, batch_size): 42 | diag = np.eye(2 * batch_size) 43 | l1 = np.eye((2 * batch_size), 2 * 44 | batch_size, k=-batch_size) 45 | l2 = np.eye((2 * batch_size), 2 * 46 | batch_size, k=batch_size) 47 | mask = torch.from_numpy((diag + l1 + l2)) 48 | mask = (1 - mask).type(torch.bool) 49 | return mask 50 | 51 | @staticmethod 52 | def _dot_simililarity(x, y): 53 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 54 | # x shape: (N, 1, C) 55 | # y shape: (1, C, 2N) 56 | # v shape: (N, 2N) 57 | return v 58 | 59 | def _cosine_simililarity(self, x, y): 60 | # x shape: (N, 1, C) 61 | # y shape: (1, 2N, C) 62 | # v shape: (N, 2N) 63 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 64 | return v 65 | 66 | def forward(self, zis, zjs): 67 | batch_size = zis.shape[0] 68 | representations = torch.cat([zjs, zis], dim=0) 69 | device = representations.device 70 | 71 | similarity_matrix = self.similarity_function( 72 | representations, representations) 73 | 74 | # filter out the scores from the positive samples 75 | l_pos = torch.diag(similarity_matrix, batch_size) 76 | r_pos = torch.diag(similarity_matrix, -batch_size) 77 | positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1) 78 | 79 | mask = self._get_correlated_mask(batch_size).to(device) 80 | negatives = similarity_matrix[mask].view(2 * batch_size, -1) 81 | 82 | logits = torch.cat((positives, negatives), dim=1) 83 | logits /= self.temperature 84 | 85 | labels = torch.zeros(2 * batch_size).to(device).long() 86 | loss = self.criterion(logits, labels) 87 | 88 | return loss / (2 * batch_size) 89 | 90 | 91 | class SimCLR(BaseSelfSupervisedModel): 92 | 93 | def __init__(self, backbone: nn.Module, params: Namespace): 94 | super().__init__(backbone, params) 95 | self.head = ProjectionHead(backbone.final_feat_dim, out_dim=params.model_simclr_projection_dim) 96 | self.ssl_loss_fn = NTXentLoss(temperature=params.model_simclr_temperature, use_cosine_similarity=True) 97 | self.final_feat_dim = self.backbone.final_feat_dim 98 | 99 | def compute_ssl_loss(self, x1, x2=None, return_features=False): 100 | if x2 is None: 101 | x = x1 102 | else: 103 | x = torch.cat([x1, x2]) 104 | batch_size = int(x.shape[0] / 2) 105 | 106 | f = self.backbone(x) 107 | f1, f2 = f[:batch_size], f[batch_size:] 108 | p1 = self.head(f1) 109 | p2 = self.head(f2) 110 | loss = self.ssl_loss_fn(p1, p2) 111 | 112 | if return_features: 113 | if x2 is None: 114 | return loss, f 115 | else: 116 | return loss, f1, f2 117 | else: 118 | return loss 119 | -------------------------------------------------------------------------------- /model/simsiam.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | from torch import nn 4 | 5 | from model import BYOL 6 | 7 | 8 | class SimSiam(BYOL): 9 | def __init__(self, backbone: nn.Module, params: Namespace): 10 | super().__init__(backbone, params, use_momentum=False) 11 | -------------------------------------------------------------------------------- /paths.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | from argparse import Namespace 5 | 6 | import configs 7 | 8 | DATASET_KEYS = { 9 | 'miniImageNet': 'mini', 10 | 'miniImageNet_test': 'mini_test', 11 | 'tieredImageNet': 'tiered', 12 | 'tieredImageNet_test': 'tiered_test', 13 | 'ImageNet': 'imagenet', 14 | 'CropDisease': 'crop', 15 | 'EuroSAT': 'euro', 16 | 'ISIC': 'isic', 17 | 'ChestX': 'chest', 18 | 'cars': 'cars', 19 | 'cub': 'cub', 20 | 'places': 'places', 21 | 'plantae': 'plantae', 22 | } 23 | 24 | BACKBONE_KEYS = { 25 | 'resnet10': 'resnet10', 26 | 'resnet18': 'resnet18', 27 | 'resnet50': 'resnet50', 28 | } 29 | 30 | MODEL_KEYS = { 31 | 'base': 'base', 32 | 'simclr': 'simclr', 33 | 'simsiam': 'simsiam', 34 | 'moco': 'moco', 35 | 'swav': 'swav', 36 | 'byol': 'byol', 37 | } 38 | 39 | 40 | def get_output_directory(params: Namespace, previous=False, makedirs=True): 41 | """ 42 | :param params: 43 | :param previous: get previous output directory for pls, put, pmsl modes 44 | :return: 45 | """ 46 | if previous and not (params.pls or params.put or params.pmsl): 47 | raise ValueError('Invalid arguments for previous=True') 48 | 49 | path = configs.save_dir 50 | path = os.path.join(path, 'output') 51 | path = os.path.join(path, DATASET_KEYS[params.source_dataset]) 52 | 53 | pretrain_specifiers = [BACKBONE_KEYS[params.backbone]] 54 | if previous: 55 | if params.pls: 56 | pretrain_specifiers.append(MODEL_KEYS['base']) 57 | else: 58 | pretrain_specifiers.append(MODEL_KEYS[params.model]) 59 | 60 | if params.pls: 61 | pretrain_specifiers.append('LS') 62 | elif params.put: 63 | pretrain_specifiers.append('UT') 64 | elif params.pmsl: 65 | pretrain_specifiers.append('LS_UT') 66 | else: 67 | raise AssertionError("Invalid parameters") 68 | pretrain_specifiers.append(params.previous_tag) 69 | else: 70 | pretrain_specifiers.append(MODEL_KEYS[params.model]) 71 | if params.pls: 72 | pretrain_specifiers.append('PLS') 73 | if params.put: 74 | pretrain_specifiers.append('PUT') 75 | if params.pmsl: 76 | pretrain_specifiers.append('PMSL') 77 | if params.ls: 78 | pretrain_specifiers.append('LS') 79 | if params.us: 80 | pretrain_specifiers.append('US') 81 | if params.ut: 82 | pretrain_specifiers.append('UT') 83 | pretrain_specifiers.append(params.tag) 84 | 85 | path = os.path.join(path, '_'.join(pretrain_specifiers)) 86 | 87 | if previous: 88 | if params.put or params.pmsl: 89 | path = os.path.join(path, DATASET_KEYS[params.target_dataset]) 90 | else: 91 | if params.put or params.pmsl or params.ut: 92 | path = os.path.join(path, DATASET_KEYS[params.target_dataset]) 93 | 94 | if makedirs: 95 | os.makedirs(path, exist_ok=True) 96 | 97 | return path 98 | 99 | 100 | def get_pretrain_history_path(output_directory): 101 | basename = 'pretrain_history.csv' 102 | return os.path.join(output_directory, basename) 103 | 104 | 105 | def get_pretrain_state_path(output_directory, epoch=0): 106 | """ 107 | :param output_directory: 108 | :param epoch: Number of completed epochs. I.e., 0 = initial. 109 | :return: 110 | """ 111 | basename = 'pretrain_state_{:04d}.pt'.format(epoch) 112 | return os.path.join(output_directory, basename) 113 | 114 | 115 | def get_final_pretrain_state_path(output_directory): 116 | glob_pattern = os.path.join(output_directory, 'pretrain_state_*.pt') 117 | paths = glob.glob(glob_pattern) 118 | 119 | pattern = re.compile('pretrain_state_(\d{4}).pt') 120 | paths_by_epoch = dict() 121 | for path in paths: 122 | match = pattern.search(path) 123 | if match: 124 | paths_by_epoch[match.group(1)] = path 125 | 126 | if len(paths_by_epoch) == 0: 127 | raise FileNotFoundError('Could not find valid pre-train state file in {}'.format(output_directory)) 128 | 129 | max_epoch = max(paths_by_epoch.keys()) 130 | return paths_by_epoch[max_epoch] 131 | 132 | 133 | def get_pretrain_params_path(output_directory): 134 | return os.path.join(output_directory, 'pretrain_params.json') 135 | 136 | 137 | def get_ft_output_directory(params, makedirs=True): 138 | path = get_output_directory(params, makedirs=makedirs) 139 | if not params.ut: 140 | path = os.path.join(path, params.target_dataset) 141 | ft_basename = '{:02d}way_{:03d}shot_{}_{}'.format(params.n_way, params.n_shot, params.ft_parts, params.ft_tag) 142 | path = os.path.join(path, ft_basename) 143 | 144 | if makedirs: 145 | os.makedirs(path, exist_ok=True) 146 | 147 | return path 148 | 149 | 150 | def get_ft_params_path(output_directory): 151 | return os.path.join(output_directory, 'params.json') 152 | 153 | 154 | def get_ft_train_history_path(output_directory): 155 | return os.path.join(output_directory, 'train_history.csv') 156 | 157 | 158 | def get_ft_test_history_path(output_directory): 159 | return os.path.join(output_directory, 'test_history.csv') 160 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torch.optim 8 | from tqdm import tqdm 9 | 10 | from backbone import get_backbone_class 11 | from datasets.dataloader import get_dataloader, get_unlabeled_dataloader 12 | from io_utils import parse_args 13 | from model import get_model_class 14 | from paths import get_output_directory, get_final_pretrain_state_path, get_pretrain_state_path, \ 15 | get_pretrain_params_path, get_pretrain_history_path 16 | from scheduler import RepeatedMultiStepLR 17 | 18 | 19 | def _get_dataloaders(params): 20 | labeled_source_bs = params.ls_batch_size 21 | batch_size = params.batch_size 22 | unlabeled_source_bs = batch_size 23 | unlabeled_target_bs = batch_size 24 | 25 | if params.us and params.ut: 26 | unlabeled_source_bs //= 2 27 | unlabeled_target_bs //= 2 28 | 29 | ls, us, ut = None, None, None 30 | if params.ls: 31 | print('Using source data {} (labeled)'.format(params.source_dataset)) 32 | ls = get_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation, 33 | batch_size=labeled_source_bs, num_workers=params.num_workers) 34 | 35 | if params.us: 36 | print('Using source data {} (unlabeled)'.format(params.source_dataset)) 37 | us = get_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation, 38 | batch_size=unlabeled_source_bs, num_workers=params.num_workers, 39 | siamese=True) # important 40 | 41 | if params.ut: 42 | print('Using target data {} (unlabeled)'.format(params.target_dataset)) 43 | ut = get_unlabeled_dataloader(dataset_name=params.target_dataset, augmentation=params.augmentation, 44 | batch_size=unlabeled_target_bs, num_workers=params.num_workers, siamese=True, 45 | unlabeled_ratio=params.unlabeled_ratio) 46 | 47 | return ls, us, ut 48 | 49 | 50 | def main(params): 51 | backbone = get_backbone_class(params.backbone)() 52 | model = get_model_class(params.model)(backbone, params) 53 | output_dir = get_output_directory(params) 54 | labeled_source_loader, unlabeled_source_loader, unlabeled_target_loader = _get_dataloaders(params) 55 | 56 | params_path = get_pretrain_params_path(output_dir) 57 | with open(params_path, 'w') as f: 58 | json.dump(vars(params), f, indent=4) 59 | pretrain_history_path = get_pretrain_history_path(output_dir) 60 | print('Saving pretrain params to {}'.format(params_path)) 61 | print('Saving pretrain history to {}'.format(pretrain_history_path)) 62 | 63 | if params.pls or params.put or params.pmsl: 64 | # Load previous pre-trained weights for second-step pre-training 65 | previous_base_output_dir = get_output_directory(params, previous=True) 66 | state_path = get_final_pretrain_state_path(previous_base_output_dir) 67 | print('Loading previous state for second-step pre-training:') 68 | print(state_path) 69 | 70 | # Note, override model.load_state_dict to change this behavior. 71 | state = torch.load(state_path) 72 | # del state["classifier.weight"] # hotfixes for using weights pre-trained on different source datasets w/o LS 73 | # del state["classifier.bias"] 74 | missing, unexpected = model.load_state_dict(state, strict=False) 75 | if len(unexpected): 76 | raise Exception("Unexpected keys from previous state: {}".format(unexpected)) 77 | elif params.imagenet_pretrained: 78 | print("Loading ImageNet pretrained weights") 79 | backbone.load_imagenet_weights() 80 | 81 | model.train() 82 | model.cuda() 83 | 84 | if params.optimizer == 'sgd': 85 | optimizer = torch.optim.SGD(model.parameters(), 86 | lr=params.lr, momentum=0.9, 87 | weight_decay=1e-4, 88 | nesterov=False) 89 | elif params.optimizer == 'adam': 90 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 91 | else: 92 | raise ValueError('Invalid value for params.optimizer: {}'.format(params.optimizer)) 93 | 94 | if params.scheduler == "MultiStepLR": 95 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params.scheduler_milestones, gamma=0.1) 96 | elif params.scheduler == "RepeatedMultiStepLR": 97 | scheduler = RepeatedMultiStepLR(optimizer, milestones=params.scheduler_milestones, interval=1000, gamma=0.1) 98 | else: 99 | raise ValueError("Invalid value for params.scheduler: {}".format(params.scheduler)) 100 | 101 | pretrain_history = { 102 | 'loss': [0] * params.epochs, 103 | 'source_loss': [0] * params.epochs, 104 | 'target_loss': [0] * params.epochs, 105 | } 106 | 107 | for epoch in range(params.epochs): 108 | print('EPOCH {}'.format(epoch).center(40).center(80, '#')) 109 | 110 | epoch_loss = 0 111 | epoch_source_loss = 0 112 | epoch_target_loss = 0 113 | steps = 0 114 | 115 | if epoch == 0: 116 | state_path = get_pretrain_state_path(output_dir, epoch=0) 117 | print('Saving pre-train state to:') 118 | print(state_path) 119 | torch.save(model.state_dict(), state_path) 120 | 121 | model.on_epoch_start() 122 | model.train() 123 | 124 | if params.ls and not params.us and not params.ut: # only ls (type 1) 125 | for x, y in tqdm(labeled_source_loader): 126 | model.on_step_start() 127 | optimizer.zero_grad() 128 | loss, _ = model.compute_cls_loss_and_accuracy(x.cuda(), y.cuda()) 129 | loss.backward() 130 | optimizer.step() 131 | model.on_step_end() 132 | 133 | epoch_loss += loss.item() 134 | epoch_source_loss += loss.item() 135 | steps += 1 136 | elif not params.ls and params.us and not params.ut: # only us (type 2) 137 | for x, _ in tqdm(unlabeled_source_loader): 138 | model.on_step_start() 139 | optimizer.zero_grad() 140 | loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda()) 141 | loss.backward() 142 | optimizer.step() 143 | model.on_step_end() 144 | 145 | epoch_loss += loss.item() 146 | epoch_source_loss += loss.item() 147 | steps += 1 148 | elif params.ut: # ut (epoch is based on unlabeled target) 149 | max_epochs = params.epochs 150 | max_steps = len(unlabeled_target_loader) * max_epochs 151 | for i, (x, _) in enumerate(tqdm(unlabeled_target_loader)): 152 | current_step = epoch * len(unlabeled_target_loader) + i 153 | model.on_step_start() 154 | optimizer.zero_grad() 155 | target_loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda()) # UT loss 156 | epoch_target_loss += target_loss.item() 157 | source_loss = None 158 | if params.ls: # type 4, 7 159 | try: 160 | sx, sy = labeled_source_loader_iter.next() 161 | except (StopIteration, NameError): 162 | labeled_source_loader_iter = iter(labeled_source_loader) 163 | sx, sy = labeled_source_loader_iter.next() 164 | source_loss = model.compute_cls_loss_and_accuracy(sx.cuda(), sy.cuda())[0] # LS loss 165 | epoch_source_loss += source_loss.item() 166 | if params.us: # type 5, 8 167 | try: 168 | sx, sy = unlabeled_source_loader_iter.next() 169 | except (StopIteration, NameError): 170 | unlabeled_source_loader_iter = iter(unlabeled_source_loader) 171 | sx, sy = unlabeled_source_loader_iter.next() 172 | source_loss = model.compute_ssl_loss(sx[0].cuda(), sx[1].cuda()) # US loss 173 | epoch_source_loss += source_loss.item() 174 | if source_loss: 175 | if params.gamma_schedule is None: 176 | gamma = params.gamma 177 | elif params.gamma_schedule == "linear": 178 | gamma = current_step / (max_steps - 1) # gamma \in [0, 1] 179 | assert 0 <= gamma <= 1 # temp 180 | else: 181 | raise AssertionError("Invalid params.gamma_schedule (should be checked during argparse)") 182 | loss = source_loss * (1 - gamma) + target_loss * gamma 183 | else: 184 | loss = target_loss 185 | loss.backward() 186 | optimizer.step() 187 | model.on_step_end() 188 | 189 | epoch_loss += loss.item() 190 | steps += 1 191 | else: 192 | raise AssertionError('Unknown training combination.') 193 | 194 | if scheduler is not None: 195 | scheduler.step() 196 | model.on_epoch_end() 197 | 198 | mean_loss = epoch_loss / steps 199 | mean_source_loss = epoch_source_loss / steps 200 | mean_target_loss = epoch_target_loss / steps 201 | fmt = 'Epoch {:04d}: loss={:6.4f} source_loss={:6.4f} target_loss={:6.4f}' 202 | print(fmt.format(epoch, mean_loss, mean_source_loss, mean_target_loss)) 203 | 204 | pretrain_history['loss'][epoch] = mean_loss 205 | pretrain_history['source_loss'][epoch] = mean_source_loss 206 | pretrain_history['target_loss'][epoch] = mean_target_loss 207 | 208 | pd.DataFrame(pretrain_history).to_csv(pretrain_history_path) 209 | 210 | epoch += 1 211 | if epoch % params.model_save_interval == 0 or epoch == params.epochs: 212 | state_path = get_pretrain_state_path(output_dir, epoch=epoch) 213 | print('Saving pre-train state to:') 214 | print(state_path) 215 | torch.save(model.state_dict(), state_path) 216 | 217 | 218 | if __name__ == '__main__': 219 | np.random.seed(10) 220 | params = parse_args() 221 | 222 | targets = params.target_dataset 223 | if targets is None: 224 | targets = [targets] 225 | elif len(targets) > 1: 226 | print('#' * 80) 227 | print("Running pretrain iteratively for multiple target datasets: {}".format(targets)) 228 | print('#' * 80) 229 | 230 | for target in targets: 231 | params.target_dataset = target 232 | main(params) 233 | -------------------------------------------------------------------------------- /pretrain.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | 4 | SOURCES=("miniImageNet" "tieredImageNet" "ImageNet") 5 | SOURCE=${SOURCES[2]} 6 | 7 | TARGETS=("CropDisease" "ISIC" "EuroSAT" "ChestX" "places" "cub" "plantae" "cars") 8 | TARGET=${TARGETS[0]} 9 | 10 | # BACKBONE=resnet10 # for mini 11 | BACKBONE=resnet18 # for tiered and full imagenet 12 | 13 | 14 | # Source SL (note, we adapt the torchvision pre-trained model for ResNet18 + ImageNet. Do not use this command as-is.) 15 | python pretrain.py --ls --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "base" --tag "default" 16 | 17 | # Target SSL 18 | python pretrain.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default" 19 | python pretrain.py --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default" 20 | 21 | # MSL (Source SL + Target SSL) 22 | python pretrain.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78" 23 | python pretrain.py --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78" 24 | 25 | # Two-Stage SSL (Source SL -> Target SSL) 26 | python pretrain.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "default" --previous_tag "default" 27 | python pretrain.py --pls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "default" --previous_tag "default" 28 | 29 | # Two-Stage MSL (Source SL -> Source SL + Target SSL) 30 | python pretrain.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "simclr" --tag "gamma78" --previous_tag "default" 31 | python pretrain.py --pls --ls --ut --source_dataset $SOURCE --target_dataset $TARGET --backbone $BACKBONE --model "byol" --tag "gamma78" --previous_tag "default" -------------------------------------------------------------------------------- /pretrain_new_lu20.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | import torch.optim 8 | from tqdm import tqdm 9 | 10 | from backbone import get_backbone_class 11 | from datasets.dataloader import get_dataloader, get_unlabeled_dataloader 12 | from io_utils import parse_args 13 | from model import get_model_class 14 | from paths import get_output_directory, get_final_pretrain_state_path, get_pretrain_state_path, \ 15 | get_pretrain_params_path, get_pretrain_history_path 16 | 17 | 18 | def _get_dataloaders(params): 19 | batch_size = params.batch_size 20 | labeled_source_bs = batch_size 21 | unlabeled_source_bs = batch_size 22 | unlabeled_target_bs = batch_size 23 | 24 | if params.us and params.ut: 25 | unlabeled_source_bs //= 2 26 | unlabeled_target_bs //= 2 27 | 28 | ls, us, ut = None, None, None 29 | if params.ls: 30 | print('Using source data {} (labeled)'.format(params.source_dataset)) 31 | ls = get_unlabeled_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation, 32 | batch_size=labeled_source_bs, siamese=False, unlabeled_ratio=params.unlabeled_ratio, 33 | num_workers=params.num_workers, split_seed=params.split_seed) 34 | 35 | if params.us: 36 | raise NotImplementedError 37 | print('Using source data {} (unlabeled)'.format(params.source_dataset)) 38 | us = get_dataloader(dataset_name=params.source_dataset, augmentation=params.augmentation, 39 | batch_size=unlabeled_source_bs, num_workers=params.num_workers, 40 | siamese=True) # important 41 | 42 | if params.ut: 43 | print('Using target data {} (unlabeled)'.format(params.target_dataset)) 44 | ut = get_unlabeled_dataloader(dataset_name=params.target_dataset, augmentation=params.augmentation, 45 | batch_size=unlabeled_target_bs, num_workers=params.num_workers, siamese=True, 46 | unlabeled_ratio=params.unlabeled_ratio) 47 | 48 | return ls, us, ut 49 | 50 | 51 | def main(params): 52 | backbone = get_backbone_class(params.backbone)() 53 | model = get_model_class(params.model)(backbone, params) 54 | output_dir = get_output_directory(params) 55 | labeled_source_loader, unlabeled_source_loader, unlabeled_target_loader = _get_dataloaders(params) 56 | 57 | params_path = get_pretrain_params_path(output_dir) 58 | with open(params_path, 'w') as f: 59 | json.dump(vars(params), f, indent=4) 60 | pretrain_history_path = get_pretrain_history_path(output_dir) 61 | print('Saving pretrain params to {}'.format(params_path)) 62 | print('Saving pretrain history to {}'.format(pretrain_history_path)) 63 | 64 | if params.pls: 65 | # Load previous pre-trained weights for second-step pre-training 66 | previous_base_output_dir = get_output_directory(params, pls_previous=True) 67 | state_path = get_final_pretrain_state_path(previous_base_output_dir) 68 | print('Loading previous state for second-step pre-training:') 69 | print(state_path) 70 | 71 | # Note, override model.load_state_dict to change this behavior. 72 | state = torch.load(state_path) 73 | missing, unexpected = model.load_state_dict(state, strict=False) 74 | if len(unexpected): 75 | raise Exception("Unexpected keys from previous state: {}".format(unexpected)) 76 | 77 | model.train() 78 | model.cuda() 79 | 80 | if params.optimizer == 'sgd': 81 | optimizer = torch.optim.SGD(model.parameters(), 82 | lr=params.lr, momentum=0.9, 83 | weight_decay=1e-4, 84 | nesterov=False) 85 | elif params.optimizer == 'adam': 86 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 87 | else: 88 | raise ValueError('Invalid value for params.optimizer: {}'.format(params.optimizer)) 89 | 90 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 91 | milestones=[400, 600, 800], 92 | gamma=0.1) 93 | 94 | pretrain_history = { 95 | 'loss': [0] * params.epochs, 96 | 'source_loss': [0] * params.epochs, 97 | 'target_loss': [0] * params.epochs, 98 | } 99 | 100 | for epoch in range(params.epochs): 101 | print('EPOCH {}'.format(epoch).center(40).center(80, '#')) 102 | 103 | epoch_loss = 0 104 | epoch_source_loss = 0 105 | epoch_target_loss = 0 106 | steps = 0 107 | 108 | if epoch == 0: 109 | state_path = get_pretrain_state_path(output_dir, epoch=0) 110 | print('Saving pre-train state to:') 111 | print(state_path) 112 | torch.save(model.state_dict(), state_path) 113 | 114 | model.on_epoch_start() 115 | model.train() 116 | 117 | if params.ls and not params.us and not params.ut: # only ls (type 1) 118 | for x, y in tqdm(labeled_source_loader): 119 | model.on_step_start() 120 | optimizer.zero_grad() 121 | loss, _ = model.compute_cls_loss_and_accuracy(x.cuda(), y.cuda()) 122 | loss.backward() 123 | optimizer.step() 124 | model.on_step_end() 125 | 126 | epoch_loss += loss.item() 127 | epoch_source_loss += loss.item() 128 | steps += 1 129 | elif not params.ls and params.us and not params.ut: # only us (type 2) 130 | for x, _ in tqdm(unlabeled_source_loader): 131 | model.on_step_start() 132 | optimizer.zero_grad() 133 | loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda()) 134 | loss.backward() 135 | optimizer.step() 136 | model.on_step_end() 137 | 138 | epoch_loss += loss.item() 139 | epoch_source_loss += loss.item() 140 | steps += 1 141 | elif params.ut: # ut (epoch is based on unlabeled target) 142 | for x, _ in tqdm(unlabeled_target_loader): 143 | model.on_step_start() 144 | optimizer.zero_grad() 145 | target_loss = model.compute_ssl_loss(x[0].cuda(), x[1].cuda()) # UT loss 146 | epoch_target_loss += target_loss.item() 147 | source_loss = None 148 | if params.ls: # type 4, 7 149 | try: 150 | sx, sy = labeled_source_loader_iter.next() 151 | except (StopIteration, NameError): 152 | labeled_source_loader_iter = iter(labeled_source_loader) 153 | sx, sy = labeled_source_loader_iter.next() 154 | source_loss = model.compute_cls_loss_and_accuracy(sx.cuda(), sy.cuda())[0] # LS loss 155 | epoch_source_loss += source_loss.item() 156 | if params.us: # type 5, 8 157 | try: 158 | sx, sy = unlabeled_source_loader_iter.next() 159 | except (StopIteration, NameError): 160 | unlabeled_source_loader_iter = iter(unlabeled_source_loader) 161 | sx, sy = unlabeled_source_loader_iter.next() 162 | source_loss = model.compute_ssl_loss(sx[0].cuda(), sx[1].cuda()) # US loss 163 | epoch_source_loss += source_loss.item() 164 | 165 | if source_loss: 166 | loss = source_loss * (1 - params.gamma) + target_loss * params.gamma 167 | else: 168 | loss = target_loss 169 | loss.backward() 170 | optimizer.step() 171 | model.on_step_end() 172 | 173 | epoch_loss += loss.item() 174 | steps += 1 175 | else: 176 | raise AssertionError('Unknown training combination.') 177 | 178 | if scheduler is not None: 179 | scheduler.step() 180 | model.on_epoch_end() 181 | 182 | mean_loss = epoch_loss / steps 183 | mean_source_loss = epoch_source_loss / steps 184 | mean_target_loss = epoch_target_loss / steps 185 | fmt = 'Epoch {:04d}: loss={:6.4f} source_loss={:6.4f} target_loss={:6.4f}' 186 | print(fmt.format(epoch, mean_loss, mean_source_loss, mean_target_loss)) 187 | 188 | pretrain_history['loss'][epoch] = mean_loss 189 | pretrain_history['source_loss'][epoch] = mean_source_loss 190 | pretrain_history['target_loss'][epoch] = mean_target_loss 191 | 192 | pd.DataFrame(pretrain_history).to_csv(pretrain_history_path) 193 | 194 | epoch += 1 195 | if epoch % params.model_save_interval == 0 or epoch == params.epochs: 196 | state_path = get_pretrain_state_path(output_dir, epoch=epoch) 197 | print('Saving pre-train state to:') 198 | print(state_path) 199 | torch.save(model.state_dict(), state_path) 200 | 201 | 202 | if __name__ == '__main__': 203 | np.random.seed(10) 204 | params = parse_args('pretrain') 205 | 206 | targets = params.target_dataset 207 | if targets is None: 208 | targets = [targets] 209 | elif len(targets) > 1: 210 | print('#' * 80) 211 | print("Running pretrain iteratively for multiple target datasets: {}".format(targets)) 212 | print('#' * 80) 213 | 214 | for target in targets: 215 | params.target_dataset = target 216 | main(params) 217 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | numpy==1.22.0 3 | pandas==1.3.5 4 | Pillow==10.2.0 5 | torch==1.13.1 6 | torchvision==0.14.1 7 | tqdm==4.62.3 8 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Adam 3 | from torch.optim.lr_scheduler import LambdaLR 4 | from torchvision.models import resnet18 5 | 6 | 7 | class RepeatedMultiStepLR(LambdaLR): 8 | def __init__(self, optimizer, milestones=(400, 600, 800), gamma=0.1, interval=1000, **kwargs): 9 | self.milestones = milestones 10 | self.interval = interval 11 | self.gamma = gamma 12 | super().__init__(optimizer, self._lambda, **kwargs) 13 | 14 | def _lambda(self, epoch): 15 | factor = 1 16 | for milestone in self.milestones: 17 | if epoch % self.interval >= milestone: 18 | factor *= self.gamma 19 | return factor 20 | 21 | 22 | def main(): 23 | resnet = resnet18() 24 | 25 | optimizer1 = Adam(resnet.parameters(), lr=0.1) 26 | optimizer2 = Adam(resnet.parameters(), lr=0.1) 27 | 28 | s1 = torch.optim.lr_scheduler.MultiStepLR(optimizer1, milestones=[400, 600, 800], gamma=0.1) 29 | s2 = RepeatedMultiStepLR(optimizer2, milestones=[400, 600, 800]) 30 | s1_history = [] 31 | s2_history = [] 32 | 33 | for i in range(2000): 34 | # print("Epoch {:04d}: {:.6f} / {:.6f}".format(i, s1.get_last_lr()[0], s2.get_last_lr()[0])) 35 | s1_history.append(s1.get_last_lr()[0]) 36 | s2_history.append(s2.get_last_lr()[0]) 37 | s1.step() 38 | s2.step() 39 | 40 | assert (s1_history[:1000] == s2_history[:1000]) 41 | assert (s1_history[:1000] == s2_history[1000:]) 42 | 43 | print("Manual test passed!") 44 | 45 | 46 | if __name__ == "__main__": # manual unit test 47 | main() 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='CD-FSL', 5 | version='', 6 | packages=['data', 'model', 'methods', 'datasets'], 7 | url='', 8 | license='', 9 | author='', 10 | author_email='', 11 | description='' 12 | ) 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def adjust_learning_rate(optimizer, epoch, lr=0.01, step1=30, step2=60, step3=90): 5 | """Sets the learning rate to the initial LR decayed by 10 every X epochs""" 6 | if epoch >= step3: 7 | lr = lr * 0.001 8 | elif epoch >= step2: 9 | lr = lr * 0.01 10 | elif epoch >= step1: 11 | lr = lr * 0.1 12 | else: 13 | lr = lr 14 | for param_group in optimizer.param_groups: 15 | param_group['lr'] = lr 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | 34 | 35 | def one_hot(y, num_class): 36 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1) 37 | 38 | def sparsity(cl_data_file): 39 | class_list = cl_data_file.keys() 40 | cl_sparsity = [] 41 | for cl in class_list: 42 | cl_sparsity.append(np.mean([np.sum(x!=0) for x in cl_data_file[cl] ]) ) 43 | 44 | return np.mean(cl_sparsity) --------------------------------------------------------------------------------