├── .gitignore ├── LICENSE ├── README.md ├── SimCLR ├── SIMCLR.py ├── configs.py ├── data ├── datasets ├── methods ├── models ├── nx_ent.py ├── run.sh └── utils ├── data ├── __init__.py ├── additional_transforms.py ├── datamgr.py ├── dataset.py └── feature_loader.py ├── datasets ├── Chest_few_shot.py ├── CropDisease_few_shot.py ├── DTD_few_shot.py ├── EuroSAT_few_shot.py ├── ISIC_few_shot.py ├── ImageNet_few_shot.py ├── __init__.py ├── additional_transforms.py ├── caltech256_few_shot.py ├── cifar_few_shot.py ├── miniImageNet_few_shot.py ├── split_seed_1 │ ├── ChestX_labeled_80.csv │ ├── ChestX_unlabeled_20.csv │ ├── CropDisease_labeled_80.csv │ ├── CropDisease_unlabeled_20.csv │ ├── EuroSAT_labeled_80.csv │ ├── EuroSAT_unlabeled_20.csv │ ├── ISIC_labeled_80.csv │ ├── ISIC_unlabeled_20.csv │ ├── miniImageNet_test_labeled_50.csv │ ├── miniImageNet_test_labeled_75.csv │ ├── miniImageNet_test_labeled_80.csv │ ├── miniImageNet_test_unlabeled_20.csv │ ├── miniImageNet_test_unlabeled_25.csv │ ├── miniImageNet_test_unlabeled_50.csv │ ├── tiered_ImageNet_test_labeled_25.csv │ ├── tiered_ImageNet_test_labeled_50.csv │ ├── tiered_ImageNet_test_labeled_75.csv │ ├── tiered_ImageNet_test_labeled_80.csv │ ├── tiered_ImageNet_test_labeled_90.csv │ ├── tiered_ImageNet_test_unlabeled_10.csv │ ├── tiered_ImageNet_test_unlabeled_20.csv │ ├── tiered_ImageNet_test_unlabeled_25.csv │ ├── tiered_ImageNet_test_unlabeled_50.csv │ └── tiered_ImageNet_test_unlabeled_75.csv └── tiered_ImageNet_few_shot.py ├── evaluation ├── compile_result.py ├── configs.py ├── data ├── datasets ├── finetune.py ├── methods ├── models ├── run.sh └── utils ├── methods ├── __init__.py ├── baselinefinetune.py ├── baselinetrain.py ├── meta_template.py ├── models └── protonet.py ├── models ├── __init__.py ├── dataparallel_wrapper.py ├── resnet.py ├── resnet10.py └── resnet12.py ├── student_STARTUP ├── STARTUP.py ├── configs.py ├── data ├── datasets ├── methods ├── models ├── nx_ent.py ├── run.sh └── utils ├── student_STARTUP_no_self_supervision ├── STARTUP_no_SS.py ├── configs.py ├── data ├── datasets ├── methods ├── models ├── run.sh └── utils ├── teacher_ImageNet ├── convert_imagenet_weight.py ├── models └── run.sh ├── teacher_miniImageNet ├── configs.py ├── data ├── datasets ├── io_utils.py ├── methods ├── models ├── run.sh ├── train.py └── utils └── utils ├── AverageMeterSet.py ├── __init__.py ├── accuracy.py ├── average_model.py ├── cdfsl_utils.py ├── count_paramters.py ├── create_logger.py └── savelog.py /.gitignore: -------------------------------------------------------------------------------- 1 | # user-defined files 2 | *.log 3 | *.tar 4 | *.pkl 5 | *.csv 6 | wandb/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Cheng Perng Phoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-training for Few-shot Transfer Across Extreme Task Differences (STARTUP) 2 | 3 | ## Introduction 4 | This repo contains the official implementation of the following ICLR2021 paper: 5 | 6 | **Title:** Self-training for Few-shot Transfer Across Extreme Task Differences 7 | **Authors:** Cheng Perng Phoo, Bharath Hariharan 8 | **Institution:** Cornell University 9 | **Arxiv:** https://arxiv.org/abs/2010.07734 10 | **Abstract:** 11 | Most few-shot learning techniques are pre-trained on a large, labeled "base dataset". In problem domains where such large labeled datasets are not available for pre-training (e.g., X-ray, satellite images), one must resort to pre-training in a different "source" problem domain (e.g., ImageNet), which can be very different from the desired target task. Traditional few-shot and transfer learning techniques fail in the presence of such extreme differences between the source and target tasks. In this paper, we present a simple and effective solution to tackle this extreme domain gap: self-training a source domain representation on unlabeled data from the target domain. We show that this improves one-shot performance on the target domain by 2.9 points on average on the challenging BSCD-FSL benchmark consisting of datasets from multiple domains. 12 | 13 | ### Requirements 14 | This codebase is tested with: 15 | 1. PyTorch 1.7.1 16 | 2. Torchvision 0.8.2 17 | 3. NumPy 18 | 4. Pandas 19 | 5. wandb (used for logging. More here: https://wandb.ai/) 20 | 21 | 22 | 23 | 24 | ## Running Experiments 25 | ### Step 0: Dataset Preparation 26 | **MiniImageNet and CD-FSL:** Download the datasets for CD-FSL benchmark following step 1 and step 2 here: https://github.com/IBM/cdfsl-benchmark 27 | **tieredImageNet:** Prepare the tieredImageNet dataset following https://github.com/mileyan/simple_shot. Note after running the preparation script, you will need to split the saved images into 3 different folders: train, val, test. 28 | 29 | ### Step 1: Teacher Training on the Base Dataset 30 | We provide scripts to produce teachers for different base datasets. Regardless of the base datasets, please follow the following steps to produce the teachers: 31 | 1. Go into the directory `teacher_miniImageNet/` (`teacher_ImageNet/` for ImageNet) 32 | 2. Take care of the `TODO:` in `run.sh` and `configs.py` (if applicable). 33 | 3. Run `bash run.sh` to produce the teachers. 34 | 35 | Note that for miniImageNet and tieredImageNet, the training script is adapted based on the official script provided by the CD-FSL benchmark. For ImageNet, we simply download the pre-trained models from PyTorch and convert them to relevant format. 36 | 37 | ### Step 2: Student Training 38 | To train the STARTUP's representation, please follow the following steps: 39 | 1. Go into the directory `student_STARTUP/` (`student_STARTUP_no_self_supervision/` for the version without SimCLR) 40 | 2. Take care of the `TODO:` in `run.sh` and `configs.py` 41 | 3. Run `bash run.sh` to produce the student/STARTUP representation. 42 | 43 | ### Step 3: Evaluation 44 | To evaluate different representations, go into `evaluation/`, modify the `TODO:` in `run.sh` and `configs.py` and run `bash run.sh`. 45 | 46 | 47 | ## Notes 48 | 1. When producing the results for the submitted paper, we did not set `torch.backends.cudnn.deterministic` and `torch.backends.cudnn.benchmark` properly, thus causing non-deterministic behaviors. We have rerun our experiments and the updated numbers can be found here: https://docs.google.com/spreadsheets/d/1O1e9xdI1SxVvRWK9VVxcO8yefZhePAHGikypWfhRv8c/edit?usp=sharing. Although some of the numbers has changed, the conclusion in the paper remains unchanged. STARTUP is able to outperform all the baselines, bringing forth tremendous improvements to cross-domain few-shot learning. 49 | 2. All the trainings are done on Nvidia Titan RTX GPU. Evaluation of different representations are performed using Nvidia RTX 2080Ti. Regardless of the GPU models, CUDA11 is used. 50 | 3. This repo is built upon the official CD-FSL benchmark repo: https://github.com/IBM/cdfsl-benchmark/tree/9c6a42f4bb3d2638bb85d3e9df3d46e78107bc53. We thank the creators of the CD-FSL benchmark for releasing code to the public. 51 | 4. You can download the model checkpoints here: https://drive.google.com/file/d/1UxOkQB4X29UvnyAgESv2Pyye3mecTmQp/view?usp=sharing 52 | 5. If you find this codebase or STARTUP useful, please consider citing our paper: 53 | ``` 54 | @inproceeding{phoo2021STARTUP, 55 | title={Self-training for Few-shot Transfer Across Extreme Task Differences}, 56 | author={Phoo, Cheng Perng and Hariharan, Bharath}, 57 | booktitle={Proceedings of the International Conference on Learning Representations}, 58 | year={2021} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /SimCLR/configs.py: -------------------------------------------------------------------------------- 1 | 2 | # TODO: Please set the directory to the target datasets accordingly 3 | miniImageNet_path = '/scratch/datasets/CD-FSL/miniImageNet_test' 4 | tiered_ImageNet_path = '/scratch/datasets/tiered_imagenet/tiered_imagenet/original_split/test' 5 | 6 | ISIC_path = "/scratch/datasets/CD-FSL/ISIC" 7 | ChestX_path = "/scratch/datasets/CD-FSL/chestX" 8 | CropDisease_path = "/scratch/datasets/CD-FSL/CropDiseases" 9 | EuroSAT_path = "/scratch/datasets/CD-FSL/EuroSAT/2750" 10 | -------------------------------------------------------------------------------- /SimCLR/data: -------------------------------------------------------------------------------- 1 | ../data -------------------------------------------------------------------------------- /SimCLR/datasets: -------------------------------------------------------------------------------- 1 | ../datasets -------------------------------------------------------------------------------- /SimCLR/methods: -------------------------------------------------------------------------------- 1 | ../methods -------------------------------------------------------------------------------- /SimCLR/models: -------------------------------------------------------------------------------- 1 | ../models -------------------------------------------------------------------------------- /SimCLR/nx_ent.py: -------------------------------------------------------------------------------- 1 | # ported from https://github.com/sthalles/SimCLR/blob/master/loss/nt_xent.py 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class NTXentLoss(torch.nn.Module): 8 | 9 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 10 | super(NTXentLoss, self).__init__() 11 | self.batch_size = batch_size 12 | self.temperature = temperature 13 | self.device = device 14 | self.softmax = torch.nn.Softmax(dim=-1) 15 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) 16 | self.similarity_function = self._get_similarity_function( 17 | use_cosine_similarity) 18 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 19 | 20 | def _get_similarity_function(self, use_cosine_similarity): 21 | if use_cosine_similarity: 22 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 23 | return self._cosine_simililarity 24 | else: 25 | return self._dot_simililarity 26 | 27 | def _get_correlated_mask(self): 28 | diag = np.eye(2 * self.batch_size) 29 | l1 = np.eye((2 * self.batch_size), 2 * 30 | self.batch_size, k=-self.batch_size) 31 | l2 = np.eye((2 * self.batch_size), 2 * 32 | self.batch_size, k=self.batch_size) 33 | mask = torch.from_numpy((diag + l1 + l2)) 34 | mask = (1 - mask).type(torch.bool) 35 | return mask.to(self.device) 36 | 37 | @staticmethod 38 | def _dot_simililarity(x, y): 39 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 40 | # x shape: (N, 1, C) 41 | # y shape: (1, C, 2N) 42 | # v shape: (N, 2N) 43 | return v 44 | 45 | def _cosine_simililarity(self, x, y): 46 | # x shape: (N, 1, C) 47 | # y shape: (1, 2N, C) 48 | # v shape: (N, 2N) 49 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 50 | return v 51 | 52 | def forward(self, zis, zjs): 53 | representations = torch.cat([zjs, zis], dim=0) 54 | 55 | similarity_matrix = self.similarity_function( 56 | representations, representations) 57 | 58 | # filter out the scores from the positive samples 59 | l_pos = torch.diag(similarity_matrix, self.batch_size) 60 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 61 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 62 | 63 | negatives = similarity_matrix[self.mask_samples_from_same_repr].view( 64 | 2 * self.batch_size, -1) 65 | 66 | logits = torch.cat((positives, negatives), dim=1) 67 | logits /= self.temperature 68 | 69 | labels = torch.zeros(2 * self.batch_size).to(self.device).long() 70 | loss = self.criterion(logits, labels) 71 | 72 | return loss / (2 * self.batch_size) 73 | -------------------------------------------------------------------------------- /SimCLR/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # bash script to train SimCLR representation 4 | export CUDA_VISIBLE_DEVICES=1 5 | 6 | 7 | # Before running the commands, please take care of the TODO appropriately 8 | for target_testset in "ChestX" "ISIC" "EuroSAT" "CropDisease" 9 | do 10 | # TODO: Please set the following argument appropriately 11 | # --dir: directory to save the student representation. 12 | # --model: backbone type (supports resnet10, resnet12 and resnet18) 13 | # --teacher_path: initialization for the representation. Remove if want to 14 | # start training from scratch 15 | # E.g. the following commands trains a SimCLR representation (initialized using the weights specified at 16 | # ../teacher_miniImageNet/logs_deterministic/checkpoints/miniImageNet/ResNet10_baseline_256_aug/399.tar) 17 | # The student representation is saved at SimCLR_miniImageNet/$target_testset\_unlabeled_20/checkpoint_best.pkl 18 | python SIMCLR.py \ 19 | --dir SimCLR_miniImageNet/$target_testset\_unlabeled_20 \ 20 | --target_dataset $target_testset \ 21 | --image_size 224 \ 22 | --target_subset_split datasets/split_seed_1/$target_testset\_unlabeled_20.csv \ 23 | --bsize 256 \ 24 | --epochs 1000 \ 25 | --save_freq 50 \ 26 | --print_freq 10 \ 27 | --seed 1 \ 28 | --wd 1e-4 \ 29 | --num_workers 4 \ 30 | --model resnet10 \ 31 | --teacher_path ../teacher_miniImageNet/logs_deterministic/checkpoints/miniImageNet/ResNet10_baseline_256_aug/399.tar \ 32 | --teacher_path_version 0 \ 33 | --eval_freq 2 \ 34 | --batch_validate \ 35 | --resume_latest 36 | done -------------------------------------------------------------------------------- /SimCLR/utils: -------------------------------------------------------------------------------- 1 | ../utils -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datamgr 2 | from . import dataset 3 | from . import additional_transforms 4 | from . import feature_loader 5 | -------------------------------------------------------------------------------- /data/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from PIL import ImageEnhance 10 | 11 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 12 | 13 | 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | 19 | 20 | def __call__(self, img): 21 | out = img 22 | randtensor = torch.rand(len(self.transforms)) 23 | 24 | for i, (transformer, alpha) in enumerate(self.transforms): 25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1 26 | out = transformer(out).enhance(r).convert('RGB') 27 | 28 | return out 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /data/datamgr.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import data.additional_transforms as add_transforms 8 | from data.dataset import SimpleDataset, SetDataset, EpisodicBatchSampler 9 | from abc import abstractmethod 10 | 11 | class TransformLoader: 12 | def __init__(self, image_size, 13 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 14 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 15 | self.image_size = image_size 16 | self.normalize_param = normalize_param 17 | self.jitter_param = jitter_param 18 | 19 | def parse_transform(self, transform_type): 20 | if transform_type=='ImageJitter': 21 | method = add_transforms.ImageJitter( self.jitter_param ) 22 | return method 23 | method = getattr(transforms, transform_type) 24 | if transform_type=='RandomSizedCrop': 25 | return method(self.image_size) 26 | elif transform_type=='CenterCrop': 27 | return method(self.image_size) 28 | elif transform_type=='Scale': 29 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 30 | elif transform_type=='Normalize': 31 | return method(**self.normalize_param ) 32 | else: 33 | return method() 34 | 35 | def get_composed_transform(self, aug = False): 36 | if aug: 37 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 38 | else: 39 | transform_list = ['Scale','CenterCrop', 'ToTensor', 'Normalize'] 40 | 41 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 42 | transform = transforms.Compose(transform_funcs) 43 | return transform 44 | 45 | class DataManager(object): 46 | @abstractmethod 47 | def get_data_loader(self, data_file, aug): 48 | pass 49 | 50 | class SimpleDataManager(DataManager): 51 | def __init__(self, image_size, batch_size): 52 | super(SimpleDataManager, self).__init__() 53 | self.batch_size = batch_size 54 | self.trans_loader = TransformLoader(image_size) 55 | 56 | def get_data_loader(self, data_file, aug): #parameters that would change on train/val set 57 | transform = self.trans_loader.get_composed_transform(aug) 58 | dataset = SimpleDataset(data_file, transform) 59 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True) 60 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 61 | 62 | return data_loader 63 | 64 | class SetDataManager(DataManager): 65 | def __init__(self, image_size, n_way, n_support, n_query, n_eposide = 100): 66 | super(SetDataManager, self).__init__() 67 | self.image_size = image_size 68 | self.n_way = n_way 69 | self.batch_size = n_support + n_query 70 | self.n_eposide = n_eposide 71 | 72 | self.trans_loader = TransformLoader(image_size) 73 | 74 | def get_data_loader(self, data_file, aug): #parameters that would change on train/val set 75 | transform = self.trans_loader.get_composed_transform(aug) 76 | dataset = SetDataset( data_file , self.batch_size, transform ) 77 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 78 | data_loader_params = dict(batch_sampler = sampler, num_workers = 12, pin_memory = True) 79 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 80 | return data_loader -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import json 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | import os 9 | identity = lambda x:x 10 | class SimpleDataset: 11 | def __init__(self, data_file, transform, target_transform=identity): 12 | with open(data_file, 'r') as f: 13 | self.meta = json.load(f) 14 | 15 | self.transform = transform 16 | self.target_transform = target_transform 17 | 18 | def __getitem__(self, i): 19 | image_path = os.path.join(self.meta['image_names'][i]) 20 | img = Image.open(image_path).convert('RGB') 21 | img = self.transform(img) 22 | target = self.target_transform(self.meta['image_labels'][i]) 23 | return img, target 24 | 25 | def __len__(self): 26 | return len(self.meta['image_names']) 27 | 28 | class SetDataset: 29 | def __init__(self, data_file, batch_size, transform): 30 | 31 | with open(data_file, 'r') as f: 32 | self.meta = json.load(f) 33 | 34 | self.cl_list = np.unique(self.meta['image_labels']).tolist() 35 | 36 | self.sub_meta = {} 37 | for cl in self.cl_list: 38 | self.sub_meta[cl] = [] 39 | 40 | for x,y in zip(self.meta['image_names'],self.meta['image_labels']): 41 | self.sub_meta[y].append(x) 42 | 43 | self.sub_dataloader = [] 44 | sub_data_loader_params = dict(batch_size = batch_size, 45 | shuffle = True, 46 | num_workers = 0, #use main thread only or may receive multiple batches 47 | pin_memory = False) 48 | for cl in self.cl_list: 49 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform = transform ) 50 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 51 | 52 | def __getitem__(self,i): 53 | return next(iter(self.sub_dataloader[i])) 54 | 55 | def __len__(self): 56 | return len(self.cl_list) 57 | 58 | 59 | class SubDataset: 60 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 61 | self.sub_meta = sub_meta 62 | self.cl = cl 63 | self.transform = transform 64 | self.target_transform = target_transform 65 | 66 | def __getitem__(self,i): 67 | #print( '%d -%d' %(self.cl,i)) 68 | image_path = os.path.join( self.sub_meta[i]) 69 | img = Image.open(image_path).convert('RGB') 70 | img = self.transform(img) 71 | target = self.target_transform(self.cl) 72 | return img, target 73 | 74 | def __len__(self): 75 | return len(self.sub_meta) 76 | 77 | class EpisodicBatchSampler(object): 78 | def __init__(self, n_classes, n_way, n_episodes): 79 | self.n_classes = n_classes 80 | self.n_way = n_way 81 | self.n_episodes = n_episodes 82 | 83 | def __len__(self): 84 | return self.n_episodes 85 | 86 | def __iter__(self): 87 | for i in range(self.n_episodes): 88 | yield torch.randperm(self.n_classes)[:self.n_way] 89 | -------------------------------------------------------------------------------- /data/feature_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import h5py 4 | import os 5 | 6 | class SimpleHDF5Dataset: 7 | def __init__(self, file_handle = None): 8 | if file_handle == None: 9 | self.f = '' 10 | self.all_feats_dset = [] 11 | self.all_labels = [] 12 | self.total = 0 13 | else: 14 | self.f = file_handle 15 | self.all_feats_dset = self.f['all_feats'][...] 16 | self.all_labels = self.f['all_labels'][...] 17 | self.total = self.f['count'][0] 18 | # print('here') 19 | def __getitem__(self, i): 20 | return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i]) 21 | 22 | def __len__(self): 23 | return self.total 24 | 25 | def init_loader(filename): 26 | if os.path.isfile(filename): 27 | print ('file %s found' % filename) 28 | else: 29 | print ('file %s not found' % filename) 30 | 31 | with h5py.File(filename, 'r') as f: 32 | fileset = SimpleHDF5Dataset(f) 33 | 34 | #labels = [ l for l in fileset.all_labels if l != 0] 35 | feats = fileset.all_feats_dset 36 | labels = fileset.all_labels 37 | 38 | print (feats.shape) 39 | print (feats[-1]) 40 | while np.sum(feats[-1]) == 0: 41 | print ("ok") 42 | feats = np.delete(feats,-1,axis = 0) 43 | labels = np.delete(labels,-1,axis = 0) 44 | 45 | class_list = np.unique(np.array(labels)).tolist() 46 | inds = range(len(labels)) 47 | 48 | cl_data_file = {} 49 | for cl in class_list: 50 | cl_data_file[cl] = [] 51 | 52 | for ind in inds: 53 | cl_data_file[labels[ind]].append( feats[ind]) 54 | 55 | return cl_data_file 56 | -------------------------------------------------------------------------------- /datasets/Chest_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | import torchvision.transforms as transforms 8 | import datasets.additional_transforms as add_transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from abc import abstractmethod 11 | 12 | from PIL import ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | import sys 16 | sys.path.append("../") 17 | import configs 18 | 19 | 20 | def identity(x): return x 21 | 22 | class CustomDatasetFromImages(Dataset): 23 | def __init__(self, transform, target_transform=identity, csv_path=configs.ChestX_path+"/Data_Entry_2017.csv", 24 | image_path = configs.ChestX_path+"/images/", split=None): 25 | """ 26 | Args: 27 | csv_path (string): path to csv file 28 | img_path (string): path to the folder where images are 29 | transform: pytorch transforms for transforms and tensor conversion 30 | target_transform: pytorch transforms for targets 31 | split: the filename of a csv containing a split for the data to be used. 32 | If None, then the full dataset is used. (Default: None) 33 | """ 34 | self.img_path = image_path 35 | self.csv_path = csv_path 36 | self.used_labels = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia", "Pneumothorax"] 37 | 38 | self.labels_maps = {"Atelectasis": 0, "Cardiomegaly": 1, "Effusion": 2, "Infiltration": 3, "Mass": 4, "Nodule": 5, "Pneumothorax": 6} 39 | 40 | labels_set = [] 41 | 42 | # Transforms 43 | self.transform = transform 44 | self.target_transform = target_transform 45 | # Read the csv file 46 | self.data_info = pd.read_csv(csv_path, skiprows=[0], header=None) 47 | 48 | # First column contains the image paths 49 | self.image_name_all = np.asarray(self.data_info.iloc[:, 0]) 50 | self.labels_all = np.asarray(self.data_info.iloc[:, 1]) 51 | 52 | self.image_name = [] 53 | self.labels = [] 54 | 55 | self.split = split 56 | 57 | 58 | for name, label in zip(self.image_name_all,self.labels_all): 59 | label = label.split("|") 60 | 61 | if len(label) == 1 and label[0] != "No Finding" and label[0] != "Pneumonia" and label[0] in self.used_labels: 62 | self.labels.append(self.labels_maps[label[0]]) 63 | self.image_name.append(name) 64 | 65 | self.data_len = len(self.image_name) 66 | self.image_name = np.asarray(self.image_name) 67 | self.labels = np.asarray(self.labels) 68 | 69 | if split is not None: 70 | print("Using Split: ", split) 71 | split = pd.read_csv(split)['img_path'].values 72 | # construct the index 73 | ind = np.concatenate( 74 | [np.where(self.image_name == j)[0] for j in split]) 75 | self.image_name = self.image_name[ind] 76 | self.labels = self.labels[ind] 77 | self.data_len = len(split) 78 | 79 | assert len(self.image_name) == len(split) 80 | assert len(self.labels) == len(split) 81 | # self.targets = self.labels 82 | 83 | def __getitem__(self, index): 84 | # Get image name from the pandas df 85 | single_image_name = self.image_name[index] 86 | 87 | # Open image 88 | img_as_img = Image.open(self.img_path + single_image_name).resize((256, 256)).convert('RGB') 89 | img_as_img.load() 90 | 91 | # Get label(class) of the image based on the cropped pandas column 92 | single_image_label = self.labels[index] 93 | 94 | return self.transform(img_as_img), self.target_transform(single_image_label) 95 | 96 | def __len__(self): 97 | return self.data_len 98 | 99 | 100 | 101 | class SimpleDataset: 102 | def __init__(self, transform, target_transform=identity, split=None): 103 | self.transform = transform 104 | self.target_transform = target_transform 105 | self.split = split 106 | self.d = CustomDatasetFromImages(transform=self.transform, target_transform=self.target_transform, split=split) 107 | 108 | 109 | def __getitem__(self, i): 110 | img, target = self.d[i] 111 | return img, target 112 | 113 | def __len__(self): 114 | return len(self.d) 115 | 116 | 117 | class SetDataset: 118 | def __init__(self, batch_size, transform, split=None): 119 | self.transform = transform 120 | self.split = split 121 | self.d = CustomDatasetFromImages(transform=self.transform, split=split) 122 | 123 | self.cl_list = sorted(np.unique(self.d.labels).tolist()) 124 | 125 | self.sub_dataloader = [] 126 | sub_data_loader_params = dict(batch_size=batch_size, 127 | shuffle=True, 128 | num_workers=0, 129 | pin_memory=False) 130 | for cl in self.cl_list: 131 | ind = np.where(np.array(self.d.labels) == cl)[0].tolist() 132 | sub_dataset = torch.utils.data.Subset(self.d, ind) 133 | self.sub_dataloader.append(torch.utils.data.DataLoader( 134 | sub_dataset, **sub_data_loader_params)) 135 | 136 | def __getitem__(self, i): 137 | return next(iter(self.sub_dataloader[i])) 138 | 139 | def __len__(self): 140 | return len(self.sub_dataloader) 141 | 142 | # class SubDataset: 143 | # def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 144 | # self.sub_meta = sub_meta 145 | # self.cl = cl 146 | # self.transform = transform 147 | # self.target_transform = target_transform 148 | 149 | # def __getitem__(self,i): 150 | 151 | # img = self.transform(self.sub_meta[i]) 152 | # target = self.target_transform(self.cl) 153 | # return img, target 154 | 155 | # def __len__(self): 156 | # return len(self.sub_meta) 157 | 158 | class EpisodicBatchSampler(object): 159 | def __init__(self, n_classes, n_way, n_episodes): 160 | self.n_classes = n_classes 161 | self.n_way = n_way 162 | self.n_episodes = n_episodes 163 | 164 | def __len__(self): 165 | return self.n_episodes 166 | 167 | def __iter__(self): 168 | for i in range(self.n_episodes): 169 | yield torch.randperm(self.n_classes)[:self.n_way] 170 | 171 | class TransformLoader: 172 | def __init__(self, image_size, 173 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 174 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 175 | 176 | self.image_size = image_size 177 | self.normalize_param = normalize_param 178 | self.jitter_param = jitter_param 179 | 180 | def parse_transform(self, transform_type): 181 | if transform_type == 'ImageJitter': 182 | method = add_transforms.ImageJitter(self.jitter_param) 183 | return method 184 | method = getattr(transforms, transform_type) 185 | if transform_type == 'RandomSizedCrop' or transform_type == 'RandomResizedCrop': 186 | return method(self.image_size) 187 | elif transform_type == 'CenterCrop': 188 | return method(self.image_size) 189 | elif transform_type == 'Scale' or transform_type == 'Resize': 190 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 191 | elif transform_type == 'Normalize': 192 | return method(**self.normalize_param) 193 | else: 194 | return method() 195 | 196 | def get_composed_transform(self, aug = False): 197 | if aug: 198 | transform_list = ['RandomResizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 199 | else: 200 | transform_list = ['Resize','CenterCrop', 'ToTensor', 'Normalize'] 201 | 202 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 203 | transform = transforms.Compose(transform_funcs) 204 | return transform 205 | 206 | class DataManager(object): 207 | @abstractmethod 208 | def get_data_loader(self, data_file, aug): 209 | pass 210 | 211 | class SimpleDataManager(DataManager): 212 | def __init__(self, image_size, batch_size, split=None): 213 | super(SimpleDataManager, self).__init__() 214 | self.batch_size = batch_size 215 | self.trans_loader = TransformLoader(image_size) 216 | self.split = split 217 | 218 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 219 | transform = self.trans_loader.get_composed_transform(aug) 220 | dataset = SimpleDataset(transform, split=self.split) 221 | 222 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers=num_workers, pin_memory = True) 223 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 224 | 225 | return data_loader 226 | 227 | class SetDataManager(DataManager): 228 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide = 100, split=None): 229 | super(SetDataManager, self).__init__() 230 | self.image_size = image_size 231 | self.n_way = n_way 232 | self.batch_size = n_support + n_query 233 | self.n_eposide = n_eposide 234 | self.split = split 235 | self.trans_loader = TransformLoader(image_size) 236 | 237 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 238 | transform = self.trans_loader.get_composed_transform(aug) 239 | dataset = SetDataset(self.batch_size, transform, self.split) 240 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 241 | data_loader_params = dict(batch_sampler = sampler, num_workers=num_workers, pin_memory = True) 242 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 243 | return data_loader 244 | 245 | if __name__ == '__main__': 246 | 247 | base_datamgr = SetDataManager(224, n_query=16, n_support=5) 248 | base_loader = base_datamgr.get_data_loader(aug=True) 249 | 250 | -------------------------------------------------------------------------------- /datasets/CropDisease_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | import torchvision.transforms as transforms 8 | import datasets.additional_transforms as add_transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from abc import abstractmethod 11 | from torchvision.datasets import ImageFolder 12 | 13 | import os 14 | 15 | import copy 16 | 17 | from PIL import ImageFile 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | 20 | import sys 21 | sys.path.append("../") 22 | import configs 23 | 24 | 25 | def construct_subset(dataset, split): 26 | split = pd.read_csv(split)['img_path'].values 27 | root = dataset.root 28 | 29 | class_to_idx = dataset.class_to_idx 30 | 31 | # create targets 32 | targets = [class_to_idx[os.path.dirname(i)] for i in split] 33 | 34 | # image_names = np.array([i[0] for i in dataset.imgs]) 35 | 36 | # # ind 37 | # ind = np.concatenate( 38 | # [np.where(image_names == os.path.join(root, j))[0] for j in split]) 39 | 40 | image_names = [os.path.join(root, j) for j in split] 41 | dataset_subset = copy.deepcopy(dataset) 42 | 43 | dataset_subset.samples = [j for j in zip(image_names, targets)] 44 | dataset_subset.imgs = dataset_subset.samples 45 | dataset_subset.targets = targets 46 | return dataset_subset 47 | 48 | 49 | identity = lambda x:x 50 | 51 | class SimpleDataset: 52 | def __init__(self, transform, target_transform=identity, split=None): 53 | self.transform = transform 54 | self.target_transform = target_transform 55 | self.split = split 56 | self.d = ImageFolder(configs.CropDisease_path + "/dataset/train/", 57 | transform=self.transform, 58 | target_transform=self.target_transform) 59 | 60 | if split is not None: 61 | print("Using Split: ", split) 62 | self.d = construct_subset(self.d, split) 63 | 64 | def __getitem__(self, i): 65 | return self.d[i] 66 | 67 | def __len__(self): 68 | return len(self.d) 69 | 70 | 71 | class SetDataset: 72 | def __init__(self, batch_size, transform, split=None): 73 | self.d = ImageFolder(configs.CropDisease_path + "/dataset/train/", transform=transform) 74 | self.split = split 75 | 76 | if split is not None: 77 | print("Using Split: ", split) 78 | self.d = construct_subset(self.d, split) 79 | 80 | self.cl_list = range(len(self.d.classes)) 81 | 82 | self.sub_dataloader = [] 83 | sub_data_loader_params = dict(batch_size=batch_size, 84 | shuffle=True, 85 | num_workers=0, 86 | pin_memory=False) 87 | for cl in self.cl_list: 88 | ind = np.where(np.array(self.d.targets) == cl)[0].tolist() 89 | sub_dataset = torch.utils.data.Subset(self.d, ind) 90 | self.sub_dataloader.append(torch.utils.data.DataLoader( 91 | sub_dataset, **sub_data_loader_params)) 92 | 93 | 94 | def __getitem__(self, i): 95 | return next(iter(self.sub_dataloader[i])) 96 | 97 | def __len__(self): 98 | return len(self.sub_dataloader) 99 | 100 | 101 | class EpisodicBatchSampler(object): 102 | def __init__(self, n_classes, n_way, n_episodes): 103 | self.n_classes = n_classes 104 | self.n_way = n_way 105 | self.n_episodes = n_episodes 106 | 107 | def __len__(self): 108 | return self.n_episodes 109 | 110 | def __iter__(self): 111 | for i in range(self.n_episodes): 112 | yield torch.randperm(self.n_classes)[:self.n_way] 113 | 114 | class TransformLoader: 115 | def __init__(self, image_size, 116 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 117 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 118 | self.image_size = image_size 119 | self.normalize_param = normalize_param 120 | self.jitter_param = jitter_param 121 | 122 | def parse_transform(self, transform_type): 123 | if transform_type=='ImageJitter': 124 | method = add_transforms.ImageJitter( self.jitter_param ) 125 | return method 126 | method = getattr(transforms, transform_type) 127 | if transform_type == 'RandomSizedCrop' or transform_type == 'RandomResizedCrop': 128 | return method(self.image_size) 129 | elif transform_type=='CenterCrop': 130 | return method(self.image_size) 131 | elif transform_type == 'Scale' or transform_type == 'Resize': 132 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 133 | elif transform_type=='Normalize': 134 | return method(**self.normalize_param ) 135 | else: 136 | return method() 137 | 138 | def get_composed_transform(self, aug = False): 139 | if aug: 140 | transform_list = ['RandomResizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 141 | else: 142 | transform_list = ['Resize','CenterCrop', 'ToTensor', 'Normalize'] 143 | 144 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 145 | transform = transforms.Compose(transform_funcs) 146 | return transform 147 | 148 | class DataManager(object): 149 | @abstractmethod 150 | def get_data_loader(self, data_file, aug): 151 | pass 152 | 153 | class SimpleDataManager(DataManager): 154 | def __init__(self, image_size, batch_size, split=None): 155 | super(SimpleDataManager, self).__init__() 156 | self.batch_size = batch_size 157 | self.trans_loader = TransformLoader(image_size) 158 | self.split = split 159 | 160 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 161 | transform = self.trans_loader.get_composed_transform(aug) 162 | dataset = SimpleDataset(transform, split=self.split) 163 | 164 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers=num_workers, pin_memory = True) 165 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 166 | 167 | return data_loader 168 | 169 | class SetDataManager(DataManager): 170 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide = 100, split=None): 171 | super(SetDataManager, self).__init__() 172 | self.image_size = image_size 173 | self.n_way = n_way 174 | self.batch_size = n_support + n_query 175 | self.n_eposide = n_eposide 176 | 177 | self.trans_loader = TransformLoader(image_size) 178 | self.split = split 179 | 180 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 181 | transform = self.trans_loader.get_composed_transform(aug) 182 | dataset = SetDataset(self.batch_size, transform, split=self.split) 183 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide) 184 | data_loader_params = dict(batch_sampler = sampler, num_workers=num_workers, pin_memory=True) 185 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 186 | return data_loader 187 | 188 | if __name__ == '__main__': 189 | 190 | train_few_shot_params = dict(n_way = 5, n_support = 5) 191 | base_datamgr = SetDataManager(224, n_query = 16) 192 | base_loader = base_datamgr.get_data_loader(aug = True) 193 | 194 | cnt = 1 195 | for i, (x, label) in enumerate(base_loader): 196 | if i < cnt: 197 | print(label) 198 | else: 199 | break 200 | -------------------------------------------------------------------------------- /datasets/DTD_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | import torchvision.transforms as transforms 8 | import datasets.additional_transforms as add_transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from abc import abstractmethod 11 | from torchvision.datasets import ImageFolder 12 | 13 | from PIL import ImageFile 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | import sys 17 | sys.path.append("../") 18 | from configs import * 19 | 20 | identity = lambda x:x 21 | class SimpleDataset: 22 | def __init__(self, transform, target_transform=identity): 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | 26 | self.meta = {} 27 | 28 | self.meta['image_names'] = [] 29 | self.meta['image_labels'] = [] 30 | 31 | d = ImageFolder(DTD_path) 32 | 33 | for i, (data, label) in enumerate(d): 34 | self.meta['image_names'].append(data) 35 | self.meta['image_labels'].append(label) 36 | 37 | def __getitem__(self, i): 38 | 39 | img = self.transform(self.meta['image_names'][i]) 40 | target = self.target_transform(self.meta['image_labels'][i]) 41 | 42 | return img, target 43 | 44 | def __len__(self): 45 | return len(self.meta['image_names']) 46 | 47 | 48 | class SetDataset: 49 | def __init__(self, batch_size, transform): 50 | 51 | self.sub_meta = {} 52 | self.cl_list = range(47) 53 | 54 | for cl in self.cl_list: 55 | self.sub_meta[cl] = [] 56 | 57 | d = ImageFolder(DTD_path) 58 | 59 | for i, (data, label) in enumerate(d): 60 | self.sub_meta[label].append(data) 61 | 62 | self.sub_dataloader = [] 63 | sub_data_loader_params = dict(batch_size = batch_size, 64 | shuffle = True, 65 | num_workers = 0, #use main thread only or may receive multiple batches 66 | pin_memory = False) 67 | for cl in self.cl_list: 68 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform = transform ) 69 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 70 | 71 | def __getitem__(self, i): 72 | return next(iter(self.sub_dataloader[i])) 73 | 74 | def __len__(self): 75 | return len(self.sub_dataloader) 76 | 77 | class SubDataset: 78 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 79 | self.sub_meta = sub_meta 80 | self.cl = cl 81 | self.transform = transform 82 | self.target_transform = target_transform 83 | 84 | def __getitem__(self,i): 85 | 86 | img = self.transform(self.sub_meta[i]) 87 | target = self.target_transform(self.cl) 88 | return img, target 89 | 90 | def __len__(self): 91 | return len(self.sub_meta) 92 | 93 | class EpisodicBatchSampler(object): 94 | def __init__(self, n_classes, n_way, n_episodes): 95 | self.n_classes = n_classes 96 | self.n_way = n_way 97 | self.n_episodes = n_episodes 98 | 99 | def __len__(self): 100 | return self.n_episodes 101 | 102 | def __iter__(self): 103 | for i in range(self.n_episodes): 104 | yield torch.randperm(self.n_classes)[:self.n_way] 105 | 106 | class TransformLoader: 107 | def __init__(self, image_size, 108 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 109 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 110 | self.image_size = image_size 111 | self.normalize_param = normalize_param 112 | self.jitter_param = jitter_param 113 | 114 | def parse_transform(self, transform_type): 115 | if transform_type=='ImageJitter': 116 | method = add_transforms.ImageJitter( self.jitter_param ) 117 | return method 118 | method = getattr(transforms, transform_type) 119 | if transform_type=='RandomSizedCrop': 120 | return method(self.image_size) 121 | elif transform_type=='CenterCrop': 122 | return method(self.image_size) 123 | elif transform_type=='Scale': 124 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 125 | elif transform_type=='Normalize': 126 | return method(**self.normalize_param ) 127 | else: 128 | return method() 129 | 130 | def get_composed_transform(self, aug = False): 131 | if aug: 132 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 133 | else: 134 | transform_list = ['Scale','CenterCrop', 'ToTensor', 'Normalize'] 135 | 136 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 137 | transform = transforms.Compose(transform_funcs) 138 | return transform 139 | 140 | class DataManager(object): 141 | @abstractmethod 142 | def get_data_loader(self, data_file, aug): 143 | pass 144 | 145 | class SimpleDataManager(DataManager): 146 | def __init__(self, image_size, batch_size): 147 | super(SimpleDataManager, self).__init__() 148 | self.batch_size = batch_size 149 | self.trans_loader = TransformLoader(image_size) 150 | 151 | def get_data_loader(self, aug): #parameters that would change on train/val set 152 | transform = self.trans_loader.get_composed_transform(aug) 153 | dataset = SimpleDataset(transform) 154 | 155 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True) 156 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 157 | 158 | return data_loader 159 | 160 | class SetDataManager(DataManager): 161 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide = 100): 162 | super(SetDataManager, self).__init__() 163 | self.image_size = image_size 164 | self.n_way = n_way 165 | self.batch_size = n_support + n_query 166 | self.n_eposide = n_eposide 167 | 168 | self.trans_loader = TransformLoader(image_size) 169 | 170 | def get_data_loader(self, aug): #parameters that would change on train/val set 171 | transform = self.trans_loader.get_composed_transform(aug) 172 | dataset = SetDataset(self.batch_size, transform) 173 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 174 | data_loader_params = dict(batch_sampler = sampler, num_workers = 12, pin_memory = True) 175 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 176 | return data_loader 177 | 178 | if __name__ == '__main__': 179 | pass 180 | -------------------------------------------------------------------------------- /datasets/EuroSAT_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | import torchvision.transforms as transforms 8 | import datasets.additional_transforms as add_transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from abc import abstractmethod 11 | from torchvision.datasets import ImageFolder 12 | 13 | import copy 14 | import os 15 | 16 | from PIL import ImageFile 17 | ImageFile.LOAD_TRUNCATED_IMAGES = True 18 | 19 | import sys 20 | sys.path.append("../") 21 | import configs 22 | 23 | identity = lambda x:x 24 | 25 | 26 | def construct_subset(dataset, split): 27 | split = pd.read_csv(split)['img_path'].values 28 | root = dataset.root 29 | 30 | class_to_idx = dataset.class_to_idx 31 | 32 | # create targets 33 | targets = [class_to_idx[os.path.dirname(i)] for i in split] 34 | 35 | # image_names = np.array([i[0] for i in dataset.imgs]) 36 | 37 | # # ind 38 | # ind = np.concatenate( 39 | # [np.where(image_names == os.path.join(root, j))[0] for j in split]) 40 | 41 | image_names = [os.path.join(root, j) for j in split] 42 | dataset_subset = copy.deepcopy(dataset) 43 | 44 | dataset_subset.samples = [j for j in zip(image_names, targets)] 45 | dataset_subset.imgs = dataset_subset.samples 46 | dataset_subset.targets = targets 47 | return dataset_subset 48 | 49 | class SimpleDataset: 50 | def __init__(self, transform, target_transform=identity, split=None): 51 | self.transform = transform 52 | self.target_transform = target_transform 53 | 54 | self.d = ImageFolder(configs.EuroSAT_path, transform=transform, target_transform=target_transform) 55 | self.split = split 56 | if split is not None: 57 | print("Using Split: ", split) 58 | self.d = construct_subset(self.d, split) 59 | 60 | def __getitem__(self, i): 61 | return self.d[i] 62 | 63 | def __len__(self): 64 | return len(self.d) 65 | 66 | 67 | class SetDataset: 68 | def __init__(self, batch_size, transform, split=None): 69 | self.d = ImageFolder(configs.EuroSAT_path, transform=transform) 70 | self.split = split 71 | if split is not None: 72 | print("Using Split: ", split) 73 | self.d = construct_subset(self.d, split) 74 | self.cl_list = range(len(self.d.classes)) 75 | 76 | self.sub_dataloader = [] 77 | sub_data_loader_params = dict(batch_size = batch_size, 78 | shuffle = True, 79 | num_workers = 0, 80 | pin_memory = False) 81 | for cl in self.cl_list: 82 | ind = np.where(np.array(self.d.targets) == cl)[0].tolist() 83 | sub_dataset = torch.utils.data.Subset(self.d, ind) 84 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 85 | 86 | 87 | def __getitem__(self, i): 88 | return next(iter(self.sub_dataloader[i])) 89 | 90 | def __len__(self): 91 | return len(self.sub_dataloader) 92 | 93 | class EpisodicBatchSampler(object): 94 | def __init__(self, n_classes, n_way, n_episodes): 95 | self.n_classes = n_classes 96 | self.n_way = n_way 97 | self.n_episodes = n_episodes 98 | 99 | def __len__(self): 100 | return self.n_episodes 101 | 102 | def __iter__(self): 103 | for i in range(self.n_episodes): 104 | yield torch.randperm(self.n_classes)[:self.n_way] 105 | 106 | class TransformLoader: 107 | def __init__(self, image_size, 108 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 109 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 110 | self.image_size = image_size 111 | self.normalize_param = normalize_param 112 | self.jitter_param = jitter_param 113 | 114 | def parse_transform(self, transform_type): 115 | if transform_type=='ImageJitter': 116 | method = add_transforms.ImageJitter( self.jitter_param ) 117 | return method 118 | method = getattr(transforms, transform_type) 119 | if transform_type == 'RandomSizedCrop' or transform_type == 'RandomResizedCrop': 120 | return method(self.image_size) 121 | elif transform_type=='CenterCrop': 122 | return method(self.image_size) 123 | elif transform_type == 'Scale' or transform_type == 'Resize': 124 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 125 | elif transform_type=='Normalize': 126 | return method(**self.normalize_param ) 127 | else: 128 | return method() 129 | 130 | def get_composed_transform(self, aug = False): 131 | if aug: 132 | transform_list = ['RandomResizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 133 | else: 134 | transform_list = ['Resize','CenterCrop', 'ToTensor', 'Normalize'] 135 | 136 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 137 | transform = transforms.Compose(transform_funcs) 138 | return transform 139 | 140 | class DataManager(object): 141 | @abstractmethod 142 | def get_data_loader(self, data_file, aug): 143 | pass 144 | 145 | class SimpleDataManager(DataManager): 146 | def __init__(self, image_size, batch_size, split=None): 147 | super(SimpleDataManager, self).__init__() 148 | self.batch_size = batch_size 149 | self.trans_loader = TransformLoader(image_size) 150 | self.split = split 151 | 152 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 153 | transform = self.trans_loader.get_composed_transform(aug) 154 | dataset = SimpleDataset(transform, split=self.split) 155 | 156 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers=num_workers, pin_memory = True) 157 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 158 | 159 | return data_loader 160 | 161 | class SetDataManager(DataManager): 162 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide=100, split=None): 163 | super(SetDataManager, self).__init__() 164 | self.image_size = image_size 165 | self.n_way = n_way 166 | self.batch_size = n_support + n_query 167 | self.n_eposide = n_eposide 168 | 169 | self.trans_loader = TransformLoader(image_size) 170 | self.split = split 171 | 172 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 173 | transform = self.trans_loader.get_composed_transform(aug) 174 | dataset = SetDataset(self.batch_size, transform, split=self.split) 175 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 176 | data_loader_params = dict(batch_sampler = sampler, num_workers=num_workers, pin_memory = True) 177 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 178 | return data_loader 179 | 180 | if __name__ == '__main__': 181 | pass 182 | -------------------------------------------------------------------------------- /datasets/ISIC_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | import torchvision.transforms as transforms 8 | import datasets.additional_transforms as add_transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from abc import abstractmethod 11 | 12 | from PIL import ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | import sys 16 | sys.path.append("../") 17 | 18 | import configs 19 | 20 | def identity(x): return x 21 | 22 | class CustomDatasetFromImages(Dataset): 23 | def __init__(self, transform, target_transform=identity, csv_path= configs.ISIC_path + "/ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv", \ 24 | image_path = configs.ISIC_path + "/ISIC2018_Task3_Training_Input/", split=None): 25 | """ 26 | Args: 27 | csv_path (string): path to csv file 28 | img_path (string): path to the folder where images are 29 | transform: pytorch transforms for transforms and tensor conversion 30 | target_transform: pytorch transforms for targets 31 | split: the filename of a csv containing a split for the data to be used. 32 | If None, then the full dataset is used. (Default: None) 33 | """ 34 | self.img_path = image_path 35 | self.csv_path = csv_path 36 | 37 | # Transforms 38 | self.transform = transform 39 | self.target_transform = target_transform 40 | # Read the csv file 41 | self.data_info = pd.read_csv(csv_path, skiprows=[0], header=None) 42 | 43 | # First column contains the image paths 44 | self.image_name = np.asarray(self.data_info.iloc[:, 0]) 45 | 46 | self.labels = np.asarray(self.data_info.iloc[:, 1:]) 47 | self.labels = (self.labels != 0).argmax(axis=1) 48 | 49 | # Calculate len 50 | self.data_len = len(self.image_name) 51 | self.split = split 52 | 53 | if split is not None: 54 | print("Using Split: ", split) 55 | split = pd.read_csv(split)['img_path'].values 56 | # construct the index 57 | ind = np.concatenate([np.where(self.image_name == j)[0] for j in split]) 58 | self.image_name = self.image_name[ind] 59 | self.labels = self.labels[ind] 60 | self.data_len = len(split) 61 | 62 | assert len(self.image_name) == len(split) 63 | assert len(self.labels) == len(split) 64 | # self.targets = self.labels 65 | 66 | def __getitem__(self, index): 67 | # Get image name from the pandas df 68 | single_image_name = self.image_name[index] 69 | # Open image 70 | temp = Image.open(self.img_path + single_image_name + ".jpg") 71 | img_as_img = temp.copy() 72 | # Get label(class) of the image based on the cropped pandas column 73 | single_image_label = self.labels[index] 74 | 75 | return self.transform(img_as_img), self.target_transform(single_image_label) 76 | 77 | def __len__(self): 78 | return self.data_len 79 | 80 | 81 | 82 | class SimpleDataset: 83 | def __init__(self, transform, target_transform=identity, split=None): 84 | self.transform = transform 85 | self.target_transform = target_transform 86 | self.d = CustomDatasetFromImages(transform=self.transform, target_transform=self.target_transform, split=split) 87 | 88 | 89 | def __getitem__(self, i): 90 | img, target = self.d[i] 91 | return img, target 92 | 93 | def __len__(self): 94 | return len(self.d) 95 | 96 | 97 | class SetDataset: 98 | def __init__(self, batch_size, transform, split=None): 99 | self.transform = transform 100 | self.split = split 101 | self.d = CustomDatasetFromImages(transform=self.transform, split=split) 102 | 103 | self.cl_list = sorted(np.unique(self.d.labels).tolist()) 104 | 105 | self.sub_dataloader = [] 106 | sub_data_loader_params = dict(batch_size = batch_size, 107 | shuffle = True, 108 | num_workers = 0, 109 | pin_memory = False) 110 | for cl in self.cl_list: 111 | ind = np.where(np.array(self.d.labels) == cl)[0].tolist() 112 | sub_dataset = torch.utils.data.Subset(self.d, ind) 113 | self.sub_dataloader.append(torch.utils.data.DataLoader( 114 | sub_dataset, **sub_data_loader_params)) 115 | 116 | def __getitem__(self, i): 117 | return next(iter(self.sub_dataloader[i])) 118 | 119 | def __len__(self): 120 | return len(self.sub_dataloader) 121 | 122 | # class SubDataset: 123 | # def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 124 | # self.sub_meta = sub_meta 125 | # self.cl = cl 126 | # self.transform = transform 127 | # self.target_transform = target_transform 128 | 129 | # def __getitem__(self,i): 130 | 131 | # img = self.transform(self.sub_meta[i]) 132 | # target = self.target_transform(self.cl) 133 | # return img, target 134 | 135 | # def __len__(self): 136 | # return len(self.sub_meta) 137 | 138 | class EpisodicBatchSampler(object): 139 | def __init__(self, n_classes, n_way, n_episodes): 140 | self.n_classes = n_classes 141 | self.n_way = n_way 142 | self.n_episodes = n_episodes 143 | 144 | def __len__(self): 145 | return self.n_episodes 146 | 147 | def __iter__(self): 148 | for i in range(self.n_episodes): 149 | yield torch.randperm(self.n_classes)[:self.n_way] 150 | 151 | class TransformLoader: 152 | def __init__(self, image_size, 153 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 154 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 155 | self.image_size = image_size 156 | self.normalize_param = normalize_param 157 | self.jitter_param = jitter_param 158 | 159 | def parse_transform(self, transform_type): 160 | if transform_type=='ImageJitter': 161 | method = add_transforms.ImageJitter( self.jitter_param ) 162 | return method 163 | method = getattr(transforms, transform_type) 164 | if transform_type == 'RandomSizedCrop' or transform_type == 'RandomResizedCrop': 165 | return method(self.image_size) 166 | elif transform_type=='CenterCrop': 167 | return method(self.image_size) 168 | elif transform_type == 'Scale' or transform_type == 'Resize': 169 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 170 | elif transform_type=='Normalize': 171 | return method(**self.normalize_param ) 172 | else: 173 | return method() 174 | 175 | def get_composed_transform(self, aug = False): 176 | if aug: 177 | transform_list = ['RandomResizedCrop', 'ImageJitter', 178 | 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 179 | else: 180 | transform_list = ['Resize', 'CenterCrop', 'ToTensor', 'Normalize'] 181 | 182 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 183 | transform = transforms.Compose(transform_funcs) 184 | return transform 185 | 186 | class DataManager(object): 187 | @abstractmethod 188 | def get_data_loader(self, data_file, aug): 189 | pass 190 | 191 | class SimpleDataManager(DataManager): 192 | def __init__(self, image_size, batch_size, split=None): 193 | super(SimpleDataManager, self).__init__() 194 | self.batch_size = batch_size 195 | self.trans_loader = TransformLoader(image_size) 196 | self.split = split 197 | 198 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 199 | transform = self.trans_loader.get_composed_transform(aug) 200 | dataset = SimpleDataset(transform, split=self.split) 201 | 202 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers=num_workers, pin_memory = True) 203 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 204 | 205 | return data_loader 206 | 207 | class SetDataManager(DataManager): 208 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide = 100, split=None): 209 | super(SetDataManager, self).__init__() 210 | self.image_size = image_size 211 | self.n_way = n_way 212 | self.batch_size = n_support + n_query 213 | self.n_eposide = n_eposide 214 | 215 | self.trans_loader = TransformLoader(image_size) 216 | self.split = split 217 | 218 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 219 | transform = self.trans_loader.get_composed_transform(aug) 220 | dataset = SetDataset(self.batch_size, transform, split=self.split) 221 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 222 | data_loader_params = dict(batch_sampler = sampler, num_workers = num_workers, pin_memory = True) 223 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 224 | return data_loader 225 | 226 | if __name__ == '__main__': 227 | 228 | train_few_shot_params = dict(n_way = 5, n_support = 5) 229 | base_datamgr = SetDataManager(224, n_query = 16) 230 | base_loader = base_datamgr.get_data_loader(aug = True) 231 | 232 | cnt = 1 233 | for i, (x, label) in enumerate(base_loader): 234 | if i < cnt: 235 | print(label.size()) 236 | else: 237 | break 238 | -------------------------------------------------------------------------------- /datasets/ImageNet_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import configs 4 | import sys 5 | import torch 6 | from PIL import Image 7 | import numpy as np 8 | import pandas as pd 9 | import torchvision.transforms as transforms 10 | import datasets.additional_transforms as add_transforms 11 | from torch.utils.data import Dataset, DataLoader 12 | from abc import abstractmethod 13 | from torchvision.datasets import ImageFolder 14 | 15 | import copy 16 | import os 17 | 18 | from PIL import ImageFile 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | 21 | sys.path.append("../") 22 | 23 | 24 | def identity(x): return x 25 | 26 | 27 | def construct_subset(dataset, split): 28 | split = pd.read_csv(split)['img_path'].values 29 | root = dataset.root 30 | 31 | class_to_idx = dataset.class_to_idx 32 | 33 | # create targets 34 | targets = [class_to_idx[os.path.dirname(i)] for i in split] 35 | 36 | # image_names = np.array([i[0] for i in dataset.imgs]) 37 | 38 | # # ind 39 | # ind = np.concatenate( 40 | # [np.where(image_names == os.path.join(root, j))[0] for j in split]) 41 | 42 | image_names = [os.path.join(root, j) for j in split] 43 | dataset_subset = copy.deepcopy(dataset) 44 | 45 | dataset_subset.samples = [j for j in zip(image_names, targets)] 46 | dataset_subset.imgs = dataset_subset.samples 47 | dataset_subset.targets = targets 48 | return dataset_subset 49 | 50 | 51 | class SimpleDataset: 52 | def __init__(self, transform, target_transform=identity, split=None): 53 | self.transform = transform 54 | self.target_transform = target_transform 55 | 56 | self.d = ImageFolder( 57 | configs.ImageNet_path, transform=transform, target_transform=target_transform) 58 | self.split = split 59 | if split is not None: 60 | print("Using Split") 61 | self.d = construct_subset(self.d, split) 62 | 63 | def __getitem__(self, i): 64 | return self.d[i] 65 | 66 | def __len__(self): 67 | return len(self.d) 68 | 69 | 70 | class SetDataset: 71 | def __init__(self, batch_size, transform, split=None): 72 | self.d = ImageFolder(configs.ImageNet_path, transform=transform) 73 | self.split = split 74 | if split is not None: 75 | print("Using Split") 76 | self.d = construct_subset(self.d, split) 77 | self.cl_list = range(len(self.d.classes)) 78 | 79 | self.sub_dataloader = [] 80 | sub_data_loader_params = dict(batch_size=batch_size, 81 | shuffle=True, 82 | num_workers=0, 83 | pin_memory=False) 84 | for cl in self.cl_list: 85 | ind = np.where(np.array(self.d.targets) == cl)[0].tolist() 86 | sub_dataset = torch.utils.data.Subset(self.d, ind) 87 | self.sub_dataloader.append(torch.utils.data.DataLoader( 88 | sub_dataset, **sub_data_loader_params)) 89 | 90 | def __getitem__(self, i): 91 | return next(iter(self.sub_dataloader[i])) 92 | 93 | def __len__(self): 94 | return len(self.sub_dataloader) 95 | 96 | 97 | class EpisodicBatchSampler(object): 98 | def __init__(self, n_classes, n_way, n_episodes): 99 | self.n_classes = n_classes 100 | self.n_way = n_way 101 | self.n_episodes = n_episodes 102 | 103 | def __len__(self): 104 | return self.n_episodes 105 | 106 | def __iter__(self): 107 | for i in range(self.n_episodes): 108 | yield torch.randperm(self.n_classes)[:self.n_way] 109 | 110 | 111 | class TransformLoader: 112 | def __init__(self, image_size, 113 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[ 114 | 0.229, 0.224, 0.225]), 115 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 116 | self.image_size = image_size 117 | self.normalize_param = normalize_param 118 | self.jitter_param = jitter_param 119 | 120 | def parse_transform(self, transform_type): 121 | if transform_type == 'ImageJitter': 122 | method = add_transforms.ImageJitter(self.jitter_param) 123 | return method 124 | 125 | if transform_type == 'Scale_original' or transform_type == 'Resize_original': 126 | return transforms.Resize([int(self.image_size), int(self.image_size)]) 127 | 128 | method = getattr(transforms, transform_type) 129 | if transform_type == 'RandomSizedCrop' or transform_type == 'RandomResizedCrop': 130 | return method(self.image_size) 131 | elif transform_type == 'CenterCrop': 132 | return method(self.image_size) 133 | elif transform_type == 'Scale' or transform_type == 'Resize': 134 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 135 | elif transform_type == 'Normalize': 136 | return method(**self.normalize_param) 137 | else: 138 | return method() 139 | 140 | def get_composed_transform(self, aug=False): 141 | if aug: 142 | transform_list = ['RandomResizedCrop', 'ImageJitter', 143 | 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 144 | else: 145 | transform_list = ['Resize', 'ToTensor', 'Normalize'] 146 | 147 | transform_funcs = [self.parse_transform(x) for x in transform_list] 148 | transform = transforms.Compose(transform_funcs) 149 | return transform 150 | 151 | 152 | class DataManager(object): 153 | @abstractmethod 154 | def get_data_loader(self, data_file, aug): 155 | pass 156 | 157 | 158 | class SimpleDataManager(DataManager): 159 | def __init__(self, image_size, batch_size, split=None): 160 | super(SimpleDataManager, self).__init__() 161 | self.batch_size = batch_size 162 | self.trans_loader = TransformLoader(image_size) 163 | self.split = split 164 | 165 | # parameters that would change on train/val set 166 | def get_data_loader(self, aug, num_workers=12): 167 | transform = self.trans_loader.get_composed_transform(aug) 168 | dataset = SimpleDataset(transform, split=self.split) 169 | 170 | data_loader_params = dict( 171 | batch_size=self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 172 | data_loader = torch.utils.data.DataLoader( 173 | dataset, **data_loader_params) 174 | 175 | return data_loader 176 | 177 | 178 | class SetDataManager(DataManager): 179 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide=100, split=None): 180 | super(SetDataManager, self).__init__() 181 | self.image_size = image_size 182 | self.n_way = n_way 183 | self.batch_size = n_support + n_query 184 | self.n_eposide = n_eposide 185 | 186 | self.trans_loader = TransformLoader(image_size) 187 | self.split = split 188 | 189 | # parameters that would change on train/val set 190 | def get_data_loader(self, aug, num_workers=12): 191 | transform = self.trans_loader.get_composed_transform(aug) 192 | dataset = SetDataset(self.batch_size, transform, split=self.split) 193 | sampler = EpisodicBatchSampler( 194 | len(dataset), self.n_way, self.n_eposide) 195 | data_loader_params = dict( 196 | batch_sampler=sampler, num_workers=num_workers, pin_memory=True) 197 | data_loader = torch.utils.data.DataLoader( 198 | dataset, **data_loader_params) 199 | return data_loader 200 | 201 | 202 | if __name__ == '__main__': 203 | pass 204 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpphoo/STARTUP/2c9c58054a477dc0081ea5d3c77f0a5386078172/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/additional_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from PIL import ImageEnhance 10 | 11 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color) 12 | 13 | 14 | 15 | class ImageJitter(object): 16 | def __init__(self, transformdict): 17 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 18 | 19 | 20 | def __call__(self, img): 21 | out = img 22 | randtensor = torch.rand(len(self.transforms)) 23 | 24 | for i, (transformer, alpha) in enumerate(self.transforms): 25 | r = alpha*(randtensor[i]*2.0 -1.0) + 1 26 | out = transformer(out).enhance(r).convert('RGB') 27 | 28 | return out 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /datasets/caltech256_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | from . import additional_transforms as add_transforms 8 | from abc import abstractmethod 9 | 10 | import os 11 | import glob 12 | from torchvision.datasets.utils import download_url, check_integrity 13 | import torch.utils.data as data 14 | 15 | class Caltech256(data.Dataset): 16 | """`Caltech256. 17 | Args: 18 | root (string): Root directory of dataset where directory 19 | ``256_ObjectCategories`` exists. 20 | train (bool, optional): Not used 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | download (bool, optional): If true, downloads the dataset from the internet and 26 | puts it in root directory. If dataset is already downloaded, it is not 27 | downloaded again. 28 | """ 29 | base_folder = '256_ObjectCategories' 30 | url = "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar" 31 | filename = "256_ObjectCategories.tar" 32 | tgz_md5 = '67b4f42ca05d46448c6bb8ecd2220f6d' 33 | 34 | def __init__(self, root, train=True, 35 | transform=None, target_transform=None, 36 | download=False): 37 | self.root = os.path.expanduser(root) 38 | self.transform = transform 39 | self.target_transform = target_transform 40 | 41 | if download: 42 | self.download() 43 | 44 | if not self._check_integrity(): 45 | raise RuntimeError('Dataset not found or corrupted.' + 46 | ' You can use download=True to download it') 47 | 48 | self.data = [] 49 | self.labels = [] 50 | 51 | for cat in range(0, 257): 52 | print (cat) 53 | 54 | cat_dirs = glob.glob(os.path.join(self.root, self.base_folder, '%03d*' % cat)) 55 | 56 | for fdir in cat_dirs: 57 | for fimg in glob.glob(os.path.join(fdir, '*.jpg')): 58 | img = Image.open(fimg).convert("RGB") 59 | 60 | self.data.append(img) 61 | self.labels.append(cat) 62 | 63 | def __getitem__(self, index): 64 | """ 65 | Args: 66 | index (int): Index 67 | Returns: 68 | tuple: (image, target) where target is index of the target class. 69 | """ 70 | img, target = self.data[index], self.labels[index] 71 | 72 | #img = Image.fromarray(img) 73 | 74 | if self.transform is not None: 75 | img = self.transform(img) 76 | 77 | if self.target_transform is not None: 78 | target = self.target_transform(target) 79 | 80 | return img, target 81 | 82 | def __len__(self): 83 | return len(self.data) 84 | 85 | def _check_integrity(self): 86 | fpath = os.path.join(self.root, self.filename) 87 | if not check_integrity(fpath, self.tgz_md5): 88 | return False 89 | return True 90 | 91 | def download(self): 92 | import tarfile 93 | 94 | root = self.root 95 | download_url(self.url, root, self.filename, self.tgz_md5) 96 | 97 | # extract file 98 | cwd = os.getcwd() 99 | tar = tarfile.open(os.path.join(root, self.filename), "r") 100 | os.chdir(root) 101 | tar.extractall() 102 | tar.close() 103 | os.chdir(cwd) 104 | 105 | def __repr__(self): 106 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 107 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 108 | fmt_str += ' Root Location: {}\n'.format(self.root) 109 | tmp = ' Transforms (if any): ' 110 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 111 | tmp = ' Target Transforms (if any): ' 112 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 113 | return fmt_str 114 | 115 | 116 | identity = lambda x:x 117 | class SimpleDataset: 118 | def __init__(self, transform, target_transform=identity): 119 | self.transform = transform 120 | self.target_transform = target_transform 121 | 122 | 123 | self.d = Caltech256(root='./', transform=transform, target_transform=target_transform, download=True) 124 | #for i, (data, label) in enumerate(d): 125 | # self.meta['image_names'].append(data) 126 | # self.meta['image_labels'].append(label) 127 | 128 | def __getitem__(self, i): 129 | 130 | #img = self.transform(self.meta['image_names'][i]) 131 | #target = self.target_transform(self.meta['image_labels'][i]) 132 | img, target = self.d[i] 133 | 134 | return img, target 135 | 136 | def __len__(self): 137 | #return len(self.meta['image_names']) 138 | return len(self.d) 139 | 140 | 141 | class SetDataset: 142 | def __init__(self, batch_size, transform): 143 | 144 | self.sub_meta = {} 145 | self.cl_list = range(257) 146 | 147 | for cl in self.cl_list: 148 | self.sub_meta[cl] = [] 149 | 150 | d = Caltech256(root='./', download=False) 151 | for i, (data, label) in enumerate(d): 152 | self.sub_meta[label].append(data) 153 | 154 | self.sub_dataloader = [] 155 | sub_data_loader_params = dict(batch_size = batch_size, 156 | shuffle = True, 157 | num_workers = 0, #use main thread only or may receive multiple batches 158 | pin_memory = False) 159 | for cl in self.cl_list: 160 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform = transform ) 161 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 162 | 163 | def __getitem__(self,i): 164 | return next(iter(self.sub_dataloader[i])) 165 | 166 | def __len__(self): 167 | return len(self.sub_dataloader) 168 | 169 | class SubDataset: 170 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 171 | self.sub_meta = sub_meta 172 | self.cl = cl 173 | self.transform = transform 174 | self.target_transform = target_transform 175 | 176 | def __getitem__(self,i): 177 | 178 | img = self.transform(self.sub_meta[i]) 179 | target = self.target_transform(self.cl) 180 | return img, target 181 | 182 | def __len__(self): 183 | return len(self.sub_meta) 184 | 185 | class EpisodicBatchSampler(object): 186 | def __init__(self, n_classes, n_way, n_episodes): 187 | self.n_classes = n_classes 188 | self.n_way = n_way 189 | self.n_episodes = n_episodes 190 | 191 | def __len__(self): 192 | return self.n_episodes 193 | 194 | def __iter__(self): 195 | for i in range(self.n_episodes): 196 | yield torch.randperm(self.n_classes)[:self.n_way] 197 | 198 | class TransformLoader: 199 | def __init__(self, image_size, 200 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 201 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 202 | self.image_size = image_size 203 | self.normalize_param = normalize_param 204 | self.jitter_param = jitter_param 205 | 206 | def parse_transform(self, transform_type): 207 | if transform_type=='ImageJitter': 208 | method = add_transforms.ImageJitter( self.jitter_param ) 209 | return method 210 | method = getattr(transforms, transform_type) 211 | if transform_type=='RandomSizedCrop': 212 | return method(self.image_size) 213 | elif transform_type=='CenterCrop': 214 | return method(self.image_size) 215 | elif transform_type=='Scale': 216 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 217 | elif transform_type=='Normalize': 218 | return method(**self.normalize_param ) 219 | else: 220 | return method() 221 | 222 | def get_composed_transform(self, aug = False): 223 | if aug: 224 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 225 | else: 226 | transform_list = ['Scale','CenterCrop', 'ToTensor', 'Normalize'] 227 | 228 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 229 | transform = transforms.Compose(transform_funcs) 230 | return transform 231 | 232 | class DataManager(object): 233 | @abstractmethod 234 | def get_data_loader(self, data_file, aug): 235 | pass 236 | 237 | class SimpleDataManager(DataManager): 238 | def __init__(self, image_size, batch_size): 239 | super(SimpleDataManager, self).__init__() 240 | self.batch_size = batch_size 241 | self.trans_loader = TransformLoader(image_size) 242 | 243 | def get_data_loader(self, aug): #parameters that would change on train/val set 244 | transform = self.trans_loader.get_composed_transform(aug) 245 | dataset = SimpleDataset(transform) 246 | 247 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True) 248 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 249 | 250 | return data_loader 251 | 252 | class SetDataManager(DataManager): 253 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide = 100): 254 | super(SetDataManager, self).__init__() 255 | self.image_size = image_size 256 | self.n_way = n_way 257 | self.batch_size = n_support + n_query 258 | self.n_eposide = n_eposide 259 | 260 | self.trans_loader = TransformLoader(image_size) 261 | 262 | def get_data_loader(self, aug): #parameters that would change on train/val set 263 | transform = self.trans_loader.get_composed_transform(aug) 264 | dataset = SetDataset(self.batch_size, transform) 265 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 266 | data_loader_params = dict(batch_sampler = sampler, num_workers = 12, pin_memory = True) 267 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 268 | return data_loader 269 | 270 | if __name__ == '__main__': 271 | pass -------------------------------------------------------------------------------- /datasets/cifar_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import additional_transforms as add_transforms 8 | from abc import abstractmethod 9 | from torchvision.datasets import CIFAR100, CIFAR10 10 | 11 | identity = lambda x:x 12 | class SimpleDataset: 13 | def __init__(self, mode, dataset, transform, target_transform=identity): 14 | self.transform = transform 15 | self.dataset = dataset 16 | self.target_transform = target_transform 17 | 18 | self.meta = {} 19 | 20 | self.meta['image_names'] = [] 21 | self.meta['image_labels'] = [] 22 | if self.dataset == "CIFAR100": 23 | 24 | d = CIFAR100("./", train=True, download=True) 25 | for i, (data, label) in enumerate(d): 26 | if mode == "base": 27 | if label % 3 == 0: 28 | self.meta['image_names'].append(data) 29 | self.meta['image_labels'].append(label) 30 | elif mode == "val": 31 | if label % 3 == 1: 32 | self.meta['image_names'].append(data) 33 | self.meta['image_labels'].append(label) 34 | else: 35 | if label % 3 == 2: 36 | self.meta['image_names'].append(data) 37 | self.meta['image_labels'].append(label) 38 | 39 | elif self.dataset == "CIFAR10": 40 | d = CIFAR10("./", train=True, download=True) 41 | for i, (data, label) in enumerate(d): 42 | if mode == "novel": 43 | self.meta['image_names'].append(data) 44 | self.meta['image_labels'].append(label) 45 | 46 | def __getitem__(self, i): 47 | 48 | img = self.transform(self.meta['image_names'][i]) 49 | target = self.target_transform(self.meta['image_labels'][i]) 50 | 51 | return img, target 52 | 53 | def __len__(self): 54 | return len(self.meta['image_names']) 55 | 56 | 57 | class SetDataset: 58 | def __init__(self, mode, dataset, batch_size, transform): 59 | 60 | self.sub_meta = {} 61 | self.cl_list = range(100) 62 | self.dataset = dataset 63 | 64 | if mode == "base": 65 | type_ = 0 66 | elif mode == "val": 67 | type_ = 1 68 | else: 69 | type_ = 2 70 | 71 | for cl in self.cl_list: 72 | if cl % 3 == type_: 73 | self.sub_meta[cl] = [] 74 | 75 | if self.dataset == "CIFAR100": 76 | d = CIFAR100("./", train=True, download=True) 77 | elif self.dataset == "CIFAR10": 78 | d = CIFAR10("./", train=True, download=True) 79 | 80 | 81 | for i, (data, label) in enumerate(d): 82 | if label % 3 == type_: 83 | self.sub_meta[label].append(data) 84 | 85 | self.sub_dataloader = [] 86 | sub_data_loader_params = dict(batch_size = batch_size, 87 | shuffle = True, 88 | num_workers = 0, #use main thread only or may receive multiple batches 89 | pin_memory = False) 90 | for cl in self.cl_list: 91 | if cl % 3 == type_: 92 | sub_dataset = SubDataset(self.sub_meta[cl], cl, transform = transform ) 93 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 94 | 95 | def __getitem__(self,i): 96 | return next(iter(self.sub_dataloader[i])) 97 | 98 | def __len__(self): 99 | return len(self.sub_dataloader) 100 | 101 | class SubDataset: 102 | def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity): 103 | self.sub_meta = sub_meta 104 | self.cl = cl 105 | self.transform = transform 106 | self.target_transform = target_transform 107 | 108 | def __getitem__(self,i): 109 | 110 | img = self.transform(self.sub_meta[i]) 111 | target = self.target_transform(self.cl) 112 | return img, target 113 | 114 | def __len__(self): 115 | return len(self.sub_meta) 116 | 117 | class EpisodicBatchSampler(object): 118 | def __init__(self, n_classes, n_way, n_episodes): 119 | self.n_classes = n_classes 120 | self.n_way = n_way 121 | self.n_episodes = n_episodes 122 | 123 | def __len__(self): 124 | return self.n_episodes 125 | 126 | def __iter__(self): 127 | for i in range(self.n_episodes): 128 | yield torch.randperm(self.n_classes)[:self.n_way] 129 | 130 | class TransformLoader: 131 | def __init__(self, image_size, 132 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 133 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 134 | self.image_size = image_size 135 | self.normalize_param = normalize_param 136 | self.jitter_param = jitter_param 137 | 138 | def parse_transform(self, transform_type): 139 | if transform_type=='ImageJitter': 140 | method = add_transforms.ImageJitter( self.jitter_param ) 141 | return method 142 | method = getattr(transforms, transform_type) 143 | if transform_type=='RandomSizedCrop': 144 | return method(self.image_size) 145 | elif transform_type=='CenterCrop': 146 | return method(self.image_size) 147 | elif transform_type=='Scale': 148 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 149 | elif transform_type=='Normalize': 150 | return method(**self.normalize_param ) 151 | else: 152 | return method() 153 | 154 | def get_composed_transform(self, aug = False): 155 | if aug: 156 | transform_list = ['RandomSizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 157 | else: 158 | transform_list = ['Scale','CenterCrop', 'ToTensor', 'Normalize'] 159 | 160 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 161 | transform = transforms.Compose(transform_funcs) 162 | return transform 163 | 164 | class DataManager(object): 165 | @abstractmethod 166 | def get_data_loader(self, data_file, aug): 167 | pass 168 | 169 | class SimpleDataManager(DataManager): 170 | def __init__(self, dataset, image_size, batch_size): 171 | super(SimpleDataManager, self).__init__() 172 | self.batch_size = batch_size 173 | self.trans_loader = TransformLoader(image_size) 174 | self.dataset = dataset 175 | 176 | def get_data_loader(self, mode, aug): #parameters that would change on train/val set 177 | transform = self.trans_loader.get_composed_transform(aug) 178 | dataset = SimpleDataset(mode, self.dataset, transform) 179 | 180 | data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True) 181 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 182 | 183 | return data_loader 184 | 185 | class SetDataManager(DataManager): 186 | def __init__(self, mode, dataset, image_size, n_way=5, n_support=5, n_query=16, n_eposide = 100): 187 | super(SetDataManager, self).__init__() 188 | self.image_size = image_size 189 | self.n_way = n_way 190 | self.batch_size = n_support + n_query 191 | self.n_eposide = n_eposide 192 | self.mode = mode 193 | self.dataset = dataset 194 | 195 | self.trans_loader = TransformLoader(image_size) 196 | 197 | def get_data_loader(self, aug): #parameters that would change on train/val set 198 | transform = self.trans_loader.get_composed_transform(aug) 199 | dataset = SetDataset(self.mode, self.dataset, self.batch_size, transform) 200 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide ) 201 | data_loader_params = dict(batch_sampler = sampler, num_workers = 12, pin_memory = True) 202 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 203 | return data_loader 204 | 205 | if __name__ == '__main__': 206 | pass -------------------------------------------------------------------------------- /datasets/miniImageNet_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | import torchvision.transforms as transforms 8 | import datasets.additional_transforms as add_transforms 9 | from torch.utils.data import Dataset, DataLoader, Subset 10 | from abc import abstractmethod 11 | from torchvision.datasets import ImageFolder 12 | 13 | from PIL import ImageFile 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | import sys 17 | sys.path.append("../") 18 | import configs 19 | 20 | import os 21 | import copy 22 | 23 | def construct_subset(dataset, split): 24 | print("Using split: ", split) 25 | split = pd.read_csv(split)['img_path'].values 26 | root = dataset.root 27 | 28 | class_to_idx = dataset.class_to_idx 29 | targets = [class_to_idx[os.path.dirname(i)] for i in split] 30 | 31 | # image_names = np.array([i[0] for i in dataset.imgs]) 32 | # # ind 33 | # ind = np.concatenate([np.where(image_names == os.path.join(root, j))[0] for j in split]) 34 | image_names = [os.path.join(root, j) for j in split] 35 | dataset_subset = copy.deepcopy(dataset) 36 | 37 | dataset_subset.samples = [j for j in zip(image_names, targets)] 38 | dataset_subset.imgs = dataset_subset.samples 39 | dataset_subset.targets = targets 40 | return dataset_subset 41 | 42 | identity = lambda x:x 43 | 44 | class SimpleDataset: 45 | def __init__(self, transform, target_transform=identity, split=None): 46 | self.transform = transform 47 | self.target_transform = target_transform 48 | self.split = None 49 | self.d = ImageFolder(configs.miniImageNet_path, transform=self.transform, 50 | target_transform=self.target_transform) 51 | 52 | if split is not None: 53 | self.d = construct_subset(self.d, split) 54 | 55 | 56 | def __getitem__(self, i): 57 | return self.d[i] 58 | 59 | def __len__(self): 60 | return len(self.d) 61 | 62 | class SetDataset: 63 | def __init__(self, batch_size, transform, split=None): 64 | ''' 65 | Split the the dataset into sub dataset (each dataset belongs to the same class) 66 | ''' 67 | 68 | self.d = ImageFolder(configs.miniImageNet_path, transform=transform) 69 | self.split = split 70 | 71 | if split is not None: 72 | self.d = construct_subset(self.d, split) 73 | 74 | self.cl_list = range(len(self.d.classes)) 75 | 76 | self.sub_dataloader = [] 77 | sub_data_loader_params = dict(batch_size = batch_size, 78 | shuffle = True, 79 | num_workers = 0, 80 | pin_memory = False) 81 | for cl in self.cl_list: 82 | ind = np.where(np.array(self.d.targets) == cl)[0].tolist() 83 | sub_dataset = torch.utils.data.Subset(self.d, ind) 84 | self.sub_dataloader.append( torch.utils.data.DataLoader(sub_dataset, **sub_data_loader_params) ) 85 | 86 | def __getitem__(self, i): 87 | return next(iter(self.sub_dataloader[i])) 88 | 89 | def __len__(self): 90 | return len(self.sub_dataloader) 91 | 92 | class EpisodicBatchSampler(object): 93 | def __init__(self, n_classes, n_way, n_episodes): 94 | self.n_classes = n_classes 95 | self.n_way = n_way 96 | self.n_episodes = n_episodes 97 | 98 | def __len__(self): 99 | return self.n_episodes 100 | 101 | def __iter__(self): 102 | for i in range(self.n_episodes): 103 | yield torch.randperm(self.n_classes)[:self.n_way] 104 | 105 | class TransformLoader: 106 | def __init__(self, image_size, 107 | normalize_param = dict(mean= [0.485, 0.456, 0.406] , std=[0.229, 0.224, 0.225]), 108 | jitter_param = dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 109 | self.image_size = image_size 110 | self.normalize_param = normalize_param 111 | self.jitter_param = jitter_param 112 | 113 | def parse_transform(self, transform_type): 114 | if transform_type=='ImageJitter': 115 | method = add_transforms.ImageJitter( self.jitter_param ) 116 | return method 117 | method = getattr(transforms, transform_type) 118 | if transform_type=='RandomSizedCrop' or transform_type == 'RandomResizedCrop': 119 | return method(self.image_size) 120 | elif transform_type=='CenterCrop': 121 | return method(self.image_size) 122 | elif transform_type=='Scale' or transform_type == 'Resize': 123 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 124 | elif transform_type=='Normalize': 125 | return method(**self.normalize_param ) 126 | else: 127 | return method() 128 | 129 | def get_composed_transform(self, aug = False): 130 | if aug: 131 | transform_list = ['RandomResizedCrop', 'ImageJitter', 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 132 | else: 133 | transform_list = ['Resize','CenterCrop', 'ToTensor', 'Normalize'] 134 | 135 | transform_funcs = [ self.parse_transform(x) for x in transform_list] 136 | transform = transforms.Compose(transform_funcs) 137 | return transform 138 | 139 | class DataManager(object): 140 | @abstractmethod 141 | def get_data_loader(self, data_file, aug): 142 | pass 143 | 144 | class SimpleDataManager(DataManager): 145 | def __init__(self, image_size, batch_size, split=None): 146 | super(SimpleDataManager, self).__init__() 147 | self.batch_size = batch_size 148 | self.trans_loader = TransformLoader(image_size) 149 | self.split = split 150 | 151 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 152 | transform = self.trans_loader.get_composed_transform(aug) 153 | dataset = SimpleDataset(transform, split=self.split) 154 | 155 | data_loader_params = dict(batch_size=self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 156 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 157 | 158 | return data_loader 159 | 160 | class SetDataManager(DataManager): 161 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide = 100, split=None): 162 | super(SetDataManager, self).__init__() 163 | self.image_size = image_size 164 | self.n_way = n_way 165 | self.batch_size = n_support + n_query 166 | self.n_eposide = n_eposide 167 | 168 | self.split = split 169 | 170 | self.trans_loader = TransformLoader(image_size) 171 | 172 | def get_data_loader(self, aug, num_workers=12): #parameters that would change on train/val set 173 | transform = self.trans_loader.get_composed_transform(aug) 174 | dataset = SetDataset(self.batch_size, transform, self.split) 175 | sampler = EpisodicBatchSampler(len(dataset), self.n_way, self.n_eposide) 176 | data_loader_params = dict(batch_sampler=sampler, num_workers=num_workers, pin_memory=True) 177 | data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params) 178 | return data_loader 179 | 180 | if __name__ == '__main__': 181 | pass 182 | -------------------------------------------------------------------------------- /datasets/tiered_ImageNet_few_shot.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/facebookresearch/low-shot-shrink-hallucinate 2 | 3 | import configs 4 | import sys 5 | import torch 6 | from PIL import Image 7 | import numpy as np 8 | import pandas as pd 9 | import torchvision.transforms as transforms 10 | import datasets.additional_transforms as add_transforms 11 | from torch.utils.data import Dataset, DataLoader 12 | from abc import abstractmethod 13 | from torchvision.datasets import ImageFolder 14 | 15 | import copy 16 | import os 17 | 18 | from PIL import ImageFile 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | 21 | sys.path.append("../") 22 | 23 | 24 | def identity(x): return x 25 | 26 | 27 | def construct_subset(dataset, split): 28 | split = pd.read_csv(split)['img_path'].values 29 | root = dataset.root 30 | 31 | class_to_idx = dataset.class_to_idx 32 | 33 | # create targets 34 | targets = [class_to_idx[os.path.dirname(i)] for i in split] 35 | 36 | # image_names = np.array([i[0] for i in dataset.imgs]) 37 | 38 | # # ind 39 | # ind = np.concatenate( 40 | # [np.where(image_names == os.path.join(root, j))[0] for j in split]) 41 | 42 | image_names = [os.path.join(root, j) for j in split] 43 | dataset_subset = copy.deepcopy(dataset) 44 | 45 | dataset_subset.samples = [j for j in zip(image_names, targets)] 46 | dataset_subset.imgs = dataset_subset.samples 47 | dataset_subset.targets = targets 48 | return dataset_subset 49 | 50 | 51 | class SimpleDataset: 52 | def __init__(self, transform, target_transform=identity, split=None): 53 | self.transform = transform 54 | self.target_transform = target_transform 55 | 56 | self.d = ImageFolder( 57 | configs.tiered_ImageNet_path, transform=transform, target_transform=target_transform) 58 | self.split = split 59 | if split is not None: 60 | print("Using Split: ", split) 61 | self.d = construct_subset(self.d, split) 62 | 63 | def __getitem__(self, i): 64 | return self.d[i] 65 | 66 | def __len__(self): 67 | return len(self.d) 68 | 69 | 70 | class SetDataset: 71 | def __init__(self, batch_size, transform, split=None): 72 | self.d = ImageFolder(configs.tiered_ImageNet_path, transform=transform) 73 | self.split = split 74 | if split is not None: 75 | print("Using Split: ", split) 76 | self.d = construct_subset(self.d, split) 77 | self.cl_list = range(len(self.d.classes)) 78 | 79 | self.sub_dataloader = [] 80 | sub_data_loader_params = dict(batch_size=batch_size, 81 | shuffle=True, 82 | num_workers=0, 83 | pin_memory=False) 84 | for cl in self.cl_list: 85 | ind = np.where(np.array(self.d.targets) == cl)[0].tolist() 86 | sub_dataset = torch.utils.data.Subset(self.d, ind) 87 | self.sub_dataloader.append(torch.utils.data.DataLoader( 88 | sub_dataset, **sub_data_loader_params)) 89 | 90 | def __getitem__(self, i): 91 | return next(iter(self.sub_dataloader[i])) 92 | 93 | def __len__(self): 94 | return len(self.sub_dataloader) 95 | 96 | 97 | class EpisodicBatchSampler(object): 98 | def __init__(self, n_classes, n_way, n_episodes): 99 | self.n_classes = n_classes 100 | self.n_way = n_way 101 | self.n_episodes = n_episodes 102 | 103 | def __len__(self): 104 | return self.n_episodes 105 | 106 | def __iter__(self): 107 | for i in range(self.n_episodes): 108 | yield torch.randperm(self.n_classes)[:self.n_way] 109 | 110 | 111 | class TransformLoader: 112 | def __init__(self, image_size, 113 | normalize_param=dict(mean=[0.485, 0.456, 0.406], std=[ 114 | 0.229, 0.224, 0.225]), 115 | jitter_param=dict(Brightness=0.4, Contrast=0.4, Color=0.4)): 116 | self.image_size = image_size 117 | self.normalize_param = normalize_param 118 | self.jitter_param = jitter_param 119 | 120 | def parse_transform(self, transform_type): 121 | if transform_type == 'ImageJitter': 122 | method = add_transforms.ImageJitter(self.jitter_param) 123 | return method 124 | 125 | if transform_type == 'Scale_original' or transform_type == 'Resize_original': 126 | return transforms.Resize([int(self.image_size), int(self.image_size)]) 127 | 128 | method = getattr(transforms, transform_type) 129 | if transform_type == 'RandomSizedCrop' or transform_type == 'RandomResizedCrop': 130 | return method(self.image_size) 131 | elif transform_type == 'CenterCrop': 132 | return method(self.image_size) 133 | elif transform_type == 'Scale' or transform_type == 'Resize': 134 | return method([int(self.image_size*1.15), int(self.image_size*1.15)]) 135 | elif transform_type == 'Normalize': 136 | return method(**self.normalize_param) 137 | else: 138 | return method() 139 | 140 | def get_composed_transform(self, aug=False): 141 | if aug: 142 | transform_list = ['RandomResizedCrop', 'ImageJitter', 143 | 'RandomHorizontalFlip', 'ToTensor', 'Normalize'] 144 | else: 145 | transform_list = ['Resize_original', 'ToTensor', 'Normalize'] 146 | 147 | transform_funcs = [self.parse_transform(x) for x in transform_list] 148 | transform = transforms.Compose(transform_funcs) 149 | return transform 150 | 151 | 152 | class DataManager(object): 153 | @abstractmethod 154 | def get_data_loader(self, data_file, aug): 155 | pass 156 | 157 | 158 | class SimpleDataManager(DataManager): 159 | def __init__(self, image_size, batch_size, split=None): 160 | super(SimpleDataManager, self).__init__() 161 | self.batch_size = batch_size 162 | self.trans_loader = TransformLoader(image_size) 163 | self.split = split 164 | 165 | # parameters that would change on train/val set 166 | def get_data_loader(self, aug, num_workers=12): 167 | transform = self.trans_loader.get_composed_transform(aug) 168 | dataset = SimpleDataset(transform, split=self.split) 169 | 170 | data_loader_params = dict( 171 | batch_size=self.batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 172 | data_loader = torch.utils.data.DataLoader( 173 | dataset, **data_loader_params) 174 | 175 | return data_loader 176 | 177 | 178 | class SetDataManager(DataManager): 179 | def __init__(self, image_size, n_way=5, n_support=5, n_query=16, n_eposide=100, split=None): 180 | super(SetDataManager, self).__init__() 181 | self.image_size = image_size 182 | self.n_way = n_way 183 | self.batch_size = n_support + n_query 184 | self.n_eposide = n_eposide 185 | 186 | self.trans_loader = TransformLoader(image_size) 187 | self.split = split 188 | 189 | # parameters that would change on train/val set 190 | def get_data_loader(self, aug, num_workers=12): 191 | transform = self.trans_loader.get_composed_transform(aug) 192 | dataset = SetDataset(self.batch_size, transform, split=self.split) 193 | sampler = EpisodicBatchSampler( 194 | len(dataset), self.n_way, self.n_eposide) 195 | data_loader_params = dict( 196 | batch_sampler=sampler, num_workers=num_workers, pin_memory=True) 197 | data_loader = torch.utils.data.DataLoader( 198 | dataset, **data_loader_params) 199 | return data_loader 200 | 201 | 202 | if __name__ == '__main__': 203 | pass 204 | -------------------------------------------------------------------------------- /evaluation/compile_result.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import numpy as np 4 | 5 | import argparse 6 | 7 | def main(args): 8 | data = pd.read_csv(args.result_file) 9 | 10 | mean = data.mean() 11 | CI = data.std() * 1.96 / np.sqrt(len(data)) 12 | 13 | compiled_result = (pd.concat([mean, CI], axis=1)) 14 | compiled_result.columns = ['Mean', '95CI'] 15 | print(compiled_result) 16 | compiled_result.to_csv(args.result_file[:-4] + '_compiled.csv') 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser(description="Construct the mean and 95 CI") 21 | parser.add_argument('--result_file', type=str, help='result file') 22 | args = parser.parse_args() 23 | main(args) -------------------------------------------------------------------------------- /evaluation/configs.py: -------------------------------------------------------------------------------- 1 | 2 | # TODO: Please set the directory to the target datasets accordingly 3 | ISIC_path = "/scratch/datasets/CD-FSL/ISIC" 4 | ChestX_path = "/scratch/datasets/CD-FSL/chestX" 5 | CropDisease_path = "/scratch/datasets/CD-FSL/CropDiseases" 6 | EuroSAT_path = "/scratch/datasets/CD-FSL/EuroSAT/2750" 7 | miniImageNet_path = '/scratch/datasets/CD-FSL/miniImageNet_test' 8 | tiered_ImageNet_path = '/scratch/datasets/tiered_imagenet/tiered_imagenet/original_split/test' 9 | -------------------------------------------------------------------------------- /evaluation/data: -------------------------------------------------------------------------------- 1 | ../data -------------------------------------------------------------------------------- /evaluation/datasets: -------------------------------------------------------------------------------- 1 | ../datasets -------------------------------------------------------------------------------- /evaluation/finetune.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim 5 | import torch.nn.functional as F 6 | import os 7 | 8 | import models 9 | from datasets import ISIC_few_shot, EuroSAT_few_shot, CropDisease_few_shot, Chest_few_shot, miniImageNet_few_shot, tiered_ImageNet_few_shot 10 | 11 | from tqdm import tqdm 12 | import pandas as pd 13 | import argparse 14 | import random 15 | import copy 16 | import warnings 17 | 18 | 19 | class Classifier(nn.Module): 20 | def __init__(self, dim, n_way): 21 | super(Classifier, self).__init__() 22 | 23 | self.fc = nn.Linear(dim, n_way) 24 | 25 | def forward(self, x): 26 | x = self.fc(x) 27 | return x 28 | 29 | def finetune(novel_loader, params, n_shot): 30 | 31 | print("Loading Model: ", params.embedding_load_path) 32 | if params.embedding_load_path_version == 0: 33 | state = torch.load(params.embedding_load_path)['state'] 34 | state_keys = list(state.keys()) 35 | for _, key in enumerate(state_keys): 36 | if "feature." in key: 37 | # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' 38 | newkey = key.replace("feature.", "") 39 | state[newkey] = state.pop(key) 40 | else: 41 | state.pop(key) 42 | sd = state 43 | elif params.embedding_load_path_version == 1: 44 | sd = torch.load(params.embedding_load_path) 45 | 46 | if 'epoch' in sd: 47 | print("Model checkpointed at epoch: ", sd['epoch']) 48 | sd = sd['model'] 49 | # elif params.embedding_load_path_version == 3: 50 | # state = torch.load(params.embedding_load_path) 51 | # print("Model checkpointed at epoch: ", state['epoch']) 52 | # state = state['model'] 53 | # state_keys = list(state.keys()) 54 | # for _, key in enumerate(state_keys): 55 | # if "module." in key: 56 | # # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx' 57 | # newkey = key.replace("module.", "") 58 | # state[newkey] = state.pop(key) 59 | # else: 60 | # state.pop(key) 61 | # sd = state 62 | else: 63 | raise ValueError("Invalid load path version!") 64 | 65 | if params.model == 'resnet10': 66 | pretrained_model_template = models.ResNet10() 67 | feature_dim = pretrained_model_template.final_feat_dim 68 | elif params.model == 'resnet12': 69 | pretrained_model_template = models.Resnet12(width=1, dropout=0.1) 70 | feature_dim = pretrained_model_template.output_size 71 | elif params.model == 'resnet18': 72 | pretrained_model_template = models.resnet18(remove_last_relu=False, 73 | input_high_res=True) 74 | feature_dim = 512 75 | else: 76 | raise ValueError("Invalid model!") 77 | 78 | pretrained_model_template.load_state_dict(sd) 79 | 80 | n_query = params.n_query 81 | n_way = params.n_way 82 | n_support = n_shot 83 | 84 | acc_all = [] 85 | 86 | for i, (x, y) in tqdm(enumerate(novel_loader)): 87 | 88 | pretrained_model = copy.deepcopy(pretrained_model_template) 89 | classifier = Classifier(feature_dim, params.n_way) 90 | 91 | pretrained_model.cuda() 92 | classifier.cuda() 93 | 94 | ############################################################################################### 95 | x = x.cuda() 96 | x_var = x 97 | 98 | assert len(torch.unique(y)) == n_way 99 | 100 | batch_size = 4 101 | support_size = n_way * n_support 102 | 103 | y_a_i = torch.from_numpy(np.repeat(range(n_way), n_support)).cuda() 104 | 105 | # split into support and query 106 | x_b_i = x_var[:, n_support:,: ,: ,:].contiguous().view(n_way*n_query, *x.size()[2:]).cuda() 107 | x_a_i = x_var[:, :n_support,: ,: ,:].contiguous().view(n_way*n_support, *x.size()[2:]).cuda() # (25, 3, 224, 224) 108 | 109 | if params.freeze_backbone: 110 | pretrained_model.eval() 111 | with torch.no_grad(): 112 | f_a_i = pretrained_model(x_a_i) 113 | else: 114 | pretrained_model.train() 115 | 116 | ############################################################################################### 117 | loss_fn = nn.CrossEntropyLoss().cuda() 118 | classifier_opt = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 119 | 120 | 121 | if not params.freeze_backbone: 122 | delta_opt = torch.optim.SGD(filter(lambda p: p.requires_grad, pretrained_model.parameters()), lr=0.01) 123 | 124 | ############################################################################################### 125 | total_epoch = 100 126 | 127 | classifier.train() 128 | 129 | for epoch in range(total_epoch): 130 | rand_id = np.random.permutation(support_size) 131 | 132 | for j in range(0, support_size, batch_size): 133 | classifier_opt.zero_grad() 134 | if not params.freeze_backbone: 135 | delta_opt.zero_grad() 136 | 137 | 138 | ##################################### 139 | selected_id = torch.from_numpy( rand_id[j: min(j+batch_size, support_size)]).cuda() 140 | 141 | y_batch = y_a_i[selected_id] 142 | 143 | if params.freeze_backbone: 144 | output = f_a_i[selected_id] 145 | else: 146 | z_batch = x_a_i[selected_id] 147 | output = pretrained_model(z_batch) 148 | 149 | output = classifier(output) 150 | loss = loss_fn(output, y_batch) 151 | 152 | ##################################### 153 | loss.backward() 154 | 155 | classifier_opt.step() 156 | if not params.freeze_backbone: 157 | delta_opt.step() 158 | 159 | pretrained_model.eval() 160 | classifier.eval() 161 | 162 | with torch.no_grad(): 163 | output = pretrained_model(x_b_i) 164 | scores = classifier(output) 165 | 166 | y_query = np.repeat(range( n_way ), n_query ) 167 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 168 | topk_ind = topk_labels.cpu().numpy() 169 | 170 | top1_correct = np.sum(topk_ind[:,0] == y_query) 171 | correct_this, count_this = float(top1_correct), len(y_query) 172 | # print (correct_this/ count_this *100) 173 | acc_all.append((correct_this/ count_this *100)) 174 | 175 | if (i+1) % 100 == 0: 176 | acc_all_np = np.asarray(acc_all) 177 | acc_mean = np.mean(acc_all_np) 178 | acc_std = np.std(acc_all_np) 179 | print('Test Acc (%d episodes) = %4.2f%% +- %4.2f%%' % 180 | (len(acc_all), acc_mean, 1.96 * acc_std/np.sqrt(len(acc_all)))) 181 | 182 | ############################################################################################### 183 | 184 | acc_all = np.asarray(acc_all) 185 | acc_mean = np.mean(acc_all) 186 | acc_std = np.std(acc_all) 187 | print('%d Test Acc = %4.2f%% +- %4.2f%%' % 188 | (len(acc_all), acc_mean, 1.96 * acc_std/np.sqrt(len(acc_all)))) 189 | 190 | return acc_all 191 | 192 | def main(params): 193 | 194 | if not os.path.isdir(params.save_dir): 195 | os.makedirs(params.save_dir) 196 | 197 | if params.target_dataset == 'ISIC': 198 | datamgr = ISIC_few_shot 199 | elif params.target_dataset == 'EuroSAT': 200 | datamgr = EuroSAT_few_shot 201 | elif params.target_dataset == 'CropDisease': 202 | datamgr = CropDisease_few_shot 203 | elif params.target_dataset == 'ChestX': 204 | datamgr = Chest_few_shot 205 | elif params.target_dataset == 'miniImageNet_test': 206 | datamgr = miniImageNet_few_shot 207 | elif params.target_dataset == 'tiered_ImageNet_test': 208 | if params.image_size != 84: 209 | warnings.warn("Tiered ImageNet: The image size for is not 84x84") 210 | datamgr = tiered_ImageNet_few_shot 211 | else: 212 | raise ValueError("Invalid Dataset!") 213 | 214 | results = {} 215 | shot_done = [] 216 | print(params.target_dataset) 217 | for shot in params.n_shot: 218 | print(f"{params.n_way}-way {shot}-shot") 219 | torch.backends.cudnn.deterministic = True 220 | torch.backends.cudnn.benchmark = False 221 | np.random.seed(params.seed) 222 | torch.random.manual_seed(params.seed) 223 | torch.cuda.manual_seed(params.seed) 224 | random.seed(params.seed) 225 | novel_loader = datamgr.SetDataManager(params.image_size, n_eposide=params.n_episode, 226 | n_query=params.n_query, n_way=params.n_way, 227 | n_support=shot, split=params.subset_split).get_data_loader( 228 | aug=params.train_aug) 229 | acc_all = finetune(novel_loader, params, n_shot=shot) 230 | results[shot] = acc_all 231 | shot_done.append(shot) 232 | 233 | if params.save_suffix is None: 234 | pd.DataFrame(results).to_csv(os.path.join(params.save_dir, 235 | params.source_dataset + '_' + params.target_dataset + '_' + 236 | str(params.n_way) + 'way' + '.csv'), index=False) 237 | else: 238 | pd.DataFrame(results).to_csv(os.path.join(params.save_dir, 239 | params.source_dataset + '_' + params.target_dataset + '_' + 240 | str(params.n_way) + 'way_' + params.save_suffix + '.csv'), index=False) 241 | 242 | return 243 | 244 | if __name__=='__main__': 245 | parser = argparse.ArgumentParser( 246 | description='few-shot Evaluation script') 247 | parser.add_argument('--save_dir', default='.', type=str, help='Directory to save the result csv') 248 | parser.add_argument('--source_dataset', default='miniImageNet', help='source_dataset') 249 | parser.add_argument('--target_dataset', default='miniImagenet', 250 | help='test target dataset') 251 | parser.add_argument('--subset_split', type=str, 252 | help='path to the csv files that contains the split of the data') 253 | parser.add_argument('--image_size', type=int, default=224, 254 | help='Resolution of the input image') 255 | parser.add_argument('--n_way', default=5, type=int, 256 | help='class num to classify for training') 257 | parser.add_argument('--n_shot', nargs='+', default=[5], type=int, 258 | help='number of labeled data in each class, same as n_support') 259 | parser.add_argument('--n_episode', default=600, type=int, 260 | help='Number of episodes') 261 | parser.add_argument('--n_query', default=15, type=int, 262 | help='Number of query examples per class') 263 | parser.add_argument('--train_aug', action='store_true', 264 | help='perform data augmentation or not during training ') 265 | parser.add_argument('--model', default='resnet10', 266 | help='backbone architecture') 267 | parser.add_argument('--freeze_backbone', action='store_true', 268 | help='Freeze the backbone network for finetuning') 269 | parser.add_argument('--seed', default=1, type=int, help='random seed') 270 | parser.add_argument('--embedding_load_path', type=str, 271 | help='path to load embedding') 272 | parser.add_argument('--embedding_load_path_version', type=int, default=1, 273 | help='how to load the embedding') 274 | parser.add_argument('--save_suffix', type=str, help='suffix added to the csv file') 275 | 276 | 277 | params = parser.parse_args() 278 | main(params) 279 | 280 | -------------------------------------------------------------------------------- /evaluation/methods: -------------------------------------------------------------------------------- 1 | ../methods -------------------------------------------------------------------------------- /evaluation/models: -------------------------------------------------------------------------------- 1 | ../models/ -------------------------------------------------------------------------------- /evaluation/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # bash script to evaluate different representations. 4 | # finetune.py learns a linear classifier on the features extracted from the support set 5 | # compile_result.py computes the averages and the 96 confidence intervals from the results generated from finetune.py 6 | # and evaluate on the query set 7 | export CUDA_VISIBLE_DEVICES=0 8 | 9 | ############################################################################################## 10 | # Evaluate Representations trained on miniImageNet 11 | ############################################################################################## 12 | 13 | # Before running the commands, please take care of the TODO appropriately 14 | for source in "miniImageNet" 15 | do 16 | for target in "ChestX" "ISIC" "EuroSAT" "CropDisease" "miniImageNet_test" 17 | do 18 | # TODO: Please set the following argument appropriately 19 | # --save_dir: directory to save the results from evaluation 20 | # --embedding_load_path: representation to be evaluated 21 | # --embedding_load_path_version: either 0 or 1. This is 1 most of the times. Only set this to 0 when 22 | # evaluating teacher model trained using teacher_miniImageNet/train.py 23 | # E.g. the following command evaluates the STARTUP representation on 600 tasks 24 | # and save the results of the 600 tasks at results/STARTUP_miniImageNet/$source\_$target\_5way.csv 25 | python finetune.py \ 26 | --image_size 224 \ 27 | --n_way 5 \ 28 | --n_shot 1 5 \ 29 | --n_episode 600 \ 30 | --n_query 15 \ 31 | --seed 1 \ 32 | --freeze_backbone \ 33 | --save_dir results/STARTUP_miniImageNet \ 34 | --source_dataset $source \ 35 | --target_dataset $target \ 36 | --subset_split datasets/split_seed_1/$target\_labeled_80.csv \ 37 | --model resnet10 \ 38 | --embedding_load_path ../student_STARTUP/miniImageNet_source/$target\_unlabeled_20/checkpoint_best.pkl \ 39 | --embedding_load_path_version 1 40 | 41 | # TODO: Please set --result_file appropriately. The prefix of the argument should be the same as 42 | # the --save_dir from the previous command 43 | # E.g. the following command computes the mean and 95 CI from results/STARTUP_miniImageNet/$source\_$target\_5way.csv 44 | # and saves them to results/STARTUP_miniImageNet/$source\_$target\_5way_compiled.csv 45 | python compile_result.py --result_file results/STARTUP_miniImageNet/$source\_$target\_5way.csv 46 | done 47 | done 48 | 49 | ############################################################################################## 50 | # Evaluate Representations trained on ImageNet 51 | ############################################################################################## 52 | 53 | # Before running the commands, please take care of the TODO appropriately 54 | for source in "ImageNet" 55 | do 56 | for target in "ChestX" "ISIC" "EuroSAT" "CropDisease" 57 | do 58 | # TODO: Please set the following argument appropriately 59 | # --save_dir: directory to save the results from evaluation 60 | # --embedding_load_path: representation to be evaluated 61 | # E.g. the following command evaluates the STARTUP representation on 600 tasks 62 | # and save the results of the 600 tasks at results/STARTUP_ImageNet/$source\_$target\_5way.csv 63 | python finetune.py \ 64 | --image_size 224 \ 65 | --n_way 5 \ 66 | --n_shot 1 5 20 50 \ 67 | --n_episode 600 \ 68 | --n_query 15 \ 69 | --seed 1 \ 70 | --freeze_backbone \ 71 | --source_dataset $source \ 72 | --target_dataset $target \ 73 | --subset_split datasets/split_seed_1/$target\_labeled_80.csv \ 74 | --model resnet18 \ 75 | --save_dir results/STARTUP_ImageNet \ 76 | --embedding_load_path ../student_STARTUP/ImageNet_source/$target\_unlabeled_20/checkpoint_best.pkl \ 77 | --embedding_load_path_version 1 78 | 79 | # TODO: Please set --result_file appropriately. The prefix of the argument should be the same as 80 | # the --save_dir from the previous command 81 | # E.g. the following command computes the mean and 95 CI from results/STARTUP_ImageNet/$source\_$target\_5way.csv 82 | # and saves them to results/STARTUP_ImageNet/$source\_$target\_5way_compiled.csv 83 | python compile_result.py --result_file temp/STARTUP_ImageNet/$source\_$target\_5way.csv 84 | done 85 | done 86 | 87 | ############################################################################################## 88 | # Evaluate Representations trained on tieredImageNet 89 | ############################################################################################## 90 | 91 | # Before running the commands, please take care of the TODO appropriately 92 | for source in "tiered_ImageNet_test" 93 | do 94 | for target in "tiered_ImageNet_test" 95 | do 96 | # TODO: Please set the following argument appropriately 97 | # --save_dir: directory to save the results from evaluation 98 | # --embedding_load_path: representation to be evaluated 99 | # --embedding_load_path_version: either 0 or 1. This is 1 most of the times. Only set this to 0 when 100 | # evaluating teacher model trained using teacher_miniImageNet/train.py 101 | # --subset_split: Either datasets/split_seed_1/$target\_labeled_90.csv (for the less unlabeled data setup) 102 | # or datasets/split_seed_1/$target\_labeled_50.csv (for the more unlabeled data setup) 103 | # E.g. the following command evaluates the STARTUP representation on 600 tasks 104 | # and save the results of the 600 tasks at results/STARTUP_tiered_ImageNet_less/$source\_$target\_5way.csv 105 | python finetune.py \ 106 | --image_size 84 \ 107 | --n_way 5 \ 108 | --n_shot 1 5 \ 109 | --n_episode 600 \ 110 | --n_query 15 \ 111 | --seed 1 \ 112 | --freeze_backbone \ 113 | --save_dir results/STARTUP_tiered_ImageNet_less \ 114 | --source_dataset $source \ 115 | --target_dataset $target \ 116 | --subset_split datasets/split_seed_1/$target\_labeled_90.csv \ 117 | --model resnet12 \ 118 | --embedding_load_path ../student_STARTUP/tiered_ImageNet_source/$target\_unlabeled_10/checkpoint_best.pkl \ 119 | --embedding_load_path_version 1 120 | 121 | # TODO: Please set --result_file appropriately. The prefix of the argument should be the same as 122 | # the --save_dir from the previous command 123 | # E.g. the following command computes the mean and 95 CI from results/STARTUP_tiered_ImageNet_less/$source\_$target\_5way.csv 124 | # and saves them to results/STARTUP_tiered_ImageNet_less/$source\_$target\_5way_compiled.csv 125 | python compile_result.py --result_file results/STARTUP_tiered_ImageNet_less/$source\_$target\_5way.csv 126 | done 127 | done -------------------------------------------------------------------------------- /evaluation/utils: -------------------------------------------------------------------------------- 1 | ../utils -------------------------------------------------------------------------------- /methods/__init__.py: -------------------------------------------------------------------------------- 1 | from . import meta_template 2 | from . import protonet 3 | 4 | 5 | -------------------------------------------------------------------------------- /methods/baselinefinetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from methods.meta_template import MetaTemplate 7 | 8 | class BaselineFinetune(MetaTemplate): 9 | def __init__(self, model_func, n_way, n_support, loss_type = "softmax"): 10 | super(BaselineFinetune, self).__init__( model_func, n_way, n_support) 11 | self.loss_type = loss_type 12 | 13 | def set_forward(self, x, is_feature = True): 14 | return self.set_forward_adaptation(x, is_feature); #Baseline always do adaptation 15 | 16 | def set_forward_adaptation(self, x, is_feature = True): 17 | assert is_feature == True, 'Baseline only support testing with feature' 18 | z_support, z_query = self.parse_feature(x,is_feature) 19 | 20 | z_support = z_support.contiguous().view(self.n_way* self.n_support, -1 ) 21 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 22 | 23 | y_support = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support )) 24 | y_support = Variable(y_support.cuda()) 25 | 26 | if self.loss_type == 'softmax': 27 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 28 | 29 | elif self.loss_type == 'dist': 30 | linear_clf = backbone.distLinear(self.feat_dim, self.n_way) 31 | 32 | 33 | linear_clf = linear_clf.cuda() 34 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 35 | 36 | loss_function = nn.CrossEntropyLoss() 37 | loss_function = loss_function.cuda() 38 | 39 | batch_size = 4 40 | support_size = self.n_way* self.n_support 41 | for epoch in range(100): 42 | rand_id = np.random.permutation(support_size) 43 | for i in range(0, support_size , batch_size): 44 | set_optimizer.zero_grad() 45 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda() 46 | z_batch = z_support[selected_id] 47 | 48 | scores = linear_clf(z_batch) 49 | 50 | y_batch = y_support[selected_id] 51 | 52 | loss = loss_function(scores,y_batch) 53 | loss.backward() 54 | set_optimizer.step() 55 | 56 | scores = linear_clf(z_query) 57 | return scores 58 | 59 | def set_forward_loss(self,x): 60 | raise ValueError('Baseline predict on pretrained feature and do not support finetune backbone') 61 | -------------------------------------------------------------------------------- /methods/baselinetrain.py: -------------------------------------------------------------------------------- 1 | import utils 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | import time 9 | 10 | class BaselineTrain(nn.Module): 11 | def __init__(self, model_func, num_class, loss_type = 'softmax'): 12 | super(BaselineTrain, self).__init__() 13 | self.feature = model_func() 14 | 15 | if loss_type == 'softmax': 16 | self.classifier = nn.Linear(self.feature.final_feat_dim, num_class) 17 | self.classifier.bias.data.fill_(0) 18 | elif loss_type == 'dist': #Baseline ++ 19 | self.classifier = backbone.distLinear(self.feature.final_feat_dim, num_class) 20 | 21 | self.loss_type = loss_type #'softmax' #'dist' 22 | self.num_class = num_class 23 | self.loss_fn = nn.CrossEntropyLoss() 24 | self.top1 = utils.AverageMeter() 25 | 26 | def forward(self, x): 27 | x = x.cuda() 28 | out = self.feature.forward(x) 29 | scores = self.classifier.forward(out) 30 | return scores 31 | 32 | def forward_loss(self, x, y): 33 | y = y.cuda() 34 | 35 | scores = self.forward(x) 36 | 37 | _, predicted = torch.max(scores.data, 1) 38 | correct = predicted.eq(y.data).cpu().sum() 39 | self.top1.update(correct.item()*100 / (y.size(0)+0.0), y.size(0)) 40 | 41 | return self.loss_fn(scores, y ) 42 | 43 | def train_loop(self, epoch, train_loader, optimizer, logger): 44 | print_freq = 10 45 | # avg_loss=0 46 | 47 | self.train() 48 | 49 | meters = utils.AverageMeterSet() 50 | 51 | end = time.time() 52 | for i, (X,y) in enumerate(train_loader): 53 | meters.update('Data_time', time.time() - end) 54 | 55 | optimizer.zero_grad() 56 | logits = self.forward(X) 57 | 58 | y = y.cuda() 59 | loss = self.loss_fn(logits, y) 60 | loss.backward() 61 | optimizer.step() 62 | 63 | perf = utils.accuracy(logits.data, 64 | y.data, topk=(1, 5)) 65 | 66 | meters.update('Loss', loss.item(), 1) 67 | meters.update('top1', perf['average'][0].item(), len(X)) 68 | meters.update('top5', perf['average'][1].item(), len(X)) 69 | 70 | meters.update('top1_per_class', perf['per_class_average'][0].item(), 1) 71 | meters.update('top5_per_class', perf['per_class_average'][1].item(), 1) 72 | 73 | meters.update('Batch_time', time.time() - end) 74 | end = time.time() 75 | 76 | # avg_loss = avg_loss+loss.item() 77 | if (i+1) % print_freq==0: 78 | #print(optimizer.state_dict()['param_groups'][0]['lr']) 79 | # print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} | Top1 Val {:f} | Top1 Avg {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1), self.top1.val, self.top1.avg)) 80 | logger_string = ('Training Epoch: [{epoch}] Step: [{step} / {steps}] Batch Time: {meters[Batch_time]:.4f} ' 81 | 'Data Time: {meters[Data_time]:.4f} Average Loss: {meters[Loss]:.4f} ' 82 | 'Top1: {meters[top1]:.4f} Top5: {meters[top5]:.4f} ' 83 | 'Top1_per_class: {meters[top1_per_class]:.4f} ' 84 | 'Top5_per_class: {meters[top5_per_class]:.4f} ').format( 85 | epoch=epoch, step=i+1, steps=len(train_loader), meters=meters) 86 | 87 | logger.info(logger_string) 88 | 89 | logger_string = ('Training Epoch: [{epoch}] Step: [{step}] Batch Time: {meters[Batch_time]:.4f} ' 90 | 'Data Time: {meters[Data_time]:.4f} Average Loss: {meters[Loss]:.4f} ' 91 | 'Top1: {meters[top1]:.4f} Top5: {meters[top5]:.4f} ' 92 | 'Top1_per_class: {meters[top1_per_class]:.4f} ' 93 | 'Top5_per_class: {meters[top5_per_class]:.4f} ').format( 94 | epoch=epoch+1, step=0, meters=meters) 95 | 96 | logger.info(logger_string) 97 | 98 | return meters.averages() 99 | 100 | 101 | def test_loop(self, val_loader): 102 | return -1 #no validation, just save model during iteration 103 | 104 | -------------------------------------------------------------------------------- /methods/meta_template.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import utils 7 | from abc import abstractmethod 8 | 9 | class MetaTemplate(nn.Module): 10 | def __init__(self, model_func, n_way, n_support, change_way = True): 11 | super(MetaTemplate, self).__init__() 12 | self.n_way = n_way 13 | self.n_support = n_support 14 | self.n_query = -1 #(change depends on input) 15 | self.feature = model_func() 16 | self.feat_dim = self.feature.final_feat_dim 17 | self.change_way = change_way #some methods allow different_way classification during training and test 18 | 19 | @abstractmethod 20 | def set_forward(self,x,is_feature): 21 | pass 22 | 23 | @abstractmethod 24 | def set_forward_loss(self, x): 25 | pass 26 | 27 | def forward(self,x): 28 | out = self.feature.forward(x) 29 | return out 30 | 31 | def parse_feature(self,x,is_feature): 32 | x = Variable(x.cuda()) 33 | if is_feature: 34 | z_all = x 35 | else: 36 | x = x.contiguous().view( self.n_way * (self.n_support + self.n_query), *x.size()[2:]) 37 | z_all = self.feature.forward(x) 38 | z_all = z_all.view( self.n_way, self.n_support + self.n_query, -1) 39 | z_support = z_all[:, :self.n_support] 40 | z_query = z_all[:, self.n_support:] 41 | 42 | return z_support, z_query 43 | 44 | def correct(self, x): 45 | scores = self.set_forward(x) 46 | y_query = np.repeat(range( self.n_way ), self.n_query ) 47 | 48 | topk_scores, topk_labels = scores.data.topk(1, 1, True, True) 49 | topk_ind = topk_labels.cpu().numpy() 50 | top1_correct = np.sum(topk_ind[:,0] == y_query) 51 | return float(top1_correct), len(y_query) 52 | 53 | def train_loop(self, epoch, train_loader, optimizer ): 54 | print_freq = 10 55 | 56 | avg_loss=0 57 | for i, (x,_ ) in enumerate(train_loader): 58 | self.n_query = x.size(1) - self.n_support 59 | if self.change_way: 60 | self.n_way = x.size(0) 61 | optimizer.zero_grad() 62 | loss = self.set_forward_loss( x ) 63 | loss.backward() 64 | optimizer.step() 65 | avg_loss = avg_loss+loss.item() 66 | 67 | if i % print_freq==0: 68 | #print(optimizer.state_dict()['param_groups'][0]['lr']) 69 | print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1))) 70 | 71 | def test_loop(self, test_loader, record = None): 72 | correct =0 73 | count = 0 74 | acc_all = [] 75 | 76 | iter_num = len(test_loader) 77 | for i, (x,_) in enumerate(test_loader): 78 | self.n_query = x.size(1) - self.n_support 79 | if self.change_way: 80 | self.n_way = x.size(0) 81 | correct_this, count_this = self.correct(x) 82 | acc_all.append(correct_this/ count_this*100 ) 83 | 84 | acc_all = np.asarray(acc_all) 85 | acc_mean = np.mean(acc_all) 86 | acc_std = np.std(acc_all) 87 | print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) 88 | 89 | return acc_mean 90 | 91 | def set_forward_adaptation(self, x, is_feature = True): #further adaptation, default is fixing feature and train a new softmax clasifier 92 | assert is_feature == True, 'Feature is fixed in further adaptation' 93 | z_support, z_query = self.parse_feature(x,is_feature) 94 | 95 | z_support = z_support.contiguous().view(self.n_way* self.n_support, -1 ) 96 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 97 | 98 | y_support = torch.from_numpy(np.repeat(range( self.n_way ), self.n_support )) 99 | y_support = Variable(y_support.cuda()) 100 | 101 | linear_clf = nn.Linear(self.feat_dim, self.n_way) 102 | linear_clf = linear_clf.cuda() 103 | 104 | set_optimizer = torch.optim.SGD(linear_clf.parameters(), lr = 0.01, momentum=0.9, dampening=0.9, weight_decay=0.001) 105 | 106 | loss_function = nn.CrossEntropyLoss() 107 | loss_function = loss_function.cuda() 108 | 109 | batch_size = 4 110 | support_size = self.n_way* self.n_support 111 | for epoch in range(100): 112 | rand_id = np.random.permutation(support_size) 113 | for i in range(0, support_size , batch_size): 114 | set_optimizer.zero_grad() 115 | selected_id = torch.from_numpy( rand_id[i: min(i+batch_size, support_size) ]).cuda() 116 | z_batch = z_support[selected_id] 117 | y_batch = y_support[selected_id] 118 | scores = linear_clf(z_batch) 119 | loss = loss_function(scores,y_batch) 120 | loss.backward() 121 | set_optimizer.step() 122 | 123 | scores = linear_clf(z_query) 124 | return scores 125 | -------------------------------------------------------------------------------- /methods/models: -------------------------------------------------------------------------------- 1 | ../models/ -------------------------------------------------------------------------------- /methods/protonet.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/jakesnell/prototypical-networks 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.nn.functional as F 8 | from methods.meta_template import MetaTemplate 9 | 10 | class ProtoNet(MetaTemplate): 11 | def __init__(self, model_func, n_way, n_support): 12 | super(ProtoNet, self).__init__( model_func, n_way, n_support) 13 | self.loss_fn = nn.CrossEntropyLoss() 14 | 15 | 16 | def set_forward(self,x,is_feature = False): 17 | z_support, z_query = self.parse_feature(x,is_feature) 18 | 19 | z_support = z_support.contiguous() 20 | z_proto = z_support.view(self.n_way, self.n_support, -1 ).mean(1) #the shape of z is [n_data, n_dim] 21 | z_query = z_query.contiguous().view(self.n_way* self.n_query, -1 ) 22 | 23 | dists = euclidean_dist(z_query, z_proto) 24 | scores = -dists 25 | return scores 26 | 27 | 28 | def set_forward_loss(self, x): 29 | y_query = torch.from_numpy(np.repeat(range( self.n_way ), self.n_query )) 30 | y_query = Variable(y_query.cuda()) 31 | 32 | scores = self.set_forward(x) 33 | 34 | return self.loss_fn(scores, y_query ) 35 | 36 | 37 | def euclidean_dist( x, y): 38 | # x: N x D 39 | # y: M x D 40 | n = x.size(0) 41 | m = y.size(0) 42 | d = x.size(1) 43 | assert d == y.size(1) 44 | 45 | x = x.unsqueeze(1).expand(n, m, d) 46 | y = y.unsqueeze(0).expand(n, m, d) 47 | 48 | return torch.pow(x - y, 2).sum(2) 49 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .dataparallel_wrapper import * 3 | from .resnet10 import ResNet10 4 | from .resnet12 import Resnet12 -------------------------------------------------------------------------------- /models/dataparallel_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class dataparallel_wrapper(nn.Module): 5 | def __init__(self, module): 6 | super(dataparallel_wrapper, self).__init__() 7 | self.module = module 8 | 9 | def forward(self, mode, *args, **kwargs): 10 | return getattr(self.module, mode)(*args, **kwargs) 11 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet20', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, groups=groups, bias=False) 13 | 14 | 15 | def conv1x1(in_planes, out_planes, stride=1): 16 | """1x1 convolution""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 24 | base_width=64, norm_layer=None, remove_last_relu=False): 25 | super(BasicBlock, self).__init__() 26 | if norm_layer is None: 27 | norm_layer = nn.BatchNorm2d 28 | if groups != 1 or base_width != 64: 29 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 30 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = norm_layer(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = norm_layer(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | self.remove_last_relu = False 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | 55 | if self.remove_last_relu: 56 | out = out 57 | else: 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | expansion = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 67 | base_width=64, norm_layer=None, remove_last_relu=False): 68 | super(Bottleneck, self).__init__() 69 | if norm_layer is None: 70 | norm_layer = nn.BatchNorm2d 71 | width = int(planes * (base_width / 64.)) * groups 72 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 73 | self.conv1 = conv1x1(inplanes, width) 74 | self.bn1 = norm_layer(width) 75 | self.conv2 = conv3x3(width, width, stride, groups) 76 | self.bn2 = norm_layer(width) 77 | self.conv3 = conv1x1(width, planes * self.expansion) 78 | self.bn3 = norm_layer(planes * self.expansion) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | self.remove_last_relu = remove_last_relu 83 | 84 | def forward(self, x): 85 | identity = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | identity = self.downsample(x) 100 | 101 | out += identity 102 | 103 | if self.remove_last_relu: 104 | out = out 105 | else: 106 | out = self.relu(out) 107 | 108 | return out 109 | 110 | class ResNet(nn.Module): 111 | 112 | def __init__(self, block_type, layers, zero_init_residual=False, 113 | groups=1, width_per_group=64, norm_layer=None, 114 | remove_last_relu=False, input_high_res=True): 115 | super(ResNet, self).__init__() 116 | if norm_layer is None: 117 | norm_layer = nn.BatchNorm2d 118 | 119 | assert block_type in ['basic', 'bottleneck'] 120 | 121 | self.inplanes = 64 122 | self.groups = groups 123 | self.base_width = width_per_group 124 | 125 | if not input_high_res: 126 | self.layer0 = nn.Sequential( 127 | nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, 128 | bias=False), 129 | norm_layer(self.inplanes), 130 | nn.ReLU(inplace=True) 131 | ) 132 | else: 133 | self.layer0 = nn.Sequential( 134 | nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 135 | bias=False), 136 | norm_layer(self.inplanes), 137 | nn.ReLU(inplace=True), 138 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | ) 140 | 141 | 142 | 143 | self.layer1 = self._make_layer( 144 | block_type, 64, layers[0], norm_layer=norm_layer, remove_last_relu=False) 145 | self.layer2 = self._make_layer( 146 | block_type, 128, layers[1], stride=2, norm_layer=norm_layer, remove_last_relu=False) 147 | self.layer3 = self._make_layer( 148 | block_type, 256, layers[2], stride=2, norm_layer=norm_layer, remove_last_relu=False) 149 | self.layer4 = self._make_layer( 150 | block_type, 512, layers[3], stride=2, norm_layer=norm_layer, remove_last_relu=remove_last_relu) 151 | self.remove_last_relu = remove_last_relu 152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 153 | 154 | # this variable is added for compatibility reason 155 | self.pool = self.avgpool 156 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 162 | nn.init.constant_(m.weight, 1) 163 | nn.init.constant_(m.bias, 0) 164 | 165 | # Zero-initialize the last BN in each residual branch, 166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 168 | if zero_init_residual: 169 | for m in self.modules(): 170 | if isinstance(m, Bottleneck): 171 | nn.init.constant_(m.bn3.weight, 0) 172 | elif isinstance(m, BasicBlock): 173 | nn.init.constant_(m.bn2.weight, 0) 174 | 175 | def _make_layer(self, block_type, planes, blocks, stride=1, norm_layer=None, remove_last_relu=False): 176 | if block_type == 'basic': 177 | block = BasicBlock 178 | elif block_type == 'bottleneck': 179 | block = Bottleneck 180 | 181 | if norm_layer is None: 182 | norm_layer = nn.BatchNorm2d 183 | downsample = None 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, norm_layer, remove_last_relu=False)) 193 | self.inplanes = planes * block.expansion 194 | 195 | for _ in range(1, blocks-1): 196 | layers.append(block(self.inplanes, planes, groups=self.groups, 197 | base_width=self.base_width, norm_layer=norm_layer, 198 | remove_last_relu=False)) 199 | 200 | layers.append(block(self.inplanes, planes, groups=self.groups, 201 | base_width=self.base_width, norm_layer=norm_layer, 202 | remove_last_relu=remove_last_relu)) 203 | 204 | return nn.Sequential(*layers) 205 | 206 | def feature_maps(self, x): 207 | x = self.layer0(x) 208 | x = self.layer1(x) 209 | x = self.layer2(x) 210 | x = self.layer3(x) 211 | x = self.layer4(x) 212 | return x 213 | 214 | def feature(self, x): 215 | x = self.feature_maps(x) 216 | x = self.avgpool(x) 217 | x = x.view(x.size(0), -1) 218 | return x 219 | 220 | def forward(self, x): 221 | x = self.feature(x) 222 | return x 223 | 224 | 225 | def resnet18(**kwargs): 226 | """Constructs a ResNet-18 model. 227 | """ 228 | model = ResNet('basic', [2, 2, 2, 2], **kwargs) 229 | return model 230 | 231 | 232 | def resnet20(**kwargs): 233 | """Constructs a ResNet-18 model. 234 | """ 235 | model = ResNet('basic', [2, 2, 2, 3], **kwargs) 236 | return model 237 | 238 | 239 | def resnet34(**kwargs): 240 | """Constructs a ResNet-34 model. 241 | """ 242 | model = ResNet('basic', [3, 4, 6, 3], **kwargs) 243 | return model 244 | 245 | 246 | def resnet50(**kwargs): 247 | """Constructs a ResNet-50 model. 248 | """ 249 | model = ResNet('bottleneck', [3, 4, 6, 3], **kwargs) 250 | return model 251 | 252 | 253 | def resnet101(**kwargs): 254 | """Constructs a ResNet-101 model. 255 | """ 256 | model = ResNet('bottleneck', [3, 4, 23, 3], **kwargs) 257 | return model 258 | 259 | 260 | def resnet152(**kwargs): 261 | """Constructs a ResNet-152 model. 262 | """ 263 | model = ResNet('bottleneck', [3, 8, 36, 3], **kwargs) 264 | return model 265 | -------------------------------------------------------------------------------- /models/resnet10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from torch.autograd import Variable 3 | import torch.nn as nn 4 | import math 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch.nn.utils.weight_norm import WeightNorm 8 | 9 | def init_layer(L): 10 | # Initialization using fan-in 11 | if isinstance(L, nn.Conv2d): 12 | n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels 13 | L.weight.data.normal_(0,math.sqrt(2.0/float(n))) 14 | elif isinstance(L, nn.BatchNorm2d): 15 | L.weight.data.fill_(1) 16 | L.bias.data.fill_(0) 17 | 18 | class Flatten(nn.Module): 19 | def __init__(self): 20 | super(Flatten, self).__init__() 21 | 22 | def forward(self, x): 23 | return x.view(x.size(0), -1) 24 | 25 | # Simple ResNet Block 26 | class SimpleBlock(nn.Module): 27 | maml = False #Default 28 | def __init__(self, indim, outdim, half_res): 29 | super(SimpleBlock, self).__init__() 30 | self.indim = indim 31 | self.outdim = outdim 32 | 33 | self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False) 34 | self.BN1 = nn.BatchNorm2d(outdim) 35 | 36 | self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False) 37 | self.BN2 = nn.BatchNorm2d(outdim) 38 | 39 | self.relu1 = nn.ReLU(inplace=True) 40 | self.relu2 = nn.ReLU(inplace=True) 41 | 42 | self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2] 43 | 44 | self.half_res = half_res 45 | 46 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 47 | if indim!=outdim: 48 | 49 | self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False) 50 | self.BNshortcut = nn.BatchNorm2d(outdim) 51 | 52 | self.parametrized_layers.append(self.shortcut) 53 | self.parametrized_layers.append(self.BNshortcut) 54 | self.shortcut_type = '1x1' 55 | else: 56 | self.shortcut_type = 'identity' 57 | 58 | for layer in self.parametrized_layers: 59 | init_layer(layer) 60 | 61 | def forward(self, x): 62 | out = self.C1(x) 63 | out = self.BN1(out) 64 | out = self.relu1(out) 65 | 66 | out = self.C2(out) 67 | out = self.BN2(out) 68 | short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x)) 69 | out = out + short_out 70 | out = self.relu2(out) 71 | return out 72 | 73 | # Bottleneck block 74 | class BottleneckBlock(nn.Module): 75 | def __init__(self, indim, outdim, half_res): 76 | super(BottleneckBlock, self).__init__() 77 | bottleneckdim = int(outdim/4) 78 | self.indim = indim 79 | self.outdim = outdim 80 | 81 | self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False) 82 | self.BN1 = nn.BatchNorm2d(bottleneckdim) 83 | self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1,padding=1) 84 | self.BN2 = nn.BatchNorm2d(bottleneckdim) 85 | self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False) 86 | self.BN3 = nn.BatchNorm2d(outdim) 87 | 88 | self.relu = nn.ReLU() 89 | self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3] 90 | self.half_res = half_res 91 | 92 | 93 | # if the input number of channels is not equal to the output, then need a 1x1 convolution 94 | if indim!=outdim: 95 | 96 | self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False) 97 | 98 | self.parametrized_layers.append(self.shortcut) 99 | self.shortcut_type = '1x1' 100 | else: 101 | self.shortcut_type = 'identity' 102 | 103 | for layer in self.parametrized_layers: 104 | init_layer(layer) 105 | 106 | 107 | def forward(self, x): 108 | 109 | short_out = x if self.shortcut_type == 'identity' else self.shortcut(x) 110 | out = self.C1(x) 111 | out = self.BN1(out) 112 | out = self.relu(out) 113 | out = self.C2(out) 114 | out = self.BN2(out) 115 | out = self.relu(out) 116 | out = self.C3(out) 117 | out = self.BN3(out) 118 | out = out + short_out 119 | 120 | out = self.relu(out) 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten = False): 126 | # list_of_num_layers specifies number of layers in each stage 127 | # list_of_out_dims specifies number of output channel for each stage 128 | super(ResNet,self).__init__() 129 | assert len(list_of_num_layers)==4, 'Can have only four stages' 130 | 131 | conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | bn1 = nn.BatchNorm2d(64) 134 | 135 | relu = nn.ReLU() 136 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 137 | 138 | init_layer(conv1) 139 | init_layer(bn1) 140 | 141 | trunk = [conv1, bn1, relu, pool1] 142 | 143 | indim = 64 144 | for i in range(4): 145 | 146 | for j in range(list_of_num_layers[i]): 147 | half_res = (i>=1) and (j==0) 148 | B = block(indim, list_of_out_dims[i], half_res) 149 | trunk.append(B) 150 | indim = list_of_out_dims[i] 151 | 152 | if flatten: 153 | # avgpool = nn.AvgPool2d(7) 154 | avgpool = nn.AdaptiveAvgPool2d((1, 1)) 155 | trunk.append(avgpool) 156 | trunk.append(Flatten()) 157 | self.final_feat_dim = indim 158 | else: 159 | self.final_feat_dim = [ indim, 7, 7] 160 | 161 | self.trunk = nn.Sequential(*trunk) 162 | 163 | def forward(self,x): 164 | out = self.trunk(x) 165 | return out 166 | 167 | def ResNet10( flatten = True): 168 | return ResNet(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten) 169 | 170 | 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /models/resnet12.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # https://github.com/ElementAI/embedding-propagation/blob/master/src/models/backbones/resnet12.py 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | class Block(torch.nn.Module): 9 | def __init__(self, ni, no, stride, dropout=0, groups=1): 10 | super().__init__() 11 | self.dropout = torch.nn.Dropout2d(dropout) if dropout > 0 else lambda x: x 12 | self.conv0 = torch.nn.Conv2d(ni, no, 3, stride, padding=1, bias=False) 13 | self.bn0 = torch.nn.BatchNorm2d(no) 14 | self.conv1 = torch.nn.Conv2d(no, no, 3, 1, padding=1, bias=False) 15 | self.bn1 = torch.nn.BatchNorm2d(no) 16 | self.conv2 = torch.nn.Conv2d(no, no, 3, 1, padding=1, bias=False) 17 | self.bn2 = torch.nn.BatchNorm2d(no) 18 | if stride == 2 or ni != no: 19 | self.shortcut = torch.nn.Conv2d(ni, no, 1, stride=1, padding=0) 20 | 21 | def get_parameters(self): 22 | return self.parameters() 23 | 24 | def forward(self, x): 25 | y = F.relu(self.bn0(self.conv0(x)), True) 26 | y = self.dropout(y) 27 | y = F.relu(self.bn1(self.conv1(y)), True) 28 | y = self.dropout(y) 29 | y = self.bn2(self.conv2(y)) 30 | return F.relu(y + self.shortcut(x), True) 31 | 32 | 33 | class Resnet12(torch.nn.Module): 34 | def __init__(self, width, dropout): 35 | super().__init__() 36 | self.output_size = 512 37 | assert(width == 1) # Comment for different variants of this model 38 | self.widths = [x * int(width) for x in [64, 128, 256]] 39 | self.widths.append(self.output_size * width) 40 | self.bn_out = torch.nn.BatchNorm1d(self.output_size) 41 | 42 | start_width = 3 43 | for i in range(len(self.widths)): 44 | setattr(self, "group_%d" %i, Block(start_width, self.widths[i], 1, dropout)) 45 | start_width = self.widths[i] 46 | 47 | def add_classifier(self, nclasses, name="classifier", modalities=None): 48 | setattr(self, name, torch.nn.Linear(self.output_size, nclasses)) 49 | 50 | def up_to_embedding(self, x): 51 | """ Applies the four residual groups 52 | Args: 53 | x: input images 54 | n: number of few-shot classes 55 | k: number of images per few-shot class 56 | """ 57 | for i in range(len(self.widths)): 58 | x = getattr(self, "group_%d" % i)(x) 59 | x = F.max_pool2d(x, 3, 2, 1) 60 | return x 61 | 62 | def forward(self, x): 63 | """Main Pytorch forward function 64 | 65 | Returns: class logits 66 | 67 | Args: 68 | x: input mages 69 | """ 70 | *args, c, h, w = x.size() 71 | x = x.view(-1, c, h, w) 72 | x = self.up_to_embedding(x) 73 | return F.relu(self.bn_out(x.mean(3).mean(2)), True) 74 | -------------------------------------------------------------------------------- /student_STARTUP/configs.py: -------------------------------------------------------------------------------- 1 | 2 | # TODO: Please set the directory to the target datasets accordingly 3 | miniImageNet_path = '/scratch/datasets/CD-FSL/miniImageNet_test' 4 | tiered_ImageNet_path = '/scratch/datasets/tiered_imagenet/tiered_imagenet/original_split/test' 5 | 6 | ISIC_path = "/scratch/datasets/CD-FSL/ISIC" 7 | ChestX_path = "/scratch/datasets/CD-FSL/chestX" 8 | CropDisease_path = "/scratch/datasets/CD-FSL/CropDiseases" 9 | EuroSAT_path = "/scratch/datasets/CD-FSL/EuroSAT/2750" 10 | -------------------------------------------------------------------------------- /student_STARTUP/data: -------------------------------------------------------------------------------- 1 | ../data -------------------------------------------------------------------------------- /student_STARTUP/datasets: -------------------------------------------------------------------------------- 1 | ../datasets -------------------------------------------------------------------------------- /student_STARTUP/methods: -------------------------------------------------------------------------------- 1 | ../methods -------------------------------------------------------------------------------- /student_STARTUP/models: -------------------------------------------------------------------------------- 1 | ../models/ -------------------------------------------------------------------------------- /student_STARTUP/nx_ent.py: -------------------------------------------------------------------------------- 1 | # ported from https://github.com/sthalles/SimCLR/blob/master/loss/nt_xent.py 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class NTXentLoss(torch.nn.Module): 8 | 9 | def __init__(self, device, batch_size, temperature, use_cosine_similarity): 10 | super(NTXentLoss, self).__init__() 11 | self.batch_size = batch_size 12 | self.temperature = temperature 13 | self.device = device 14 | self.softmax = torch.nn.Softmax(dim=-1) 15 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) 16 | self.similarity_function = self._get_similarity_function( 17 | use_cosine_similarity) 18 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 19 | 20 | def _get_similarity_function(self, use_cosine_similarity): 21 | if use_cosine_similarity: 22 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 23 | return self._cosine_simililarity 24 | else: 25 | return self._dot_simililarity 26 | 27 | def _get_correlated_mask(self): 28 | diag = np.eye(2 * self.batch_size) 29 | l1 = np.eye((2 * self.batch_size), 2 * 30 | self.batch_size, k=-self.batch_size) 31 | l2 = np.eye((2 * self.batch_size), 2 * 32 | self.batch_size, k=self.batch_size) 33 | mask = torch.from_numpy((diag + l1 + l2)) 34 | mask = (1 - mask).type(torch.bool) 35 | return mask.to(self.device) 36 | 37 | @staticmethod 38 | def _dot_simililarity(x, y): 39 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 40 | # x shape: (N, 1, C) 41 | # y shape: (1, C, 2N) 42 | # v shape: (N, 2N) 43 | return v 44 | 45 | def _cosine_simililarity(self, x, y): 46 | # x shape: (N, 1, C) 47 | # y shape: (1, 2N, C) 48 | # v shape: (N, 2N) 49 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 50 | return v 51 | 52 | def forward(self, zis, zjs): 53 | representations = torch.cat([zjs, zis], dim=0) 54 | 55 | similarity_matrix = self.similarity_function( 56 | representations, representations) 57 | 58 | # filter out the scores from the positive samples 59 | l_pos = torch.diag(similarity_matrix, self.batch_size) 60 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 61 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 62 | 63 | negatives = similarity_matrix[self.mask_samples_from_same_repr].view( 64 | 2 * self.batch_size, -1) 65 | 66 | logits = torch.cat((positives, negatives), dim=1) 67 | logits /= self.temperature 68 | 69 | labels = torch.zeros(2 * self.batch_size).to(self.device).long() 70 | loss = self.criterion(logits, labels) 71 | 72 | return loss / (2 * self.batch_size) 73 | -------------------------------------------------------------------------------- /student_STARTUP/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # bash script to train STARTUP representation with SimCLR self-supervision 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | ############################################################################################## 7 | # Train student representation using MiniImageNet as the source 8 | ############################################################################################## 9 | # Before running the commands, please take care of the TODO appropriately 10 | for target_testset in "ChestX" "ISIC" "EuroSAT" "CropDisease" "miniImageNet_test" 11 | do 12 | # TODO: Please set the following argument appropriately 13 | # --teacher_path: filename for the teacher model 14 | # --base_path: path to find base dataset 15 | # --dir: directory to save the student representation. 16 | # E.g. the following commands trains a STARTUP representation based on the teacher specified at 17 | # ../teacher_miniImageNet/logs_deterministic/checkpoints/miniImageNet/ResNet10_baseline_256_aug/399.tar 18 | # The student representation is saved at miniImageNet_source/$target_testset\_unlabeled_20/checkpoint_best.pkl 19 | python STARTUP.py \ 20 | --dir miniImageNet_source/$target_testset\_unlabeled_20 \ 21 | --target_dataset $target_testset \ 22 | --image_size 224 \ 23 | --target_subset_split datasets/split_seed_1/$target_testset\_unlabeled_20.csv \ 24 | --bsize 256 \ 25 | --epochs 1000 \ 26 | --save_freq 50 \ 27 | --print_freq 10 \ 28 | --seed 1 \ 29 | --wd 1e-4 \ 30 | --num_workers 4 \ 31 | --model resnet10 \ 32 | --teacher_path ../teacher_miniImageNet/logs_deterministic/checkpoints/miniImageNet/ResNet10_baseline_256_aug/399.tar \ 33 | --teacher_path_version 0 \ 34 | --base_dataset miniImageNet \ 35 | --base_path /scratch/datasets/miniImageNet_full_resolution/train \ 36 | --base_no_color_jitter \ 37 | --base_val_ratio 0.05 \ 38 | --eval_freq 2 \ 39 | --batch_validate \ 40 | --resume_latest 41 | done 42 | 43 | ############################################################################################## 44 | # Train student representation using ImageNet as the source 45 | ############################################################################################## 46 | # Before running the commands, please take care of the TODO appropriately 47 | 48 | for target_testset in "ChestX" "ISIC" "EuroSAT" "CropDisease" 49 | do 50 | # TODO: Please set the following argument appropriately 51 | # --teacher_path: filename for the teacher model 52 | # --base_path: path to find base dataset 53 | # --dir: directory to save the student representation. 54 | # E.g. the following commands trains a STARTUP representation based on the teacher specified at 55 | # ../teacher_ImageNet/resnet18/checkpoint.pkl 56 | # The student representation is saved at ImageNet_source/$target_testset\_unlabeled_20/checkpoint_best.pkl 57 | python STARTUP.py \ 58 | --dir ImageNet_source/$target_testset\_unlabeled_20 \ 59 | --target_dataset $target_testset \ 60 | --image_size 224 \ 61 | --target_subset_split datasets/split_seed_1/$target_testset\_unlabeled_20.csv \ 62 | --bsize 256 \ 63 | --epochs 1000 \ 64 | --save_freq 50 \ 65 | --print_freq 10 \ 66 | --seed 1 \ 67 | --wd 1e-4 \ 68 | --num_workers 4 \ 69 | --model resnet18 \ 70 | --teacher_path ../teacher_ImageNet/resnet18/checkpoint.pkl \ 71 | --teacher_path_version 1 \ 72 | --base_dataset ImageNet \ 73 | --base_path /scratch/datasets/imagenet/train \ 74 | --base_no_color_jitter \ 75 | --base_val_ratio 0.01 \ 76 | --eval_freq 2 \ 77 | --batch_validate \ 78 | --resume_latest 79 | done 80 | 81 | 82 | ############################################################################################## 83 | # Train student representation using tieredImageNet as the source 84 | ############################################################################################## 85 | # Before running the commands, please take care of the TODO appropriately 86 | 87 | # source tieredImageNet 88 | for target_testset in "tiered_ImageNet_test" 89 | do 90 | # TODO: Please set the following argument appropriately 91 | # --teacher_path: filename for the teacher model 92 | # --base_path: path to find base dataset 93 | # --dir: directory to save the student representation. 94 | # --target_subset_split Either datasets/split_seed_1/$target\_unlabeled_10.csv (for the less unlabeled data setup) 95 | # or datasets/split_seed_1/$target\_unlabeled_50.csv (for the more unlabeled data setup) 96 | # E.g. the following commands trains a STARTUP representation based on the teacher specified at 97 | # ../teacher_miniImageNet/logs_deterministic/checkpoints/tiered_ImageNet/ResNet12_baseline_256/89.tar 98 | # The student representation is saved at tiered_ImageNet__source/$target_testset\_unlabeled_50/checkpoint_best.pkl 99 | python STARTUP.py \ 100 | --dir tiered_ImageNet_source/$target_testset\_unlabeled_50 \ 101 | --target_dataset $target_testset \ 102 | --image_size 84 \ 103 | --target_subset_split datasets/split_seed_1/$target_testset\_unlabeled_50.csv \ 104 | --bsize 256 \ 105 | --epochs 100 \ 106 | --save_freq 50 \ 107 | --print_freq 10 \ 108 | --seed 1 \ 109 | --wd 1e-4 \ 110 | --num_workers 2 \ 111 | --model resnet12 \ 112 | --teacher_path ../teacher_miniImageNet/logs_deterministic/checkpoints/tiered_ImageNet/ResNet12_baseline_256/89.tar \ 113 | --teacher_path_version 0 \ 114 | --base_dataset tiered_ImageNet \ 115 | --base_path /scratch/datasets/tiered_imagenet/tiered_imagenet/original_split/train \ 116 | --base_no_color_jitter \ 117 | --base_val_ratio 0.05 \ 118 | --eval_freq 2 \ 119 | --batch_validate \ 120 | --resume_latest 121 | done -------------------------------------------------------------------------------- /student_STARTUP/utils: -------------------------------------------------------------------------------- 1 | ../utils -------------------------------------------------------------------------------- /student_STARTUP_no_self_supervision/configs.py: -------------------------------------------------------------------------------- 1 | 2 | # TODO: Please set the directory to the target datasets accordingly 3 | miniImageNet_path = '/scratch/datasets/CD-FSL/miniImageNet_test' 4 | tiered_ImageNet_path = '/scratch/datasets/tiered_imagenet/tiered_imagenet/original_split/test' 5 | 6 | ISIC_path = "/scratch/datasets/CD-FSL/ISIC" 7 | ChestX_path = "/scratch/datasets/CD-FSL/chestX" 8 | CropDisease_path = "/scratch/datasets/CD-FSL/CropDiseases" 9 | EuroSAT_path = "/scratch/datasets/CD-FSL/EuroSAT/2750" 10 | -------------------------------------------------------------------------------- /student_STARTUP_no_self_supervision/data: -------------------------------------------------------------------------------- 1 | ../data -------------------------------------------------------------------------------- /student_STARTUP_no_self_supervision/datasets: -------------------------------------------------------------------------------- 1 | ../datasets -------------------------------------------------------------------------------- /student_STARTUP_no_self_supervision/methods: -------------------------------------------------------------------------------- 1 | ../methods -------------------------------------------------------------------------------- /student_STARTUP_no_self_supervision/models: -------------------------------------------------------------------------------- 1 | ../models/ -------------------------------------------------------------------------------- /student_STARTUP_no_self_supervision/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # bash script to train STARTUP representation without SimCLR self-supervision 4 | export CUDA_VISIBLE_DEVICES=1 5 | 6 | ############################################################################################## 7 | # Train student representation using MiniImageNet as the source 8 | ############################################################################################## 9 | # Before running the commands, please take care of the TODO appropriately 10 | for target_testset in "ChestX" "ISIC" "EuroSAT" "CropDisease" "miniImageNet_test" 11 | do 12 | # TODO: Please set the following argument appropriately 13 | # --teacher_path: filename for the teacher model 14 | # --base_path: path to find base dataset 15 | # --dir: directory to save the student representation. 16 | # E.g. the following commands trains a STARTUP representation based on the teacher specified at 17 | # ../teacher_miniImageNet/logs_deterministic/checkpoints/miniImageNet/ResNet10_baseline_256_aug/399.tar 18 | # The student representation is saved at miniImageNet_source/$target_testset\_unlabeled_20/checkpoint_best.pkl 19 | python STARTUP_no_SS.py \ 20 | --dir miniImageNet_source/$target_testset\_unlabeled_20 \ 21 | --target_dataset $target_testset \ 22 | --image_size 224 \ 23 | --target_subset_split datasets/split_seed_1/$target_testset\_unlabeled_20.csv \ 24 | --bsize 256 \ 25 | --epochs 1000 \ 26 | --save_freq 50 \ 27 | --print_freq 10 \ 28 | --seed 1 \ 29 | --wd 1e-4 \ 30 | --num_workers 4 \ 31 | --model resnet10 \ 32 | --teacher_path ../teacher_miniImageNet/logs_deterministic/checkpoints/miniImageNet/ResNet10_baseline_256_aug/399.tar \ 33 | --teacher_path_version 0 \ 34 | --base_dataset miniImageNet \ 35 | --base_path /scratch/datasets/miniImageNet_full_resolution/train \ 36 | --base_no_color_jitter \ 37 | --base_val_ratio 0.05 \ 38 | --eval_freq 2 \ 39 | --resume_latest 40 | done 41 | 42 | 43 | ############################################################################################## 44 | # Train student representation using ImageNet as the source 45 | ############################################################################################## 46 | # Before running the commands, please take care of the TODO appropriately 47 | 48 | for target_testset in "ChestX" "ISIC" "EuroSAT" "CropDisease" 49 | do 50 | # TODO: Please set the following argument appropriately 51 | # --teacher_path: filename for the teacher model 52 | # --base_path: path to find base dataset 53 | # --dir: directory to save the student representation. 54 | # E.g. the following commands trains a STARTUP representation based on the teacher specified at 55 | # ../teacher_ImageNet/resnet18/checkpoint.pkl 56 | # The student representation is saved at ImageNet_source/$target_testset\_unlabeled_20/checkpoint_best.pkl 57 | python STARTUP_no_SS.py \ 58 | --dir ImageNet_source/$target_testset\_unlabeled_20 \ 59 | --target_dataset $target_testset \ 60 | --image_size 224 \ 61 | --target_subset_split datasets/split_seed_1/$target_testset\_unlabeled_20.csv \ 62 | --bsize 256 \ 63 | --epochs 1000 \ 64 | --save_freq 50 \ 65 | --print_freq 10 \ 66 | --seed 1 \ 67 | --wd 1e-4 \ 68 | --num_workers 4 \ 69 | --model resnet18 \ 70 | --teacher_path ../teacher_ImageNet/resnet18/checkpoint.pkl \ 71 | --teacher_path_version 1 \ 72 | --base_dataset ImageNet \ 73 | --base_path /scratch/datasets/imagenet/train \ 74 | --base_no_color_jitter \ 75 | --base_val_ratio 0.01 \ 76 | --eval_freq 2 \ 77 | --resume_latest 78 | done 79 | 80 | ############################################################################################## 81 | # Train student representation using tieredImageNet as the source 82 | ############################################################################################## 83 | # Before running the commands, please take care of the TODO appropriately 84 | 85 | for target_testset in "tiered_ImageNet_test" 86 | do 87 | # TODO: Please set the following argument appropriately 88 | # --teacher_path: filename for the teacher model 89 | # --base_path: path to find base dataset 90 | # --dir: directory to save the student representation. 91 | # --target_subset_split Either datasets/split_seed_1/$target\_unlabeled_10.csv (for the less unlabeled data setup) 92 | # or datasets/split_seed_1/$target\_unlabeled_50.csv (for the more unlabeled data setup) 93 | # E.g. the following commands trains a STARTUP representation based on the teacher specified at 94 | # ../teacher_miniImageNet/logs_deterministic/checkpoints/tiered_ImageNet/ResNet12_baseline_256/89.tar 95 | # The student representation is saved at tiered_ImageNet__source/$target_testset\_unlabeled_50/checkpoint_best.pkl 96 | python STARTUP_no_SS.py \ 97 | --dir tiered_ImageNet_source/$target_testset\_unlabeled_50 \ 98 | --target_dataset $target_testset \ 99 | --image_size 84 \ 100 | --target_subset_split datasets/split_seed_1/$target_testset\_unlabeled_50.csv \ 101 | --bsize 256 \ 102 | --epochs 100 \ 103 | --save_freq 50 \ 104 | --print_freq 10 \ 105 | --seed 1 \ 106 | --wd 1e-4 \ 107 | --num_workers 2 \ 108 | --model resnet12 \ 109 | --teacher_path ../teacher_miniImageNet/logs_deterministic/checkpoints/tiered_ImageNet/ResNet12_baseline_256/89.tar \ 110 | --teacher_path_version 0 \ 111 | --base_dataset tiered_ImageNet \ 112 | --base_path /scratch/datasets/tiered_imagenet/tiered_imagenet/original_split/train \ 113 | --base_no_color_jitter \ 114 | --base_val_ratio 0.05 \ 115 | --eval_freq 2 \ 116 | --resume_latest 117 | done -------------------------------------------------------------------------------- /student_STARTUP_no_self_supervision/utils: -------------------------------------------------------------------------------- 1 | ../utils -------------------------------------------------------------------------------- /teacher_ImageNet/convert_imagenet_weight.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torchvision.models 3 | import argparse 4 | 5 | import models 6 | 7 | import torch.nn as nn 8 | 9 | import os 10 | import torch 11 | 12 | def main(args): 13 | if args.model == 'resnet18': 14 | backbone = models.resnet18(remove_last_relu=False, input_high_res=True).cuda() 15 | else: 16 | raise ValueError("Invalid backbone!") 17 | 18 | pretrained_model = getattr( 19 | torchvision.models, args.model)(pretrained=True).cuda() 20 | 21 | # load the backbone parameters 22 | for i in range(5): 23 | layer_name = f"layer{i}" 24 | 25 | # the first layer requires special handling 26 | # If the input is low resolution, then the first conv1 will have kernel size 3x3 27 | # instead of 7x7 28 | if layer_name == 'layer0': 29 | mod = getattr(backbone, layer_name) 30 | mod[0].load_state_dict( 31 | getattr(pretrained_model, 'conv1').state_dict()) 32 | mod[1].load_state_dict( 33 | getattr(pretrained_model, 'bn1').state_dict()) 34 | else: 35 | getattr(backbone, layer_name).load_state_dict( 36 | getattr(pretrained_model, layer_name).state_dict()) 37 | 38 | if not os.path.isdir(args.save_dir): 39 | os.mkdir(args.save_dir) 40 | 41 | sd = { 42 | 'model': backbone.state_dict(), 43 | 'clf': pretrained_model.fc.state_dict() 44 | } 45 | 46 | torch.save(sd, os.path.join(args.save_dir, 'checkpoint.pkl')) 47 | 48 | 49 | 50 | return 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser(description="Convert the pretrained ImageNet ResNet weight") 55 | parser.add_argument('--save_dir', type=str, default='.', help='Directory to save the pretrained weights') 56 | parser.add_argument('--model', type=str, default='resnet18', help='which resnet model') 57 | args = parser.parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /teacher_ImageNet/models: -------------------------------------------------------------------------------- 1 | ../models/ -------------------------------------------------------------------------------- /teacher_ImageNet/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # convert a pretrained imagenet model from PyTorch to a weight format that will be used 4 | # for other experiments. 5 | # TODO: Set --save_dir to specify the directory to save the converted model weights. 6 | python convert_imagenet_weight.py --save_dir resnet18 --model resnet18 -------------------------------------------------------------------------------- /teacher_miniImageNet/configs.py: -------------------------------------------------------------------------------- 1 | # TODO: Set the directory to save the model 2 | save_dir = './logs_deterministic' 3 | 4 | # TODO: Set the directory to the miniImageNet/tieredImageNet dataset 5 | miniImageNet_path = '/scratch/datasets/miniImageNet_full_resolution/train' 6 | tiered_ImageNet_path = '/scratch/datasets/tiered_imagenet/tiered_imagenet/original_split/train' 7 | -------------------------------------------------------------------------------- /teacher_miniImageNet/data: -------------------------------------------------------------------------------- 1 | ../data -------------------------------------------------------------------------------- /teacher_miniImageNet/datasets: -------------------------------------------------------------------------------- 1 | ../datasets -------------------------------------------------------------------------------- /teacher_miniImageNet/io_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import argparse 5 | import models 6 | 7 | def create_resnet12(): 8 | m = models.Resnet12(width=1, dropout=0.1) 9 | m.final_feat_dim = m.output_size 10 | return m 11 | 12 | def create_resnet18(): 13 | m = models.resnet18(remove_last_relu=False, 14 | input_high_res=True) 15 | m.final_feat_dim = 512 16 | return m 17 | 18 | model_dict = { 19 | 'ResNet10': models.ResNet10, 20 | 'ResNet12': create_resnet12, 21 | 'ResNet18': create_resnet18, 22 | } 23 | 24 | def parse_args(script): 25 | parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script)) 26 | parser.add_argument('--dataset' , default='miniImagenet', help='training base model') 27 | parser.add_argument('--subset_split', help='split for dataset') 28 | parser.add_argument('--model' , default='ResNet10', help='backbone architecture') 29 | parser.add_argument('--method' , default='baseline', help='baseline/protonet') 30 | parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training') 31 | parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ') 32 | parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support') 33 | parser.add_argument('--train_aug' , action='store_true', help='perform data augmentation or not during training ') 34 | parser.add_argument('--freeze_backbone' , action='store_true', help='Freeze the backbone network for finetuning') 35 | parser.add_argument('--seed', default=1, type=int, help='random seed') 36 | parser.add_argument('--bsize', default=256, type=int, help='batchsize for supervised training') 37 | 38 | parser.add_argument('--models_to_use', '--names-list', nargs='+', default=['miniImageNet', 'caltech256', 'DTD', 'cifar100', 'CUB'], help='pretained model to use') 39 | parser.add_argument('--fine_tune_all_models' , action='store_true', help='fine-tune each model before selection') #still required for save_features.py and test.py to find the model path correctly 40 | 41 | if script == 'train': 42 | parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline') #make it larger than the maximum label value in base class 43 | parser.add_argument('--save_freq' , default=50, type=int, help='Save frequency') 44 | parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch') 45 | parser.add_argument('--stop_epoch' , default=400, type=int, help ='Stopping epoch') # for meta-learning methods, each epoch contains 100 episodes 46 | 47 | elif script == 'save_features': 48 | parser.add_argument('--split' , default='novel', help='base/val/novel') #default novel, but you can also test base/val class accuracy if you want 49 | parser.add_argument('--save_iter', default=-1, type=int,help ='save feature from the model trained in x epoch, use the best model if x is -1') 50 | elif script == 'test': 51 | parser.add_argument('--split' , default='novel', help='base/val/novel') #default novel, but you can also test base/val class accuracy if you want 52 | parser.add_argument('--save_iter', default=-1, type=int,help ='saved feature from the model trained in x epoch, use the best model if x is -1') 53 | parser.add_argument('--adaptation' , action='store_true', help='further adaptation in test time or not') 54 | else: 55 | raise ValueError('Unknown script') 56 | 57 | return parser.parse_args() 58 | 59 | def get_assigned_file(checkpoint_dir,num): 60 | assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num)) 61 | return assign_file 62 | 63 | def get_resume_file(checkpoint_dir): 64 | filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar')) 65 | if len(filelist) == 0: 66 | return None 67 | 68 | filelist = [ x for x in filelist if os.path.basename(x) != 'best_model.tar' ] 69 | epochs = np.array([int(os.path.splitext(os.path.basename(x))[0]) for x in filelist]) 70 | max_epoch = np.max(epochs) 71 | resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(max_epoch)) 72 | return resume_file 73 | 74 | def get_best_file(checkpoint_dir): 75 | best_file = os.path.join(checkpoint_dir, 'best_model.tar') 76 | if os.path.isfile(best_file): 77 | return best_file 78 | else: 79 | return get_resume_file(checkpoint_dir) 80 | -------------------------------------------------------------------------------- /teacher_miniImageNet/methods: -------------------------------------------------------------------------------- 1 | ../methods -------------------------------------------------------------------------------- /teacher_miniImageNet/models: -------------------------------------------------------------------------------- 1 | ../models/ -------------------------------------------------------------------------------- /teacher_miniImageNet/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | # training a classification model on miniImageNet 5 | python train.py --dataset miniImageNet --model ResNet10 --method baseline --bsize 256 --start_epoch 0 --stop_epoch 400 --train_aug 6 | 7 | # training a classification model on tieredImageNet 8 | # python train.py --dataset tiered_ImageNet --model ResNet12 --method baseline --bsize 256 --start_epoch 0 --stop_epoch 90 -------------------------------------------------------------------------------- /teacher_miniImageNet/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim 5 | import torch.optim.lr_scheduler as lr_scheduler 6 | # import time 7 | import os 8 | # import glob 9 | 10 | import configs 11 | from data.datamgr import SimpleDataManager, SetDataManager 12 | from methods.baselinetrain import BaselineTrain 13 | from methods.protonet import ProtoNet 14 | 15 | from io_utils import model_dict, parse_args 16 | from datasets import miniImageNet_few_shot, tiered_ImageNet_few_shot, ImageNet_few_shot 17 | 18 | import utils 19 | import wandb 20 | 21 | from tqdm import tqdm 22 | import random 23 | 24 | 25 | def train(base_loader, model, optimization, start_epoch, stop_epoch, params, logger): 26 | if optimization == 'Adam': 27 | optimizer = torch.optim.Adam(model.parameters()) 28 | else: 29 | raise ValueError('Unknown optimization, please define by yourself') 30 | 31 | for epoch in tqdm(range(start_epoch,stop_epoch)): 32 | model.train() 33 | perf = model.train_loop(epoch, base_loader, optimizer, logger) 34 | 35 | if not os.path.isdir(params.checkpoint_dir): 36 | os.makedirs(params.checkpoint_dir) 37 | 38 | if (epoch % params.save_freq == 0) or (epoch == stop_epoch - 1): 39 | outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch)) 40 | torch.save({'epoch':epoch, 'state':model.state_dict(), 41 | 'optimizer': optimizer.state_dict()}, outfile) 42 | 43 | wandb.log({'loss': perf['Loss/avg']}, step=epoch+1) 44 | wandb.log({'top1': perf['top1/avg'], 45 | 'top5': perf['top5/avg'], 46 | 'top1_per_class': perf['top1_per_class/avg'], 47 | 'top5_per_class': perf['top5_per_class/avg']}, step=epoch+1) 48 | 49 | return model 50 | 51 | if __name__=='__main__': 52 | params = parse_args('train') 53 | image_size = 224 54 | bsize = params.bsize 55 | optimization = 'Adam' 56 | 57 | torch.backends.cudnn.deterministic = True 58 | torch.backends.cudnn.benchmark = False 59 | np.random.seed(params.seed) 60 | torch.random.manual_seed(params.seed) 61 | torch.cuda.manual_seed(params.seed) 62 | random.seed(params.seed) 63 | 64 | save_dir = configs.save_dir 65 | params.checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s' % ( 66 | save_dir, params.dataset, params.model, params.method, bsize) 67 | if params.train_aug: 68 | params.checkpoint_dir += '_aug' 69 | 70 | if not params.method in ['baseline', 'baseline++']: 71 | params.checkpoint_dir += '_%dway_%dshot' % ( 72 | params.train_n_way, params.n_shot) 73 | 74 | if not os.path.isdir(params.checkpoint_dir): 75 | os.makedirs(params.checkpoint_dir) 76 | 77 | logger = utils.create_logger(os.path.join(params.checkpoint_dir, 'checkpoint.log'), __name__) 78 | 79 | if params.method in ['baseline'] : 80 | 81 | if params.dataset == "miniImageNet": 82 | # Original Batchsize is 16 83 | datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size=bsize, split=params.subset_split) 84 | base_loader = datamgr.get_data_loader(aug=params.train_aug, num_workers=8) 85 | params.num_classes = 64 86 | elif params.dataset == 'tiered_ImageNet': 87 | image_size = 84 88 | # Do no augmentation for tiered imagenet to be consisitent with the literature 89 | datamgr = tiered_ImageNet_few_shot.SimpleDataManager( 90 | image_size, batch_size=bsize, split=params.subset_split) 91 | base_loader = datamgr.get_data_loader( 92 | aug=False, num_workers=8) 93 | print("Number of images", len(base_loader.dataset)) 94 | params.num_classes = 351 95 | elif params.dataset == 'ImageNet': 96 | datamgr = ImageNet_few_shot.SimpleDataManager( 97 | image_size, batch_size=bsize, split=params.subset_split) 98 | base_loader = datamgr.get_data_loader( 99 | aug=params.train_aug, num_workers=8) 100 | print("Number of images", len(base_loader.dataset)) 101 | params.num_classes = 1000 102 | else: 103 | raise ValueError('Unknown dataset') 104 | 105 | model = BaselineTrain(model_dict[params.model], params.num_classes) 106 | 107 | elif params.method in ['protonet']: 108 | n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small 109 | train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) 110 | test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) 111 | 112 | if params.dataset == "miniImageNet": 113 | 114 | datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_query = n_query, **train_few_shot_params) 115 | base_loader = datamgr.get_data_loader(aug = params.train_aug) 116 | 117 | else: 118 | raise ValueError('Unknown dataset') 119 | 120 | if params.method == 'protonet': 121 | model = ProtoNet( model_dict[params.model], **train_few_shot_params ) 122 | 123 | else: 124 | raise ValueError('Unknown method') 125 | 126 | for arg in vars(params): 127 | logger.info(f"{arg}: {getattr(params, arg)}") 128 | 129 | logger.info(f"Image_size: {image_size}") 130 | logger.info(f"Optimization: {optimization}") 131 | 132 | wandb.init(project='cross_task_distillation', 133 | group=__file__, 134 | name=f'{__file__}_{params.checkpoint_dir}') 135 | 136 | wandb.config.update(params) 137 | 138 | model = model.cuda() 139 | 140 | start_epoch = params.start_epoch 141 | stop_epoch = params.stop_epoch 142 | 143 | model = train(base_loader, model, optimization, start_epoch, stop_epoch, params, logger=logger) 144 | -------------------------------------------------------------------------------- /teacher_miniImageNet/utils: -------------------------------------------------------------------------------- 1 | ../utils -------------------------------------------------------------------------------- /utils/AverageMeterSet.py: -------------------------------------------------------------------------------- 1 | # Specify classes or functions that will be exported 2 | __all__ = ['AverageMeter', 'AverageMeterSet'] 3 | 4 | class AverageMeterSet: 5 | def __init__(self): 6 | self.meters = {} 7 | 8 | def __getitem__(self, key): 9 | return self.meters[key] 10 | 11 | def update(self, name, value, n=1): 12 | if not name in self.meters: 13 | self.meters[name] = AverageMeter() 14 | self.meters[name].update(value, n) 15 | 16 | def reset(self): 17 | for meter in self.meters.values(): 18 | meter.reset() 19 | 20 | def values(self, postfix=''): 21 | return {name + postfix: meter.val for name, meter in self.meters.items()} 22 | 23 | def averages(self, postfix='/avg'): 24 | return {name + postfix: meter.avg for name, meter in self.meters.items()} 25 | 26 | def sums(self, postfix='/sum'): 27 | return {name + postfix: meter.sum for name, meter in self.meters.items()} 28 | 29 | def counts(self, postfix='/count'): 30 | return {name + postfix: meter.count for name, meter in self.meters.items()} 31 | 32 | class AverageMeter: 33 | """Computes and stores the average and current value""" 34 | 35 | def __init__(self): 36 | self.reset() 37 | 38 | def reset(self): 39 | self.val = 0 40 | self.avg = 0 41 | self.sum = 0 42 | self.count = 0 43 | 44 | def update(self, val, n=1): 45 | ''' 46 | val is the average value 47 | n : the number of items used to calculate the average 48 | ''' 49 | self.val = val 50 | self.sum += val * n 51 | self.count += n 52 | self.avg = self.sum / self.count 53 | 54 | def __format__(self, format): 55 | return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # this add the __all__ from different module into the namespace 2 | from .AverageMeterSet import * 3 | from .create_logger import * 4 | from .savelog import * 5 | from .count_paramters import * 6 | from .accuracy import * 7 | from .average_model import * 8 | from .cdfsl_utils import * -------------------------------------------------------------------------------- /utils/accuracy.py: -------------------------------------------------------------------------------- 1 | # Compute accuracy 2 | import torch 3 | 4 | def accuracy(logits, ground_truth, topk=[1, ]): 5 | assert len(logits) == len(ground_truth) 6 | # this function will calculate per class acc 7 | # average per class acc and acc 8 | 9 | n, d = logits.shape 10 | 11 | label_unique = torch.unique(ground_truth) 12 | acc = {} 13 | acc['average'] = torch.zeros(len(topk)) 14 | acc['per_class_average'] = torch.zeros(len(topk)) 15 | acc['per_class'] = [[] for _ in label_unique] 16 | acc['gt_unique'] = label_unique 17 | acc['topk'] = topk 18 | acc['num_classes'] = d 19 | 20 | max_k = max(topk) 21 | argsort = torch.argsort(logits, dim=1, descending=True)[:, :min([max_k, d])] 22 | correct = (argsort == ground_truth.view(-1, 1)).float() 23 | 24 | for indi, i in enumerate(label_unique): 25 | ind = torch.nonzero(ground_truth == i, as_tuple=False).view(-1) 26 | correct_target = correct[ind] 27 | 28 | # calculate topk 29 | for indj, j in enumerate(topk): 30 | num_correct_partial = torch.sum(correct_target[:, :j]).item() 31 | acc_partial = num_correct_partial / len(correct_target) 32 | acc['average'][indj] += num_correct_partial 33 | acc['per_class_average'][indj] += acc_partial 34 | acc['per_class'][indi].append(acc_partial * 100) 35 | 36 | acc['average'] = acc['average'] / n * 100 37 | acc['per_class_average'] = acc['per_class_average'] / len(label_unique) * 100 38 | 39 | return acc 40 | -------------------------------------------------------------------------------- /utils/average_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import copy 5 | import warnings 6 | 7 | 8 | class running_ensemble(nn.Module): 9 | def __init__(self, model): 10 | super(running_ensemble, self).__init__() 11 | self.model = copy.deepcopy(model) 12 | self.model.eval() 13 | 14 | for p in self.model.parameters(): 15 | p.requires_grad_(False) 16 | 17 | self.register_buffer('num_models', torch.zeros(1)) 18 | self.bn_updated = False 19 | return 20 | 21 | def update(self, model): 22 | alpha = 1 / (self.num_models + 1) 23 | for p1, p2 in zip(self.model.parameters(), model.parameters()): 24 | p1.data *= (1 - alpha) 25 | p1.data += p2.data * alpha 26 | 27 | self.num_models += 1 28 | self.bn_update = False 29 | 30 | @staticmethod 31 | def _reset_bn(module): 32 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 33 | module.running_mean = torch.zeros_like(module.running_mean) 34 | module.running_var = torch.ones_like(module.running_var) 35 | 36 | @staticmethod 37 | def _get_momenta(module, momenta): 38 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 39 | momenta[module] = module.momentum 40 | 41 | 42 | @staticmethod 43 | def _set_momenta(module, momenta): 44 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 45 | module.momentum = momenta[module] 46 | 47 | def update_bn(self, loader): 48 | self.model.train() 49 | self.model.apply(running_ensemble._reset_bn) 50 | is_cuda = next(self.model.parameters()).is_cuda 51 | 52 | momenta = {} 53 | self.model.apply(lambda module: running_ensemble._get_momenta(module, momenta)) 54 | n = 0 55 | for X, _ in loader: 56 | if is_cuda: 57 | X = X.cuda() 58 | 59 | b = len(X) 60 | momentum = b / (n + b) 61 | 62 | for module in momenta.keys(): 63 | module.momentum = momentum 64 | 65 | self.model(X) 66 | 67 | n += b 68 | 69 | self.model.apply(lambda module: running_ensemble._set_momenta(module, momenta)) 70 | self.model.eval() 71 | self.bn_updated = True 72 | return 73 | 74 | def forward(self, x): 75 | if not self.bn_updated: 76 | warnings.warn('Running Mean and Variance of BatchNorm is not Updated!. Use with Care!') 77 | return self.model(x) 78 | -------------------------------------------------------------------------------- /utils/cdfsl_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def adjust_learning_rate(optimizer, epoch, lr=0.01, step1=30, step2=60, step3=90): 5 | """Sets the learning rate to the initial LR decayed by 10 every X epochs""" 6 | if epoch >= step3: 7 | lr = lr * 0.001 8 | elif epoch >= step2: 9 | lr = lr * 0.01 10 | elif epoch >= step1: 11 | lr = lr * 0.1 12 | else: 13 | lr = lr 14 | for param_group in optimizer.param_groups: 15 | param_group['lr'] = lr 16 | 17 | def one_hot(y, num_class): 18 | return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1) 19 | 20 | def sparsity(cl_data_file): 21 | class_list = cl_data_file.keys() 22 | cl_sparsity = [] 23 | for cl in class_list: 24 | cl_sparsity.append(np.mean([np.sum(x!=0) for x in cl_data_file[cl] ]) ) 25 | 26 | return np.mean(cl_sparsity) -------------------------------------------------------------------------------- /utils/count_paramters.py: -------------------------------------------------------------------------------- 1 | # ported from from https://github.com/CuriousAI/mean-teacher/blob/master/pytorch/mean_teacher/utils.py 2 | __all__ = ['parameter_count'] 3 | 4 | def parameter_count(module, verbose=False): 5 | params = list(module.named_parameters()) 6 | total_count = sum(int(param.numel()) for name, param in params) 7 | if verbose: 8 | lines = [ 9 | "", 10 | "List of model parameters:", 11 | "=========================", 12 | ] 13 | 14 | row_format = "{name:<40} {shape:>20} ={total_size:>12,d}" 15 | 16 | for name, param in params: 17 | lines.append(row_format.format( 18 | name=name, 19 | shape=" * ".join(str(p) for p in param.size()), 20 | total_size=param.numel() 21 | )) 22 | lines.append("=" * 75) 23 | lines.append(row_format.format( 24 | name="all parameters", 25 | shape="sum of above", 26 | total_size=total_count 27 | )) 28 | lines.append("") 29 | print("\n".join(lines)) 30 | return total_count -------------------------------------------------------------------------------- /utils/create_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | __all__ = ['create_logger'] 4 | 5 | def create_logger(fname, logger_name): 6 | # Get a logger with name logger_name 7 | logger = logging.getLogger(logger_name) 8 | 9 | # File handler for log 10 | hdlr = logging.FileHandler(fname) 11 | # Format of the logging information 12 | formatter = logging.Formatter('%(levelname)s %(message)s') 13 | 14 | hdlr.setFormatter(formatter) 15 | logger.addHandler(hdlr) 16 | 17 | # Set the level to logging info, meaning anything information 18 | # with information level above info will be logged 19 | logger.setLevel(logging.INFO) 20 | 21 | return logger -------------------------------------------------------------------------------- /utils/savelog.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | import os 4 | 5 | import pandas as pd 6 | 7 | __all__ = ['savelog'] 8 | 9 | class savelog: 10 | ''' Saves training log to csv''' 11 | INCREMENTAL_UPDATE_TIME = 0 12 | 13 | def __init__(self, directory, name): 14 | self.file_path = os.path.join(directory, "{}_{:%Y-%m-%d_%H:%M:%S}.csv".format(name, datetime.datetime.now())) 15 | self.data = {} 16 | self.last_update_time = time.time() - self.INCREMENTAL_UPDATE_TIME 17 | 18 | def record(self, step, value_dict): 19 | self.data[step] = value_dict 20 | if time.time() - self.last_update_time >= self.INCREMENTAL_UPDATE_TIME: 21 | self.last_update_time = time.time() 22 | self.save() 23 | 24 | def save(self): 25 | df = pd.DataFrame.from_dict(self.data, orient='index').to_csv(self.file_path) --------------------------------------------------------------------------------