├── .gitignore ├── README.md ├── dataloader.py ├── dataset ├── cifar.py ├── dataset_selfsupervision.py ├── mini_imagenet.py ├── tiered_imagenet.py └── transform_cfg.py ├── distill ├── NCEAverage.py ├── NCECriterion.py ├── __init__.py ├── alias_multinomial.py ├── criterion.py └── util.py ├── eval ├── __init__.py ├── cls_eval.py ├── meta_eval.py └── util.py ├── eval_fewshot.py ├── models ├── __init__.py ├── convnet.py ├── resnet.py ├── resnet_new.py ├── resnet_sd.py ├── resnet_selfdist.py ├── resnet_ssl.py ├── util.py └── wresnet.py ├── requirements.txt ├── run.sh ├── train_distillation.py ├── train_selfsupervison.py ├── util.py └── utils └── figs ├── main.png ├── results1.png ├── results2.png └── training.png /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | output*/ 4 | ckpts/ 5 | *.pth 6 | *.t7 7 | *.jpg 8 | tmp*.py 9 | # run*.sh 10 | *.pdf 11 | *.npy 12 | models_distilled/ 13 | tensorboard/ 14 | models_pretrained/ 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SKD : Self-supervised Knowledge Distillation for Few-shot Learning 2 | Official implementation of "SKD : Self-supervised Knowledge Distillation for Few-shot Learning". [(paper link)](https://arxiv.org/abs/2006.09785). The paper reports state-of-the-art results on four popular few-shot learning benchmarks. 3 | 4 | Real-world contains an overwhelmingly large number of object classes, learning all of which at once is impossible. Few shot learning is a promising learning paradigm due to its ability to learn out of order distributions quickly with only a few samples. Recent works show that simply learning a good feature embedding can outperform more sophisticated meta-learning and metric learning algorithms. In this paper, we propose a simple approach to improve the representation capacity of deep neural networks for few-shot learning tasks. We follow a two-stage learning process: First, we train a neural network to maximize the entropy of the feature embedding, thus creating an optimal output manifold using self-supervision as an auxiliary loss. In the second stage, we minimize the entropy on feature embedding by bringing self-supervised twins together, while constraining the manifold with student-teacher distillation. Our experiments show that, even in the first stage, auxiliary self-supervision can outperform current state-of-the-art methods, with further gains achieved by our second stage distillation process. 5 | 6 | This official code provides an implementation for our SKD. This repository is implemented using PyTorch and it includes code for running the few-shot learning experiments on **CIFAR-FS**, **FC-100**, **miniImageNet** and **tieredImageNet** datasets. 7 | 8 |

9 |

(a) SKD has two stage learning. In Gen-0, self-supervision is used to estimate the true prediction manifold, equivariant to input transformations. Specifically, we enforce the model to predict the amount of input rotation using only the output logits. In Gen-1, we force the original sample outputs to be the same as in Gen-0 (dotted lines), while reducing its distance with its augmented versions to enhance discriminability.

10 | 11 |

12 |

(b) SKD training pipeline.

13 | 14 | ## Dependencies 15 | This code requires the following: 16 | * matplotlib==3.2.1 17 | * mkl==2019.0 18 | * numpy==1.18.4 19 | * Pillow==7.1.2 20 | * scikit_learn==0.23.1 21 | * scipy==1.4.1 22 | * torch==1.5.0 23 | * torchvision==0.6.0 24 | * tqdm==4.46.0 25 | * wandb==0.8.36 26 | 27 | run `pip3 install -r requirements.txt` to install all the dependencies. 28 | 29 | ## Download Data 30 | The data we used here is preprocessed by the repo of [MetaOptNet](https://github.com/kjunelee/MetaOptNet), Please find the renamed versions of the files in below link by [RFS](https://github.com/WangYueFt/rfs). 31 | 32 | [[DropBox Data Packages Link]](https://www.dropbox.com/sh/6yd1ygtyc3yd981/AABVeEqzC08YQv4UZk7lNHvya?dl=0) 33 | 34 | ## Training 35 | 36 | ### Generation Zero 37 | To perform the Generation Zero experiment, run: 38 | 39 | `python3 train_supervised_ssl.py --tags cifarfs,may30 --model resnet12_ssl --model_path save/backup --dataset CIFAR-FS --data_root ../../Datasets/CIFAR_FS/ --n_aug_support_samples 5 --n_ways 5 --n_shots 1 --epochs 65 --lr_decay_epochs 60 --gamma 2.0` 40 | 41 | WANDB will create unique names for each runs, and save the model names accordingly. Use this name for the teacher in the next experiment. 42 | 43 | 44 | ### Generation One 45 | To perform the Generation One experiment, run: 46 | 47 | `python3 train_distillation.py --tags cifarfs,gen1,may30 --model_s resnet12_ssl --model_t resnet12_ssl --path_t save/backup/resnet12_ssl_CIFAR-FS_lr_0.05_decay_0.0005_trans_D_trial_1/model_firm-sun-1.pth --model_path save/backup --dataset CIFAR-FS --data_root ../../Datasets/CIFAR_FS/ --n_aug_support_samples 5 --n_ways 5 --n_shots 1 --epochs 65 --lr_decay_epochs 60 --gamma 0.1` 48 | 49 | 50 | ### Evaluation 51 | 52 | `python3 eval_fewshot.py --model resnet12_ssl --model_path save/backup2/resnet12_ssl_toy_lr_0.05_decay_0.0005_trans_A_trial_1/model_firm-sun-1.pth --dataset toy --data_root ../../Datasets/CIFAR_FS/ --n_aug_support_samples 5 --n_ways 5 --n_shots 1` 53 | 54 | 55 | ## Results 56 | 57 | We perform extensive experiments on four datasets in a few-shot learning setting, leading to significant improvements over the state of the art methods. 58 | 59 |

60 |

(c) SKD performance on miniImageNet and tieredImageNet.

61 |
62 |
63 |

64 |

(d) SKD performance on CIFAR-FS and FC100 datasets.

65 | 66 | 67 | ## We Credit 68 | Thanks to https://github.com/WangYueFt/rfs, for the preliminary implementations. 69 | 70 | ## Contact 71 | Jathushan Rajasegaran - jathushan.rajasegaran@inceptioniai.org or brjathu@gmail.com 72 |
73 | To ask questions or report issues, please open an issue on the [issues tracker](https://github.com/brjathu/SKD/issues). 74 |
75 | Discussions, suggestions and questions are welcome! 76 | 77 | 78 | ## Citation 79 | ``` 80 | @article{rajasegaran2020self, 81 | title={Self-supervised Knowledge Distillation for Few-shot Learning}, 82 | author={Rajasegaran, Jathushan and Khan, Salman and Hayat, Munawar and Khan, Fahad Shahbaz and Shah, Mubarak}, 83 | journal={https://arxiv.org/abs/2006.09785}, 84 | year = {2020} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import socket 6 | import time 7 | import sys 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | import torch.backends.cudnn as cudnn 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | 17 | from dataset.mini_imagenet import ImageNet, MetaImageNet 18 | from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet 19 | from dataset.cifar import CIFAR100, MetaCIFAR100, CIFAR100_toy 20 | from dataset.transform_cfg import transforms_options, transforms_test_options, transforms_list 21 | 22 | from dataset.dataset_selfsupervision import SSDatasetWrapper 23 | 24 | import numpy as np 25 | 26 | 27 | def get_dataloaders(opt): 28 | # dataloader 29 | train_partition = 'trainval' if opt.use_trainval else 'train' 30 | 31 | if opt.dataset == 'toy': 32 | 33 | train_trans, test_trans = transforms_options['D'] 34 | 35 | train_loader = DataLoader(CIFAR100_toy(args=opt, partition=train_partition, transform=train_trans), 36 | batch_size=opt.batch_size, shuffle=True, drop_last=True, 37 | num_workers=opt.num_workers) 38 | val_loader = DataLoader(CIFAR100_toy(args=opt, partition='train', transform=test_trans), 39 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, 40 | num_workers=opt.num_workers // 2) 41 | 42 | # train_trans, test_trans = transforms_test_options[opt.transform] 43 | 44 | # meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', 45 | # train_transform=train_trans, 46 | # test_transform=test_trans), 47 | # batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 48 | # num_workers=opt.num_workers) 49 | # meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', 50 | # train_transform=train_trans, 51 | # test_transform=test_trans), 52 | # batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 53 | # num_workers=opt.num_workers) 54 | n_cls = 5 55 | 56 | return train_loader, val_loader, 5, 5, n_cls 57 | 58 | 59 | if opt.dataset == 'miniImageNet': 60 | 61 | train_trans, test_trans = transforms_options[opt.transform] 62 | train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans), 63 | batch_size=opt.batch_size, shuffle=True, drop_last=True, 64 | num_workers=opt.num_workers) 65 | val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans), 66 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, 67 | num_workers=opt.num_workers // 2) 68 | 69 | train_trans, test_trans = transforms_test_options[opt.transform] 70 | meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test', 71 | train_transform=train_trans, 72 | test_transform=test_trans), 73 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 74 | num_workers=opt.num_workers) 75 | meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val', 76 | train_transform=train_trans, 77 | test_transform=test_trans), 78 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 79 | num_workers=opt.num_workers) 80 | 81 | if opt.use_trainval: 82 | n_cls = 80 83 | else: 84 | n_cls = 64 85 | elif opt.dataset == 'tieredImageNet': 86 | train_trans, test_trans = transforms_options[opt.transform] 87 | train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans), 88 | batch_size=opt.batch_size, shuffle=True, drop_last=True, 89 | num_workers=opt.num_workers) 90 | val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans), 91 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, 92 | num_workers=opt.num_workers // 2) 93 | 94 | train_trans, test_trans = transforms_test_options[opt.transform] 95 | meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test', 96 | train_transform=train_trans, 97 | test_transform=test_trans), 98 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 99 | num_workers=opt.num_workers) 100 | meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val', 101 | train_transform=train_trans, 102 | test_transform=test_trans), 103 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 104 | num_workers=opt.num_workers) 105 | if opt.use_trainval: 106 | n_cls = 448 107 | else: 108 | n_cls = 351 109 | elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': 110 | train_trans, test_trans = transforms_options['D'] 111 | 112 | train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans), 113 | batch_size=opt.batch_size, shuffle=True, drop_last=True, 114 | num_workers=opt.num_workers) 115 | val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans), 116 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False, 117 | num_workers=opt.num_workers // 2) 118 | 119 | train_trans, test_trans = transforms_test_options[opt.transform] 120 | 121 | 122 | # ns = [opt.n_shots].copy() 123 | # opt.n_ways = 32 124 | # opt.n_shots = 5 125 | # opt.n_aug_support_samples = 2 126 | meta_trainloader = DataLoader(MetaCIFAR100(args=opt, partition='train', 127 | train_transform=train_trans, 128 | test_transform=test_trans), 129 | batch_size=1, shuffle=True, drop_last=False, 130 | num_workers=opt.num_workers) 131 | 132 | # opt.n_ways = 5 133 | # opt.n_shots = ns[0] 134 | # print(opt.n_shots) 135 | # opt.n_aug_support_samples = 5 136 | meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test', 137 | train_transform=train_trans, 138 | test_transform=test_trans), 139 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 140 | num_workers=opt.num_workers) 141 | meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val', 142 | train_transform=train_trans, 143 | test_transform=test_trans), 144 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False, 145 | num_workers=opt.num_workers) 146 | if opt.use_trainval: 147 | n_cls = 80 148 | else: 149 | if opt.dataset == 'CIFAR-FS': 150 | n_cls = 64 151 | elif opt.dataset == 'FC100': 152 | n_cls = 60 153 | else: 154 | raise NotImplementedError('dataset not supported: {}'.format(opt.dataset)) 155 | # return train_loader, val_loader, meta_trainloader, meta_testloader, meta_valloader, n_cls 156 | else: 157 | raise NotImplementedError(opt.dataset) 158 | 159 | return train_loader, val_loader, meta_testloader, meta_valloader, n_cls -------------------------------------------------------------------------------- /dataset/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import pickle 5 | from PIL import Image 6 | import numpy as np 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class CIFAR100(Dataset): 14 | """support FC100 and CIFAR-FS""" 15 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 16 | transform=None): 17 | super(Dataset, self).__init__() 18 | self.data_root = args.data_root 19 | self.partition = partition 20 | self.data_aug = args.data_aug 21 | self.mean = [0.5071, 0.4867, 0.4408] 22 | self.std = [0.2675, 0.2565, 0.2761] 23 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 24 | self.pretrain = pretrain 25 | self.simclr = args.simclr 26 | 27 | 28 | if transform is None: 29 | if self.partition == 'train' and self.data_aug: 30 | self.transform = transforms.Compose([ 31 | lambda x: Image.fromarray(x), 32 | transforms.RandomCrop(32, padding=4), 33 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 34 | transforms.RandomHorizontalFlip(), 35 | lambda x: np.asarray(x), 36 | transforms.ToTensor(), 37 | self.normalize 38 | ]) 39 | else: 40 | self.transform = transforms.Compose([ 41 | lambda x: Image.fromarray(x), 42 | transforms.ToTensor(), 43 | self.normalize 44 | ]) 45 | else: 46 | self.transform = transform 47 | 48 | if self.pretrain: 49 | self.file_pattern = '%s.pickle' 50 | else: 51 | self.file_pattern = '%s.pickle' 52 | self.data = {} 53 | 54 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f: 55 | data = pickle.load(f, encoding='latin1') 56 | self.imgs = data['data'] 57 | labels = data['labels'] 58 | # adjust sparse labels to labels from 0 to n. 59 | cur_class = 0 60 | label2label = {} 61 | for idx, label in enumerate(labels): 62 | if label not in label2label: 63 | label2label[label] = cur_class 64 | cur_class += 1 65 | new_labels = [] 66 | for idx, label in enumerate(labels): 67 | new_labels.append(label2label[label]) 68 | self.labels = new_labels 69 | 70 | 71 | # pre-process for contrastive sampling 72 | self.k = k 73 | self.is_sample = is_sample 74 | if self.is_sample: 75 | self.labels = np.asarray(self.labels) 76 | self.labels = self.labels - np.min(self.labels) 77 | num_classes = np.max(self.labels) + 1 78 | 79 | self.cls_positive = [[] for _ in range(num_classes)] 80 | for i in range(len(self.imgs)): 81 | self.cls_positive[self.labels[i]].append(i) 82 | 83 | self.cls_negative = [[] for _ in range(num_classes)] 84 | for i in range(num_classes): 85 | for j in range(num_classes): 86 | if j == i: 87 | continue 88 | self.cls_negative[i].extend(self.cls_positive[j]) 89 | 90 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 91 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 92 | self.cls_positive = np.asarray(self.cls_positive) 93 | self.cls_negative = np.asarray(self.cls_negative) 94 | 95 | def __getitem__(self, item): 96 | img = np.asarray(self.imgs[item]).astype('uint8') 97 | target = self.labels[item] - min(self.labels) 98 | 99 | if(self.simclr): 100 | img1 = self.transform(img) 101 | img2 = self.transform(img) 102 | return (img1, img2), target, item 103 | 104 | img = self.transform(img) 105 | if not self.is_sample: 106 | return img, target, item 107 | else: 108 | pos_idx = item 109 | replace = True if self.k > len(self.cls_negative[target]) else False 110 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 111 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 112 | return img, target, item, sample_idx 113 | 114 | def __len__(self): 115 | return len(self.labels) 116 | 117 | 118 | 119 | 120 | class CIFAR100_toy(Dataset): 121 | """support FC100 and CIFAR-FS""" 122 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 123 | transform=None): 124 | super(Dataset, self).__init__() 125 | self.data_root = args.data_root 126 | self.partition = partition 127 | self.data_aug = args.data_aug 128 | self.mean = [0.5071, 0.4867, 0.4408] 129 | self.std = [0.2675, 0.2565, 0.2761] 130 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 131 | self.pretrain = pretrain 132 | self.simclr = args.simclr 133 | 134 | 135 | if transform is None: 136 | if self.partition == 'train' and self.data_aug: 137 | self.transform = transforms.Compose([ 138 | lambda x: Image.fromarray(x), 139 | transforms.RandomCrop(32, padding=4), 140 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 141 | transforms.RandomHorizontalFlip(), 142 | lambda x: np.asarray(x), 143 | transforms.ToTensor(), 144 | self.normalize 145 | ]) 146 | else: 147 | self.transform = transforms.Compose([ 148 | lambda x: Image.fromarray(x), 149 | transforms.ToTensor(), 150 | self.normalize 151 | ]) 152 | else: 153 | self.transform = transform 154 | 155 | if self.pretrain: 156 | self.file_pattern = '%s.pickle' 157 | else: 158 | self.file_pattern = '%s.pickle' 159 | self.data = {} 160 | 161 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f: 162 | data = pickle.load(f, encoding='latin1') 163 | self.imgs = data['data'] 164 | labels = data['labels'] 165 | # adjust sparse labels to labels from 0 to n. 166 | cur_class = 0 167 | label2label = {} 168 | for idx, label in enumerate(labels): 169 | if label not in label2label: 170 | label2label[label] = cur_class 171 | cur_class += 1 172 | new_labels = [] 173 | for idx, label in enumerate(labels): 174 | new_labels.append(label2label[label]) 175 | self.labels = new_labels 176 | 177 | self.labels = np.array(self.labels) 178 | self.imgs = np.array(self.imgs) 179 | print(self.labels.shape) 180 | print(self.imgs.shape) 181 | 182 | loc = np.where(self.labels<5)[0] 183 | self.labels = self.labels[loc] 184 | self.imgs = self.imgs[loc] 185 | 186 | 187 | self.k = k 188 | self.is_sample = is_sample 189 | 190 | def __getitem__(self, item): 191 | img = np.asarray(self.imgs[item]).astype('uint8') 192 | target = self.labels[item] - min(self.labels) 193 | 194 | if(self.simclr): 195 | img1 = self.transform(img) 196 | img2 = self.transform(img) 197 | return (img1, img2), target, item 198 | 199 | img = self.transform(img) 200 | if not self.is_sample: 201 | return img, target, item 202 | else: 203 | pos_idx = item 204 | replace = True if self.k > len(self.cls_negative[target]) else False 205 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 206 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 207 | return img, target, item, sample_idx 208 | 209 | def __len__(self): 210 | return len(self.labels) 211 | 212 | 213 | 214 | 215 | 216 | class MetaCIFAR100(CIFAR100): 217 | 218 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True): 219 | super(MetaCIFAR100, self).__init__(args, partition, False) 220 | self.fix_seed = fix_seed 221 | self.n_ways = args.n_ways 222 | self.n_shots = args.n_shots 223 | self.n_queries = args.n_queries 224 | self.classes = list(self.data.keys()) 225 | self.n_test_runs = args.n_test_runs 226 | self.n_aug_support_samples = args.n_aug_support_samples 227 | if train_transform is None: 228 | self.train_transform = transforms.Compose([ 229 | lambda x: Image.fromarray(x), 230 | transforms.RandomCrop(32, padding=4), 231 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 232 | transforms.RandomHorizontalFlip(), 233 | lambda x: np.asarray(x), 234 | transforms.ToTensor(), 235 | self.normalize 236 | ]) 237 | else: 238 | self.train_transform = train_transform 239 | 240 | if test_transform is None: 241 | self.test_transform = transforms.Compose([ 242 | lambda x: Image.fromarray(x), 243 | transforms.ToTensor(), 244 | self.normalize 245 | ]) 246 | else: 247 | self.test_transform = test_transform 248 | 249 | self.data = {} 250 | for idx in range(self.imgs.shape[0]): 251 | if self.labels[idx] not in self.data: 252 | self.data[self.labels[idx]] = [] 253 | self.data[self.labels[idx]].append(self.imgs[idx]) 254 | self.classes = list(self.data.keys()) 255 | 256 | def __getitem__(self, item): 257 | if self.fix_seed: 258 | np.random.seed(item) 259 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 260 | 261 | support_xs = [] 262 | support_ys = [] 263 | support_ts = [] 264 | query_xs = [] 265 | query_ys = [] 266 | query_ts = [] 267 | for idx, cls in enumerate(cls_sampled): 268 | imgs = np.asarray(self.data[cls]).astype('uint8') 269 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 270 | support_xs.append(imgs[support_xs_ids_sampled]) 271 | support_ys.append([idx] * self.n_shots) 272 | support_ts.append([cls] * self.n_shots) 273 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 274 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 275 | query_xs.append(imgs[query_xs_ids]) 276 | query_ys.append([idx] * query_xs_ids.shape[0]) 277 | query_ts.append([cls] * query_xs_ids.shape[0]) 278 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 279 | query_xs), np.array(query_ys) 280 | support_ts, query_ts = np.array(support_ts), np.array(query_ts) 281 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 282 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 283 | query_ys = query_ys.reshape((num_ways * n_queries_per_way,)) 284 | query_ts = query_ts.reshape((num_ways * n_queries_per_way,)) 285 | 286 | support_xs = support_xs.reshape((-1, height, width, channel)) 287 | if self.n_aug_support_samples > 1: 288 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 289 | support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples)) 290 | support_ts = np.tile(support_ts.reshape((-1,)), (self.n_aug_support_samples)) 291 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 292 | 293 | 294 | 295 | query_xs = query_xs.reshape((-1, height, width, channel)) 296 | if self.n_aug_support_samples > 1: 297 | query_xs = np.tile(query_xs, (self.n_aug_support_samples, 1, 1, 1)) 298 | query_ys = np.tile(query_ys.reshape((-1,)), (self.n_aug_support_samples)) 299 | query_ts = np.tile(query_ts.reshape((-1,)), (self.n_aug_support_samples)) 300 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 301 | 302 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 303 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 304 | 305 | return support_xs, support_ys, query_xs, query_ys 306 | 307 | def __len__(self): 308 | return self.n_test_runs 309 | 310 | 311 | if __name__ == '__main__': 312 | args = lambda x: None 313 | args.n_ways = 5 314 | args.n_shots = 1 315 | args.n_queries = 12 316 | # args.data_root = 'data' 317 | args.data_root = '/home/yonglong/Downloads/FC100' 318 | args.data_aug = True 319 | args.n_test_runs = 5 320 | args.n_aug_support_samples = 1 321 | imagenet = CIFAR100(args, 'train') 322 | print(len(imagenet)) 323 | print(imagenet.__getitem__(500)[0].shape) 324 | 325 | metaimagenet = MetaCIFAR100(args, 'train') 326 | print(len(metaimagenet)) 327 | print(metaimagenet.__getitem__(500)[0].size()) 328 | print(metaimagenet.__getitem__(500)[1].shape) 329 | print(metaimagenet.__getitem__(500)[2].size()) 330 | print(metaimagenet.__getitem__(500)[3].shape) -------------------------------------------------------------------------------- /dataset/dataset_selfsupervision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | 5 | 6 | class SSDatasetWrapper(torch.utils.data.Dataset): 7 | def __init__(self, dset, opt): 8 | self.dset = dset 9 | self.opt = opt 10 | 11 | def __getitem__(self, index): 12 | image, target, item = self.dset[index] 13 | 14 | if(not(self.opt.ssl)): 15 | return image, target, item 16 | else: 17 | if(self.opt.ssl_rot): 18 | label = np.random.randint(4) 19 | if label == 1: 20 | image_rot = tensor_rot_90(image) 21 | elif label == 2: 22 | image_rot = tensor_rot_180(image) 23 | elif label == 3: 24 | image_rot = tensor_rot_270(image) 25 | else: 26 | image_rot = image 27 | 28 | return (image, image_rot), (target, label), item 29 | 30 | if(self.opt.ssl_quad): 31 | label = np.random.randint(4) 32 | 33 | horstr = image.size(1) // 2 34 | verstr = image.size(2) // 2 35 | horlab = label // 2 36 | verlab = label % 2 37 | 38 | image_quad = image[:, horlab*horstr:(horlab+1)*horstr, verlab*verstr:(verlab+1)*verstr,] 39 | return (image, image_quad), (target, label), item 40 | 41 | def __len__(self): 42 | return len(self.dset) 43 | 44 | # Assumes that tensor is (nchannels, height, width) 45 | def tensor_rot_90(x): 46 | return x.flip(2).transpose(1,2) 47 | def tensor_rot_90_digit(x): 48 | return x.transpose(1,2) 49 | 50 | def tensor_rot_180(x): 51 | return x.flip(2).flip(1) 52 | def tensor_rot_180_digit(x): 53 | return x.flip(2) 54 | 55 | def tensor_rot_270(x): 56 | return x.transpose(1,2).flip(2) -------------------------------------------------------------------------------- /dataset/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ImageNet(Dataset): 11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 12 | transform=None): 13 | super(Dataset, self).__init__() 14 | self.data_root = args.data_root 15 | self.partition = partition 16 | self.data_aug = args.data_aug 17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 19 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 20 | self.pretrain = pretrain 21 | 22 | if transform is None: 23 | if self.partition == 'train' and self.data_aug: 24 | self.transform = transforms.Compose([ 25 | lambda x: Image.fromarray(x), 26 | transforms.RandomCrop(84, padding=8), 27 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 28 | transforms.RandomHorizontalFlip(), 29 | lambda x: np.asarray(x), 30 | transforms.ToTensor(), 31 | self.normalize 32 | ]) 33 | else: 34 | self.transform = transforms.Compose([ 35 | lambda x: Image.fromarray(x), 36 | transforms.ToTensor(), 37 | self.normalize 38 | ]) 39 | else: 40 | self.transform = transform 41 | 42 | if self.pretrain: 43 | self.file_pattern = 'miniImageNet_category_split_train_phase_%s.pickle' 44 | else: 45 | self.file_pattern = 'miniImageNet_category_split_%s.pickle' 46 | self.data = {} 47 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f: 48 | data = pickle.load(f, encoding='latin1') 49 | self.imgs = data['data'] 50 | self.labels = data['labels'] 51 | 52 | 53 | 54 | 55 | 56 | # pre-process for contrastive sampling 57 | self.k = k 58 | self.is_sample = is_sample 59 | if self.is_sample: 60 | self.labels = np.asarray(self.labels) 61 | self.labels = self.labels - np.min(self.labels) 62 | num_classes = np.max(self.labels) + 1 63 | 64 | self.cls_positive = [[] for _ in range(num_classes)] 65 | for i in range(len(self.imgs)): 66 | self.cls_positive[self.labels[i]].append(i) 67 | 68 | self.cls_negative = [[] for _ in range(num_classes)] 69 | for i in range(num_classes): 70 | for j in range(num_classes): 71 | if j == i: 72 | continue 73 | self.cls_negative[i].extend(self.cls_positive[j]) 74 | 75 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 76 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 77 | self.cls_positive = np.asarray(self.cls_positive) 78 | self.cls_negative = np.asarray(self.cls_negative) 79 | 80 | def __getitem__(self, item): 81 | img = np.asarray(self.imgs[item]).astype('uint8') 82 | img = self.transform(img) 83 | target = self.labels[item] - min(self.labels) 84 | 85 | if not self.is_sample: 86 | return img, target, item 87 | else: 88 | pos_idx = item 89 | replace = True if self.k > len(self.cls_negative[target]) else False 90 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 91 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 92 | return img, target, item, sample_idx 93 | 94 | def __len__(self): 95 | return len(self.labels) 96 | 97 | 98 | class MetaImageNet(ImageNet): 99 | 100 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True): 101 | super(MetaImageNet, self).__init__(args, partition, False) 102 | self.fix_seed = fix_seed 103 | self.n_ways = args.n_ways 104 | self.n_shots = args.n_shots 105 | self.n_queries = args.n_queries 106 | self.classes = list(self.data.keys()) 107 | self.n_test_runs = args.n_test_runs 108 | self.n_aug_support_samples = args.n_aug_support_samples 109 | if train_transform is None: 110 | self.train_transform = transforms.Compose([ 111 | lambda x: Image.fromarray(x), 112 | transforms.RandomCrop(84, padding=8), 113 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 114 | transforms.RandomHorizontalFlip(), 115 | lambda x: np.asarray(x), 116 | transforms.ToTensor(), 117 | self.normalize 118 | ]) 119 | else: 120 | self.train_transform = train_transform 121 | 122 | if test_transform is None: 123 | self.test_transform = transforms.Compose([ 124 | lambda x: Image.fromarray(x), 125 | transforms.ToTensor(), 126 | self.normalize 127 | ]) 128 | else: 129 | self.test_transform = test_transform 130 | 131 | self.data = {} 132 | for idx in range(self.imgs.shape[0]): 133 | if self.labels[idx] not in self.data: 134 | self.data[self.labels[idx]] = [] 135 | self.data[self.labels[idx]].append(self.imgs[idx]) 136 | self.classes = list(self.data.keys()) 137 | 138 | def __getitem__(self, item): 139 | if self.fix_seed: 140 | np.random.seed(item) 141 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 142 | support_xs = [] 143 | support_ys = [] 144 | query_xs = [] 145 | query_ys = [] 146 | for idx, cls in enumerate(cls_sampled): 147 | imgs = np.asarray(self.data[cls]).astype('uint8') 148 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 149 | support_xs.append(imgs[support_xs_ids_sampled]) 150 | support_ys.append([idx] * self.n_shots) 151 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 152 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 153 | query_xs.append(imgs[query_xs_ids]) 154 | query_ys.append([idx] * query_xs_ids.shape[0]) 155 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 156 | query_xs), np.array(query_ys) 157 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 158 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 159 | query_ys = query_ys.reshape((num_ways * n_queries_per_way, )) 160 | 161 | support_xs = support_xs.reshape((-1, height, width, channel)) 162 | if self.n_aug_support_samples > 1: 163 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 164 | support_ys = np.tile(support_ys.reshape((-1, )), (self.n_aug_support_samples)) 165 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 166 | query_xs = query_xs.reshape((-1, height, width, channel)) 167 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 168 | 169 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 170 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 171 | 172 | return support_xs, support_ys, query_xs, query_ys 173 | 174 | def __len__(self): 175 | return self.n_test_runs 176 | 177 | 178 | if __name__ == '__main__': 179 | args = lambda x: None 180 | args.n_ways = 5 181 | args.n_shots = 1 182 | args.n_queries = 12 183 | args.data_root = 'data' 184 | args.data_aug = True 185 | args.n_test_runs = 5 186 | args.n_aug_support_samples = 1 187 | imagenet = ImageNet(args, 'val') 188 | print(len(imagenet)) 189 | print(imagenet.__getitem__(500)[0].shape) 190 | 191 | metaimagenet = MetaImageNet(args) 192 | print(len(metaimagenet)) 193 | print(metaimagenet.__getitem__(500)[0].size()) 194 | print(metaimagenet.__getitem__(500)[1].shape) 195 | print(metaimagenet.__getitem__(500)[2].size()) 196 | print(metaimagenet.__getitem__(500)[3].shape) 197 | -------------------------------------------------------------------------------- /dataset/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class TieredImageNet(Dataset): 11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096, 12 | transform=None): 13 | super(Dataset, self).__init__() 14 | self.data_root = args.data_root 15 | self.partition = partition 16 | self.data_aug = args.data_aug 17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 19 | 20 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std) 21 | self.pretrain = pretrain 22 | 23 | if transform is None: 24 | if self.partition == 'train' and self.data_aug: 25 | self.transform = transforms.Compose([ 26 | lambda x: Image.fromarray(x), 27 | transforms.RandomCrop(84, padding=8), 28 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 29 | transforms.RandomHorizontalFlip(), 30 | lambda x: np.asarray(x), 31 | transforms.ToTensor(), 32 | self.normalize 33 | ]) 34 | else: 35 | self.transform = transforms.Compose([ 36 | lambda x: Image.fromarray(x), 37 | transforms.ToTensor(), 38 | self.normalize 39 | ]) 40 | else: 41 | self.transform = transform 42 | 43 | if self.pretrain: 44 | self.image_file_pattern = '%s_images.npz' 45 | self.label_file_pattern = '%s_labels.pkl' 46 | else: 47 | self.image_file_pattern = '%s_images.npz' 48 | self.label_file_pattern = '%s_labels.pkl' 49 | 50 | self.data = {} 51 | 52 | # modified code to load tieredImageNet 53 | image_file = os.path.join(self.data_root, self.image_file_pattern % partition) 54 | self.imgs = np.load(image_file)['images'] 55 | label_file = os.path.join(self.data_root, self.label_file_pattern % partition) 56 | self.labels = self._load_labels(label_file)['labels'] 57 | 58 | # pre-process for contrastive sampling 59 | self.k = k 60 | self.is_sample = is_sample 61 | if self.is_sample: 62 | self.labels = np.asarray(self.labels) 63 | self.labels = self.labels - np.min(self.labels) 64 | num_classes = np.max(self.labels) + 1 65 | 66 | self.cls_positive = [[] for _ in range(num_classes)] 67 | for i in range(len(self.imgs)): 68 | self.cls_positive[self.labels[i]].append(i) 69 | 70 | self.cls_negative = [[] for _ in range(num_classes)] 71 | for i in range(num_classes): 72 | for j in range(num_classes): 73 | if j == i: 74 | continue 75 | self.cls_negative[i].extend(self.cls_positive[j]) 76 | 77 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 78 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 79 | self.cls_positive = np.asarray(self.cls_positive) 80 | self.cls_negative = np.asarray(self.cls_negative) 81 | 82 | def __getitem__(self, item): 83 | img = np.asarray(self.imgs[item]).astype('uint8') 84 | img = self.transform(img) 85 | target = self.labels[item] - min(self.labels) 86 | 87 | if not self.is_sample: 88 | return img, target, item 89 | else: 90 | pos_idx = item 91 | replace = True if self.k > len(self.cls_negative[target]) else False 92 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 93 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 94 | return img, target, item, sample_idx 95 | 96 | def __len__(self): 97 | return len(self.labels) 98 | 99 | @staticmethod 100 | def _load_labels(file): 101 | try: 102 | with open(file, 'rb') as fo: 103 | data = pickle.load(fo) 104 | return data 105 | except: 106 | with open(file, 'rb') as f: 107 | u = pickle._Unpickler(f) 108 | u.encoding = 'latin1' 109 | data = u.load() 110 | return data 111 | 112 | 113 | class MetaTieredImageNet(TieredImageNet): 114 | 115 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True): 116 | super(MetaTieredImageNet, self).__init__(args, partition, False) 117 | self.fix_seed = fix_seed 118 | self.n_ways = args.n_ways 119 | self.n_shots = args.n_shots 120 | self.n_queries = args.n_queries 121 | self.classes = list(self.data.keys()) 122 | self.n_test_runs = args.n_test_runs 123 | self.n_aug_support_samples = args.n_aug_support_samples 124 | if train_transform is None: 125 | self.train_transform = transforms.Compose([ 126 | lambda x: Image.fromarray(x), 127 | transforms.RandomCrop(84, padding=8), 128 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 129 | transforms.RandomHorizontalFlip(), 130 | lambda x: np.asarray(x), 131 | transforms.ToTensor(), 132 | self.normalize 133 | ]) 134 | else: 135 | self.train_transform = train_transform 136 | 137 | if test_transform is None: 138 | self.test_transform = transforms.Compose([ 139 | lambda x: Image.fromarray(x), 140 | transforms.ToTensor(), 141 | self.normalize 142 | ]) 143 | else: 144 | self.test_transform = test_transform 145 | 146 | self.data = {} 147 | for idx in range(self.imgs.shape[0]): 148 | if self.labels[idx] not in self.data: 149 | self.data[self.labels[idx]] = [] 150 | self.data[self.labels[idx]].append(self.imgs[idx]) 151 | self.classes = list(self.data.keys()) 152 | 153 | def __getitem__(self, item): 154 | if self.fix_seed: 155 | np.random.seed(item) 156 | cls_sampled = np.random.choice(self.classes, self.n_ways, False) 157 | support_xs = [] 158 | support_ys = [] 159 | query_xs = [] 160 | query_ys = [] 161 | for idx, cls in enumerate(cls_sampled): 162 | imgs = np.asarray(self.data[cls]).astype('uint8') 163 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False) 164 | support_xs.append(imgs[support_xs_ids_sampled]) 165 | support_ys.append([idx] * self.n_shots) 166 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled) 167 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False) 168 | query_xs.append(imgs[query_xs_ids]) 169 | query_ys.append([idx] * query_xs_ids.shape[0]) 170 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array( 171 | query_xs), np.array(query_ys) 172 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape 173 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel)) 174 | query_ys = query_ys.reshape((num_ways * n_queries_per_way,)) 175 | 176 | support_xs = support_xs.reshape((-1, height, width, channel)) 177 | if self.n_aug_support_samples > 1: 178 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1)) 179 | support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples)) 180 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0) 181 | query_xs = query_xs.reshape((-1, height, width, channel)) 182 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0) 183 | 184 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs))) 185 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs))) 186 | 187 | return support_xs, support_ys, query_xs, query_ys 188 | 189 | def __len__(self): 190 | return self.n_test_runs 191 | 192 | 193 | if __name__ == '__main__': 194 | args = lambda x: None 195 | args.n_ways = 5 196 | args.n_shots = 1 197 | args.n_queries = 12 198 | # args.data_root = 'data' 199 | args.data_root = '/home/yonglong/Data/tiered-imagenet-kwon' 200 | args.data_aug = True 201 | args.n_test_runs = 5 202 | args.n_aug_support_samples = 1 203 | imagenet = TieredImageNet(args, 'train') 204 | print(len(imagenet)) 205 | print(imagenet.__getitem__(500)[0].shape) 206 | 207 | metaimagenet = MetaTieredImageNet(args) 208 | print(len(metaimagenet)) 209 | print(metaimagenet.__getitem__(500)[0].size()) 210 | print(metaimagenet.__getitem__(500)[1].shape) 211 | print(metaimagenet.__getitem__(500)[2].size()) 212 | print(metaimagenet.__getitem__(500)[3].shape) 213 | -------------------------------------------------------------------------------- /dataset/transform_cfg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import torchvision.transforms as transforms 6 | 7 | 8 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0] 9 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0] 10 | normalize = transforms.Normalize(mean=mean, std=std) 11 | 12 | 13 | transform_A = [ 14 | transforms.Compose([ 15 | lambda x: Image.fromarray(x), 16 | transforms.RandomCrop(84, padding=8), 17 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 18 | transforms.RandomHorizontalFlip(), 19 | lambda x: np.asarray(x), 20 | transforms.ToTensor(), 21 | normalize 22 | ]), 23 | 24 | transforms.Compose([ 25 | lambda x: Image.fromarray(x), 26 | transforms.ToTensor(), 27 | normalize 28 | ]) 29 | ] 30 | 31 | transform_A_test = [ 32 | transforms.Compose([ 33 | lambda x: Image.fromarray(x), 34 | transforms.RandomCrop(84, padding=8), 35 | transforms.RandomHorizontalFlip(), 36 | lambda x: np.asarray(x), 37 | transforms.ToTensor(), 38 | normalize 39 | ]), 40 | 41 | transforms.Compose([ 42 | lambda x: Image.fromarray(x), 43 | transforms.ToTensor(), 44 | normalize 45 | ]) 46 | ] 47 | 48 | 49 | 50 | 51 | 52 | 53 | # CIFAR style transformation 54 | mean = [0.5071, 0.4867, 0.4408] 55 | std = [0.2675, 0.2565, 0.2761] 56 | normalize_cifar100 = transforms.Normalize(mean=mean, std=std) 57 | transform_D = [ 58 | transforms.Compose([ 59 | lambda x: Image.fromarray(x), 60 | transforms.RandomCrop(32, padding=4), 61 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 62 | # transforms.RandomRotation(10), 63 | transforms.RandomHorizontalFlip(), 64 | lambda x: np.asarray(x), 65 | transforms.ToTensor(), 66 | normalize_cifar100 67 | ]), 68 | 69 | transforms.Compose([ 70 | lambda x: Image.fromarray(x), 71 | transforms.ToTensor(), 72 | normalize_cifar100 73 | ]) 74 | ] 75 | 76 | transform_D_test = [ 77 | transforms.Compose([ 78 | lambda x: Image.fromarray(x), 79 | transforms.RandomCrop(32, padding=4), 80 | transforms.RandomHorizontalFlip(), 81 | lambda x: np.asarray(x), 82 | transforms.ToTensor(), 83 | normalize_cifar100 84 | ]), 85 | 86 | transforms.Compose([ 87 | lambda x: Image.fromarray(x), 88 | transforms.ToTensor(), 89 | normalize_cifar100 90 | ]) 91 | ] 92 | 93 | 94 | transforms_list = ['A', 'D'] 95 | 96 | 97 | transforms_options = { 98 | 'A': transform_A, 99 | 'D': transform_D, 100 | } 101 | 102 | transforms_test_options = { 103 | 'A': transform_A_test, 104 | 'D': transform_D_test, 105 | } 106 | -------------------------------------------------------------------------------- /distill/NCEAverage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch import nn 4 | from .alias_multinomial import AliasMethod 5 | import math 6 | 7 | 8 | class NCESoftmax(nn.Module): 9 | 10 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5): 11 | super(NCESoftmax, self).__init__() 12 | self.nLem = outputSize 13 | self.unigrams = torch.ones(self.nLem) 14 | self.multinomial = AliasMethod(self.unigrams) 15 | self.multinomial.cuda() 16 | self.K = K 17 | 18 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum])) 19 | stdv = 1. / math.sqrt(inputSize / 3) 20 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 21 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 22 | 23 | def forward(self, l, ab, y, idx=None): 24 | K = int(self.params[0].item()) 25 | T = self.params[1].item() 26 | Z_l = self.params[2].item() 27 | Z_ab = self.params[3].item() 28 | 29 | momentum = self.params[4].item() 30 | batchSize = l.size(0) 31 | outputSize = self.memory_l.size(0) 32 | inputSize = self.memory_l.size(1) 33 | 34 | # original score computation 35 | if idx is None: 36 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 37 | idx.select(1, 0).copy_(y.data) 38 | # sample 39 | weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach() 40 | weight_l = weight_l.view(batchSize, K + 1, inputSize) 41 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1)) 42 | # out_ab = torch.exp(torch.div(out_ab, T)) 43 | out_ab = torch.div(out_ab, T) 44 | # sample 45 | weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach() 46 | weight_ab = weight_ab.view(batchSize, K + 1, inputSize) 47 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1)) 48 | # out_l = torch.exp(torch.div(out_l, T)) 49 | out_l = torch.div(out_l, T) 50 | 51 | # set Z if haven't been set yet 52 | if Z_l < 0: 53 | # self.params[2] = out_l.mean() * outputSize 54 | self.params[2] = 1 55 | Z_l = self.params[2].clone().detach().item() 56 | print("normalization constant Z_l is set to {:.1f}".format(Z_l)) 57 | if Z_ab < 0: 58 | # self.params[3] = out_ab.mean() * outputSize 59 | self.params[3] = 1 60 | Z_ab = self.params[3].clone().detach().item() 61 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab)) 62 | 63 | # compute out_l, out_ab 64 | # out_l = torch.div(out_l, Z_l).contiguous() 65 | # out_ab = torch.div(out_ab, Z_ab).contiguous() 66 | out_l = out_l.contiguous() 67 | out_ab = out_ab.contiguous() 68 | 69 | # update memory 70 | with torch.no_grad(): 71 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1)) 72 | l_pos.mul_(momentum) 73 | l_pos.add_(torch.mul(l, 1 - momentum)) 74 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 75 | updated_l = l_pos.div(l_norm) 76 | self.memory_l.index_copy_(0, y, updated_l) 77 | 78 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1)) 79 | ab_pos.mul_(momentum) 80 | ab_pos.add_(torch.mul(ab, 1 - momentum)) 81 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 82 | updated_ab = ab_pos.div(ab_norm) 83 | self.memory_ab.index_copy_(0, y, updated_ab) 84 | 85 | return out_l, out_ab 86 | 87 | 88 | class NCEAverage(nn.Module): 89 | 90 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5): 91 | super(NCEAverage, self).__init__() 92 | self.nLem = outputSize 93 | self.unigrams = torch.ones(self.nLem) 94 | self.multinomial = AliasMethod(self.unigrams) 95 | self.multinomial.cuda() 96 | self.K = K 97 | 98 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum])) 99 | stdv = 1. / math.sqrt(inputSize / 3) 100 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 101 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 102 | 103 | def forward(self, l, ab, y, idx=None): 104 | K = int(self.params[0].item()) 105 | T = self.params[1].item() 106 | Z_l = self.params[2].item() 107 | Z_ab = self.params[3].item() 108 | 109 | momentum = self.params[4].item() 110 | batchSize = l.size(0) 111 | outputSize = self.memory_l.size(0) 112 | inputSize = self.memory_l.size(1) 113 | 114 | # original score computation 115 | if idx is None: 116 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 117 | idx.select(1, 0).copy_(y.data) 118 | # sample 119 | weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach() 120 | weight_l = weight_l.view(batchSize, K + 1, inputSize) 121 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1)) 122 | out_ab = torch.exp(torch.div(out_ab, T)) 123 | # sample 124 | weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach() 125 | weight_ab = weight_ab.view(batchSize, K + 1, inputSize) 126 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1)) 127 | out_l = torch.exp(torch.div(out_l, T)) 128 | 129 | # set Z if haven't been set yet 130 | if Z_l < 0: 131 | self.params[2] = out_l.mean() * outputSize 132 | Z_l = self.params[2].clone().detach().item() 133 | print("normalization constant Z_l is set to {:.1f}".format(Z_l)) 134 | if Z_ab < 0: 135 | self.params[3] = out_ab.mean() * outputSize 136 | Z_ab = self.params[3].clone().detach().item() 137 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab)) 138 | 139 | # compute out_l, out_ab 140 | out_l = torch.div(out_l, Z_l).contiguous() 141 | out_ab = torch.div(out_ab, Z_ab).contiguous() 142 | 143 | # update memory 144 | with torch.no_grad(): 145 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1)) 146 | l_pos.mul_(momentum) 147 | l_pos.add_(torch.mul(l, 1 - momentum)) 148 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 149 | updated_l = l_pos.div(l_norm) 150 | self.memory_l.index_copy_(0, y, updated_l) 151 | 152 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1)) 153 | ab_pos.mul_(momentum) 154 | ab_pos.add_(torch.mul(ab, 1 - momentum)) 155 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 156 | updated_ab = ab_pos.div(ab_norm) 157 | self.memory_ab.index_copy_(0, y, updated_ab) 158 | 159 | return out_l, out_ab 160 | 161 | 162 | class NCEAverageWithZ(nn.Module): 163 | 164 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5, z=None): 165 | super(NCEAverageWithZ, self).__init__() 166 | self.nLem = outputSize 167 | self.unigrams = torch.ones(self.nLem) 168 | self.multinomial = AliasMethod(self.unigrams) 169 | self.multinomial.cuda() 170 | self.K = K 171 | 172 | if z is None or z <= 0: 173 | z = -1 174 | else: 175 | pass 176 | self.register_buffer('params', torch.tensor([K, T, z, z, momentum])) 177 | stdv = 1. / math.sqrt(inputSize / 3) 178 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 179 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 180 | 181 | def forward(self, l, ab, y, idx=None): 182 | K = int(self.params[0].item()) 183 | T = self.params[1].item() 184 | Z_l = self.params[2].item() 185 | Z_ab = self.params[3].item() 186 | 187 | momentum = self.params[4].item() 188 | batchSize = l.size(0) 189 | outputSize = self.memory_l.size(0) 190 | inputSize = self.memory_l.size(1) 191 | 192 | # original score computation 193 | if idx is None: 194 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 195 | idx.select(1, 0).copy_(y.data) 196 | # sample 197 | weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach() 198 | weight_l = weight_l.view(batchSize, K + 1, inputSize) 199 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1)) 200 | out_ab = torch.exp(torch.div(out_ab, T)) 201 | # sample 202 | weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach() 203 | weight_ab = weight_ab.view(batchSize, K + 1, inputSize) 204 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1)) 205 | out_l = torch.exp(torch.div(out_l, T)) 206 | 207 | # set Z if haven't been set yet 208 | if Z_l < 0: 209 | self.params[2] = out_l.mean() * outputSize 210 | Z_l = self.params[2].clone().detach().item() 211 | print("normalization constant Z_l is set to {:.1f}".format(Z_l)) 212 | if Z_ab < 0: 213 | self.params[3] = out_ab.mean() * outputSize 214 | Z_ab = self.params[3].clone().detach().item() 215 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab)) 216 | 217 | # compute out_l, out_ab 218 | out_l = torch.div(out_l, Z_l).contiguous() 219 | out_ab = torch.div(out_ab, Z_ab).contiguous() 220 | 221 | # update memory 222 | with torch.no_grad(): 223 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1)) 224 | l_pos.mul_(momentum) 225 | l_pos.add_(torch.mul(l, 1 - momentum)) 226 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 227 | updated_l = l_pos.div(l_norm) 228 | self.memory_l.index_copy_(0, y, updated_l) 229 | 230 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1)) 231 | ab_pos.mul_(momentum) 232 | ab_pos.add_(torch.mul(ab, 1 - momentum)) 233 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 234 | updated_ab = ab_pos.div(ab_norm) 235 | self.memory_ab.index_copy_(0, y, updated_ab) 236 | 237 | return out_l, out_ab 238 | 239 | 240 | class NCEAverageFull(nn.Module): 241 | 242 | def __init__(self, inputSize, outputSize, T=0.07, momentum=0.5): 243 | super(NCEAverageFull, self).__init__() 244 | self.nLem = outputSize 245 | 246 | self.register_buffer('params', torch.tensor([T, -1, -1, momentum])) 247 | stdv = 1. / math.sqrt(inputSize / 3) 248 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 249 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 250 | 251 | def forward(self, l, ab, y): 252 | T = self.params[0].item() 253 | Z_l = self.params[1].item() 254 | Z_ab = self.params[2].item() 255 | 256 | momentum = self.params[3].item() 257 | batchSize = l.size(0) 258 | outputSize = self.memory_l.size(0) 259 | inputSize = self.memory_l.size(1) 260 | 261 | # score computation 262 | idx1 = y.unsqueeze(1).expand(-1, inputSize).unsqueeze(1).expand(-1, 1, -1) 263 | idx2 = torch.zeros(batchSize).long().cuda() 264 | idx2 = idx2.unsqueeze(1).expand(-1, inputSize).unsqueeze(1).expand(-1, 1, -1) 265 | # sample 266 | weight_l = self.memory_l.clone().detach().unsqueeze(0).expand(batchSize, outputSize, inputSize) 267 | weight_l_1 = weight_l.gather(dim=1, index=idx1) 268 | weight_l_2 = weight_l.gather(dim=1, index=idx2) 269 | weight_l.scatter_(1, idx1, weight_l_2) 270 | weight_l.scatter_(1, idx2, weight_l_1) 271 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1)) 272 | out_ab = torch.exp(torch.div(out_ab, T)) 273 | # sample 274 | weight_ab = self.memory_ab.clone().detach().unsqueeze(0).expand(batchSize, outputSize, inputSize) 275 | weight_ab_1 = weight_ab.gather(dim=1, index=idx1) 276 | weight_ab_2 = weight_ab.gather(dim=1, index=idx2) 277 | weight_ab.scatter_(1, idx1, weight_ab_2) 278 | weight_ab.scatter_(1, idx2, weight_ab_1) 279 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1)) 280 | out_l = torch.exp(torch.div(out_l, T)) 281 | 282 | # set Z if haven't been set yet 283 | if Z_l < 0: 284 | self.params[1] = out_l.mean() * outputSize 285 | Z_l = self.params[1].clone().detach().item() 286 | print("normalization constant Z_l is set to {:.1f}".format(Z_l)) 287 | if Z_ab < 0: 288 | self.params[2] = out_ab.mean() * outputSize 289 | Z_ab = self.params[2].clone().detach().item() 290 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab)) 291 | 292 | # compute out_l, out_ab 293 | out_l = torch.div(out_l, Z_l).contiguous() 294 | out_ab = torch.div(out_ab, Z_ab).contiguous() 295 | 296 | # update memory 297 | with torch.no_grad(): 298 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1)) 299 | l_pos.mul_(momentum) 300 | l_pos.add_(torch.mul(l, 1 - momentum)) 301 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 302 | updated_l = l_pos.div(l_norm) 303 | self.memory_l.index_copy_(0, y, updated_l) 304 | 305 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1)) 306 | ab_pos.mul_(momentum) 307 | ab_pos.add_(torch.mul(ab, 1 - momentum)) 308 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 309 | updated_ab = ab_pos.div(ab_norm) 310 | self.memory_ab.index_copy_(0, y, updated_ab) 311 | 312 | return out_l, out_ab 313 | 314 | 315 | class NCEAverageFullSoftmax(nn.Module): 316 | 317 | def __init__(self, inputSize, outputSize, T=1, momentum=0.5): 318 | super(NCEAverageFullSoftmax, self).__init__() 319 | self.nLem = outputSize 320 | 321 | self.register_buffer('params', torch.tensor([T, momentum])) 322 | stdv = 1. / math.sqrt(inputSize / 3) 323 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 324 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 325 | 326 | def forward(self, l, ab, y): 327 | T = self.params[0].item() 328 | momentum = self.params[1].item() 329 | batchSize = l.size(0) 330 | outputSize = self.memory_l.size(0) 331 | inputSize = self.memory_l.size(1) 332 | 333 | # score computation 334 | # weight_l = self.memory_l.unsqueeze(0).expand(batchSize, outputSize, inputSize).detach() 335 | weight_l = self.memory_l.clone().unsqueeze(0).expand(batchSize, outputSize, inputSize).detach() 336 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1)) 337 | out_ab = out_ab.div(T) 338 | out_ab = out_ab.squeeze().contiguous() 339 | 340 | # weight_ab = self.memory_ab.unsqueeze(0).expand(batchSize, outputSize, inputSize).detach() 341 | weight_ab = self.memory_ab.clone().unsqueeze(0).expand(batchSize, outputSize, inputSize).detach() 342 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1)) 343 | out_l = out_l.div(T) 344 | out_l = out_l.squeeze().contiguous() 345 | 346 | # update memory 347 | with torch.no_grad(): 348 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1)) 349 | l_pos.mul_(momentum) 350 | l_pos.add_(torch.mul(l, 1 - momentum)) 351 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 352 | updated_l = l_pos.div(l_norm) 353 | self.memory_l.index_copy_(0, y, updated_l) 354 | 355 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1)) 356 | ab_pos.mul_(momentum) 357 | ab_pos.add_(torch.mul(ab, 1 - momentum)) 358 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 359 | updated_ab = ab_pos.div(ab_norm) 360 | self.memory_ab.index_copy_(0, y, updated_ab) 361 | 362 | return out_l, out_ab 363 | 364 | def update_memory(self, l, ab, y): 365 | momentum = self.params[1].item() 366 | # update memory 367 | with torch.no_grad(): 368 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1)) 369 | l_pos.mul_(momentum) 370 | l_pos.add_(torch.mul(l, 1 - momentum)) 371 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 372 | updated_l = l_pos.div(l_norm) 373 | self.memory_l.index_copy_(0, y, updated_l) 374 | 375 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1)) 376 | ab_pos.mul_(momentum) 377 | ab_pos.add_(torch.mul(ab, 1 - momentum)) 378 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 379 | updated_ab = ab_pos.div(ab_norm) 380 | self.memory_ab.index_copy_(0, y, updated_ab) 381 | -------------------------------------------------------------------------------- /distill/NCECriterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | eps = 1e-7 5 | 6 | 7 | class NCECriterion(nn.Module): 8 | 9 | def __init__(self, nLem): 10 | super(NCECriterion, self).__init__() 11 | self.nLem = nLem 12 | 13 | def forward(self, x): 14 | batchSize = x.size(0) 15 | K = x.size(1) - 1 16 | Pnt = 1 / float(self.nLem) 17 | Pns = 1 / float(self.nLem) 18 | 19 | # eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt) 20 | Pmt = x.select(1, 0) 21 | Pmt_div = Pmt.add(K * Pnt + eps) 22 | lnPmt = torch.div(Pmt, Pmt_div) 23 | 24 | # eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns) 25 | Pon_div = x.narrow(1, 1, K).add(K * Pns + eps) 26 | Pon = Pon_div.clone().fill_(K * Pns) 27 | lnPon = torch.div(Pon, Pon_div) 28 | 29 | # equation 6 in ref. A 30 | lnPmt.log_() 31 | lnPon.log_() 32 | 33 | lnPmtsum = lnPmt.sum(0) 34 | lnPonsum = lnPon.view(-1, 1).sum(0) 35 | 36 | loss = - (lnPmtsum + lnPonsum) / batchSize 37 | 38 | return loss -------------------------------------------------------------------------------- /distill/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/distill/__init__.py -------------------------------------------------------------------------------- /distill/alias_multinomial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AliasMethod(object): 5 | ''' 6 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 7 | ''' 8 | def __init__(self, probs): 9 | 10 | if probs.sum() > 1: 11 | probs.div_(probs.sum()) 12 | K = len(probs) 13 | self.prob = torch.zeros(K) 14 | self.alias = torch.LongTensor([0]*K) 15 | 16 | # Sort the data into the outcomes with probabilities 17 | # that are larger and smaller than 1/K. 18 | smaller = [] 19 | larger = [] 20 | for kk, prob in enumerate(probs): 21 | self.prob[kk] = K*prob 22 | if self.prob[kk] < 1.0: 23 | smaller.append(kk) 24 | else: 25 | larger.append(kk) 26 | 27 | # Loop though and create little binary mixtures that 28 | # appropriately allocate the larger outcomes over the 29 | # overall uniform mixture. 30 | while len(smaller) > 0 and len(larger) > 0: 31 | small = smaller.pop() 32 | large = larger.pop() 33 | 34 | self.alias[small] = large 35 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] 36 | 37 | if self.prob[large] < 1.0: 38 | smaller.append(large) 39 | else: 40 | larger.append(large) 41 | 42 | for last_one in smaller+larger: 43 | self.prob[last_one] = 1 44 | 45 | def cuda(self): 46 | self.prob = self.prob.cuda() 47 | self.alias = self.alias.cuda() 48 | 49 | def draw(self, N): 50 | ''' 51 | Draw N samples from multinomial 52 | ''' 53 | K = self.alias.size(0) 54 | 55 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) 56 | prob = self.prob.index_select(0, kk) 57 | alias = self.alias.index_select(0, kk) 58 | # b is whether a random number is greater than q 59 | b = torch.bernoulli(prob) 60 | oq = kk.mul(b.long()) 61 | oj = alias.mul((1-b).long()) 62 | 63 | return oq + oj 64 | -------------------------------------------------------------------------------- /distill/criterion.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from scipy.stats import norm 9 | 10 | from .NCEAverage import NCEAverage 11 | from .NCEAverage import NCESoftmax 12 | from .NCECriterion import NCECriterion 13 | 14 | 15 | class DistillKL(nn.Module): 16 | """KL divergence for distillation""" 17 | def __init__(self, T): 18 | super(DistillKL, self).__init__() 19 | self.T = T 20 | 21 | def forward(self, y_s, y_t): 22 | p_s = F.log_softmax(y_s/self.T, dim=1) 23 | p_t = F.softmax(y_t/self.T, dim=1) 24 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] 25 | return loss 26 | 27 | 28 | class NCELoss(nn.Module): 29 | """NCE contrastive loss""" 30 | def __init__(self, opt, n_data): 31 | super(NCELoss, self).__init__() 32 | self.contrast = NCEAverage(opt.feat_dim, n_data, opt.nce_k, opt.nce_t, opt.nce_m) 33 | self.criterion_t = NCECriterion(n_data) 34 | self.criterion_s = NCECriterion(n_data) 35 | 36 | def forward(self, f_s, f_t, idx, contrast_idx=None): 37 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx) 38 | s_loss = self.criterion_s(out_s) 39 | t_loss = self.criterion_t(out_t) 40 | loss = s_loss + t_loss 41 | return loss 42 | 43 | 44 | class NCESoftmaxLoss(nn.Module): 45 | """info NCE style loss, softmax""" 46 | def __init__(self, opt, n_data): 47 | super(NCESoftmaxLoss, self).__init__() 48 | self.contrast = NCESoftmax(opt.feat_dim, n_data, opt.nce_k, opt.nce_t, opt.nce_m) 49 | self.criterion_t = nn.CrossEntropyLoss() 50 | self.criterion_s = nn.CrossEntropyLoss() 51 | 52 | def forward(self, f_s, f_t, idx, contrast_idx=None): 53 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx) 54 | bsz = f_s.shape[0] 55 | label = torch.zeros([bsz, 1]).cuda().long() 56 | s_loss = self.criterion_s(out_s, label) 57 | t_loss = self.criterion_t(out_t, label) 58 | loss = s_loss + t_loss 59 | return loss 60 | 61 | 62 | class Attention(nn.Module): 63 | """attention transfer loss""" 64 | def __init__(self, p=2): 65 | super(Attention, self).__init__() 66 | self.p = p 67 | 68 | def forward(self, g_s, g_t): 69 | return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 70 | 71 | def at_loss(self, f_s, f_t): 72 | s_H, t_H = f_s.shape[2], f_t.shape[2] 73 | if s_H > t_H: 74 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 75 | elif s_H < t_H: 76 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 77 | else: 78 | pass 79 | return (self.at(f_s) - self.at(f_t)).pow(2).mean() 80 | 81 | def at(self, f): 82 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) 83 | 84 | 85 | class HintLoss(nn.Module): 86 | """regression loss from hints""" 87 | def __init__(self): 88 | super(HintLoss, self).__init__() 89 | self.crit = nn.MSELoss() 90 | 91 | def forward(self, f_s, f_t): 92 | loss = self.crit(f_s, f_t) 93 | return loss 94 | 95 | 96 | if __name__ == '__main__': 97 | pass 98 | -------------------------------------------------------------------------------- /distill/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class Embed(nn.Module): 7 | """Embedding module""" 8 | def __init__(self, dim_in=1024, dim_out=128): 9 | super(Embed, self).__init__() 10 | self.linear = nn.Linear(dim_in, dim_out) 11 | self.l2norm = Normalize(2) 12 | 13 | def forward(self, x): 14 | x = x.view(x.shape[0], -1) 15 | x = self.linear(x) 16 | x = self.l2norm(x) 17 | return x 18 | 19 | 20 | class LinearEmbed(nn.Module): 21 | """Linear Embedding""" 22 | def __init__(self, dim_in=1024, dim_out=128): 23 | super(LinearEmbed, self).__init__() 24 | self.linear = nn.Linear(dim_in, dim_out) 25 | 26 | def forward(self, x): 27 | x = x.view(x.shape[0], -1) 28 | x = self.linear(x) 29 | return x 30 | 31 | 32 | class MLPEmbed(nn.Module): 33 | """non-linear embed by MLP""" 34 | def __init__(self, dim_in=1024, dim_out=128): 35 | super(MLPEmbed, self).__init__() 36 | self.linear1 = nn.Linear(dim_in, 2 * dim_out) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.linear2 = nn.Linear(2 * dim_out, dim_out) 39 | self.l2norm = Normalize(2) 40 | 41 | def forward(self, x): 42 | x = x.view(x.shape[0], -1) 43 | x = self.relu(self.linear1(x)) 44 | x = self.l2norm(self.linear2(x)) 45 | return x 46 | 47 | 48 | class Normalize(nn.Module): 49 | """normalization layer""" 50 | def __init__(self, power=2): 51 | super(Normalize, self).__init__() 52 | self.power = power 53 | 54 | def forward(self, x): 55 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 56 | out = x.div(norm) 57 | return out 58 | 59 | 60 | if __name__ == '__main__': 61 | pass 62 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/eval/__init__.py -------------------------------------------------------------------------------- /eval/cls_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import time 5 | from tqdm import tqdm 6 | from .util import AverageMeter, accuracy 7 | import numpy as np 8 | 9 | def validate(val_loader, model, criterion, opt): 10 | """One epoch validation""" 11 | batch_time = AverageMeter() 12 | losses = AverageMeter() 13 | top1 = AverageMeter() 14 | top5 = AverageMeter() 15 | 16 | # switch to evaluate mode 17 | model.eval() 18 | 19 | with torch.no_grad(): 20 | with tqdm(val_loader, total=len(val_loader)) as pbar: 21 | end = time.time() 22 | for idx, (input, target, _) in enumerate(pbar): 23 | 24 | if(opt.simclr): 25 | input = input[0].float() 26 | else: 27 | input = input.float() 28 | 29 | if torch.cuda.is_available(): 30 | input = input.cuda() 31 | target = target.cuda() 32 | 33 | # compute output 34 | output = model(input) 35 | loss = criterion(output, target) 36 | 37 | # measure accuracy and record loss 38 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 39 | losses.update(loss.item(), input.size(0)) 40 | top1.update(acc1[0], input.size(0)) 41 | top5.update(acc5[0], input.size(0)) 42 | 43 | # measure elapsed time 44 | batch_time.update(time.time() - end) 45 | end = time.time() 46 | 47 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()), 48 | "Acc@5":'{0:.2f}'.format(top1.avg.cpu().numpy(),2), 49 | "Loss" :'{0:.2f}'.format(losses.avg,2), 50 | }) 51 | # if idx % opt.print_freq == 0: 52 | # print('Test: [{0}/{1}]\t' 53 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 54 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 55 | # 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 56 | # 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 57 | # idx, len(val_loader), batch_time=batch_time, loss=losses, 58 | # top1=top1, top5=top5)) 59 | 60 | print('Val_Acc@1 {top1.avg:.3f} Val_Acc@5 {top5.avg:.3f}' 61 | .format(top1=top1, top5=top5)) 62 | 63 | return top1.avg, top5.avg, losses.avg 64 | 65 | 66 | 67 | 68 | def embedding(val_loader, model, opt): 69 | """One epoch validation""" 70 | batch_time = AverageMeter() 71 | losses = AverageMeter() 72 | top1 = AverageMeter() 73 | top5 = AverageMeter() 74 | 75 | # switch to evaluate mode 76 | model.eval() 77 | 78 | 79 | with torch.no_grad(): 80 | with tqdm(val_loader, total=len(val_loader)) as pbar: 81 | end = time.time() 82 | for idx, (input, target, _) in enumerate(pbar): 83 | 84 | if(opt.simclr): 85 | input = input[0].float() 86 | else: 87 | input = input.float() 88 | 89 | if torch.cuda.is_available(): 90 | input = input.cuda() 91 | target = target.cuda() 92 | 93 | batch_size = input.size()[0] 94 | x = input 95 | x_90 = x.transpose(2,3).flip(2) 96 | x_180 = x.flip(2).flip(3) 97 | x_270 = x.flip(2).transpose(2,3) 98 | generated_data = torch.cat((x, x_90, x_180, x_270),0) 99 | train_targets = target.repeat(4) 100 | 101 | # compute output 102 | # output = model(input) 103 | (_,_,_,_, feat), (output, rot_logits) = model(generated_data, rot=True) 104 | # loss = criterion(output, target) 105 | 106 | # measure accuracy and record loss 107 | acc1, acc5 = accuracy(output[:batch_size], target, topk=(1, 5)) 108 | # losses.update(loss.item(), input.size(0)) 109 | top1.update(acc1[0], input.size(0)) 110 | top5.update(acc5[0], input.size(0)) 111 | 112 | if(idx==0): 113 | embeddings = output 114 | classes = train_targets 115 | else: 116 | embeddings = torch.cat((embeddings, output),0) 117 | classes = torch.cat((classes, train_targets),0) 118 | 119 | 120 | # measure elapsed time 121 | batch_time.update(time.time() - end) 122 | end = time.time() 123 | 124 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()), 125 | "Acc@5":'{0:.2f}'.format(top1.avg.cpu().numpy(),2) 126 | }) 127 | # if idx % opt.print_freq == 0: 128 | # print('Test: [{0}/{1}]\t' 129 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 130 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 131 | # 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 132 | # 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 133 | # idx, len(val_loader), batch_time=batch_time, loss=losses, 134 | # top1=top1, top5=top5)) 135 | 136 | print('Val_Acc@1 {top1.avg:.3f} Val_Acc@5 {top5.avg:.3f}' 137 | .format(top1=top1, top5=top5)) 138 | print(embeddings.size()) 139 | print(classes.size()) 140 | 141 | np.save("embeddings.npy", embeddings.detach().cpu().numpy()) 142 | np.save("classes.npy", classes.detach().cpu().numpy()) 143 | 144 | 145 | 146 | # with tqdm(val_loader, total=len(val_loader)) as pbar: 147 | # end = time.time() 148 | # for idx, (input, target, _) in enumerate(pbar): 149 | 150 | # if(opt.simclr): 151 | # input = input[0].float() 152 | # else: 153 | # input = input.float() 154 | 155 | # if torch.cuda.is_available(): 156 | # input = input.cuda() 157 | # target = target.cuda() 158 | 159 | # generated_data = torch.cat((x, x_180),0) 160 | # # compute output 161 | # # output = model(input) 162 | # (_,_,_,_, feat), (output, rot_logits) = model(input, rot=True) 163 | # # loss = criterion(output, target) 164 | 165 | # # measure accuracy and record loss 166 | # acc1, acc5 = accuracy(output, target, topk=(1, 5)) 167 | # # losses.update(loss.item(), input.size(0)) 168 | # top1.update(acc1[0], input.size(0)) 169 | # top5.update(acc5[0], input.size(0)) 170 | 171 | # if(idx==0): 172 | # embeddings = output 173 | # classes = target 174 | # else: 175 | # embeddings = torch.cat((embeddings, output),0) 176 | # classes = torch.cat((classes, target),0) 177 | 178 | 179 | # # measure elapsed time 180 | # batch_time.update(time.time() - end) 181 | # end = time.time() 182 | 183 | # pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()), 184 | # "Acc@5":'{0:.2f}'.format(top1.avg.cpu().numpy(),2) 185 | # }) 186 | # # if idx % opt.print_freq == 0: 187 | # # print('Test: [{0}/{1}]\t' 188 | # # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 189 | # # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 190 | # # 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 191 | # # 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 192 | # # idx, len(val_loader), batch_time=batch_time, loss=losses, 193 | # # top1=top1, top5=top5)) 194 | 195 | # print('Val_Acc@1 {top1.avg:.3f} Val_Acc@5 {top5.avg:.3f}' 196 | # .format(top1=top1, top5=top5)) 197 | # print(embeddings.size()) 198 | # print(classes.size()) 199 | 200 | # np.save("embeddings.npy", embeddings.detach().cpu().numpy()) 201 | # np.save("classes.npy", classes.detach().cpu().numpy()) 202 | return top1.avg, top5.avg, losses.avg 203 | -------------------------------------------------------------------------------- /eval/meta_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import scipy 5 | from scipy.stats import t 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from sklearn import metrics 10 | from sklearn.svm import SVC 11 | from sklearn.linear_model import LogisticRegression 12 | from sklearn.neighbors import KNeighborsClassifier 13 | from sklearn.ensemble import RandomForestClassifier 14 | 15 | import torch 16 | import torch.nn as nn 17 | import sys, os 18 | from collections import Counter 19 | 20 | 21 | sys.path.append(os.path.abspath('..')) 22 | 23 | from util import accuracy 24 | 25 | 26 | def mean_confidence_interval(data, confidence=0.95): 27 | a = 100.0 * np.array(data) 28 | n = len(a) 29 | m, se = np.mean(a), scipy.stats.sem(a) 30 | h = se * t._ppf((1+confidence)/2., n-1) 31 | return m, h 32 | 33 | 34 | def normalize(x): 35 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2) 36 | out = x.div(norm) 37 | return out 38 | 39 | 40 | def meta_test(net, testloader, use_logit=False, is_norm=True, classifier='LR'): 41 | net = net.eval() 42 | acc = [] 43 | 44 | with torch.no_grad(): 45 | with tqdm(testloader, total=len(testloader)) as pbar: 46 | for idx, data in enumerate(pbar): 47 | support_xs, support_ys, query_xs, query_ys = data 48 | 49 | support_xs = support_xs.cuda() 50 | query_xs = query_xs.cuda() 51 | batch_size, _, height, width, channel = support_xs.size() 52 | support_xs = support_xs.view(-1, height, width, channel) 53 | query_xs = query_xs.view(-1, height, width, channel) 54 | 55 | 56 | 57 | # batch_size = support_xs.size()[0] 58 | # x = support_xs 59 | # x_90 = x.transpose(2,3).flip(2) 60 | # x_180 = x.flip(2).flip(3) 61 | # x_270 = x.flip(2).transpose(2,3) 62 | # generated_data = torch.cat((x, x_90, x_180, x_270),0) 63 | # support_ys = support_ys.repeat(1,4) 64 | # support_xs = generated_data 65 | 66 | # print(support_xs.size()) 67 | # print(support_ys.size()) 68 | 69 | 70 | 71 | if use_logit: 72 | support_features = net(support_xs).view(support_xs.size(0), -1) 73 | query_features = net(query_xs).view(query_xs.size(0), -1) 74 | else: 75 | feat_support, _ = net(support_xs, is_feat=True) 76 | support_features = feat_support[-1].view(support_xs.size(0), -1) 77 | feat_query, _ = net(query_xs, is_feat=True) 78 | query_features = feat_query[-1].view(query_xs.size(0), -1) 79 | 80 | # feat_support, _ = net(support_xs) 81 | # support_features = feat_support.view(support_xs.size(0), -1) 82 | # feat_query, _ = net(query_xs) 83 | # query_features = feat_query.view(query_xs.size(0), -1) 84 | 85 | 86 | if is_norm: 87 | support_features = normalize(support_features) 88 | query_features = normalize(query_features) 89 | 90 | support_features = support_features.detach().cpu().numpy() 91 | query_features = query_features.detach().cpu().numpy() 92 | 93 | support_ys = support_ys.view(-1).numpy() 94 | query_ys = query_ys.view(-1).numpy() 95 | 96 | 97 | 98 | if classifier == 'LR': 99 | clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000, penalty='l2', 100 | multi_class='multinomial') 101 | clf.fit(support_features, support_ys) 102 | query_ys_pred = clf.predict(query_features) 103 | elif classifier == 'NN': 104 | query_ys_pred = NN(support_features, support_ys, query_features) 105 | elif classifier == 'Cosine': 106 | query_ys_pred = Cosine(support_features, support_ys, query_features) 107 | else: 108 | raise NotImplementedError('classifier not supported: {}'.format(classifier)) 109 | 110 | 111 | # bs = query_features.shape[0]//opt.n_aug_support_samples 112 | # a = np.reshape(query_ys_pred[:bs], (-1,1)) 113 | # c = query_ys[:bs] 114 | # for i in range(1,opt.n_aug_support_samples): 115 | # a = np.hstack([a, np.reshape(query_ys_pred[i*bs:(i+1)*bs], (-1,1))]) 116 | 117 | # d = [] 118 | # for i in range(a.shape[0]): 119 | # b = Counter(a[i,:]) 120 | # d.append(b.most_common(1)[0][0]) 121 | 122 | # # (values,counts) = np.unique(a,axis=1, return_counts=True) 123 | # # print(counts) 124 | # # ind=np.argmax(counts) 125 | # # print values[ind] # pr 126 | 127 | 128 | # # # a = np.argmax 129 | # # print(a.shape) 130 | # # print(c.shape) 131 | 132 | acc.append(metrics.accuracy_score(query_ys, query_ys_pred)) 133 | 134 | pbar.set_postfix({"FSL_Acc":'{0:.2f}'.format(metrics.accuracy_score(query_ys, query_ys_pred))}) 135 | 136 | return mean_confidence_interval(acc) 137 | 138 | 139 | 140 | 141 | def meta_test_tune(net, testloader, use_logit=False, is_norm=True, classifier='LR', lamda=0.2): 142 | net = net.eval() 143 | acc = [] 144 | 145 | with tqdm(testloader, total=len(testloader)) as pbar: 146 | for idx, data in enumerate(pbar): 147 | support_xs, support_ys, query_xs, query_ys, support_ts, query_ts = data 148 | 149 | support_xs = support_xs.cuda() 150 | support_ys = support_ys.cuda() 151 | query_ys = query_ys.cuda() 152 | query_xs = query_xs.cuda() 153 | batch_size, _, height, width, channel = support_xs.size() 154 | support_xs = support_xs.view(-1, height, width, channel) 155 | support_ys = support_ys.view(-1,1) 156 | query_ys = query_ys.view(-1) 157 | query_xs = query_xs.view(-1, height, width, channel) 158 | 159 | if use_logit: 160 | support_features = net(support_xs).view(support_xs.size(0), -1) 161 | query_features = net(query_xs).view(query_xs.size(0), -1) 162 | else: 163 | feat_support, _ = net(support_xs, is_feat=True) 164 | support_features = feat_support[-1].view(support_xs.size(0), -1) 165 | feat_query, _ = net(query_xs, is_feat=True) 166 | query_features = feat_query[-1].view(query_xs.size(0), -1) 167 | 168 | if is_norm: 169 | support_features = normalize(support_features) 170 | query_features = normalize(query_features) 171 | 172 | y_onehot = torch.FloatTensor(support_ys.size()[0], 5).cuda() 173 | 174 | # In your for loop 175 | y_onehot.zero_() 176 | y_onehot.scatter_(1, support_ys, 1) 177 | 178 | 179 | X = support_features 180 | XTX = torch.matmul(torch.t(X),X) 181 | 182 | B = torch.matmul( (XTX + lamda*torch.eye(640).cuda() ).inverse(), torch.matmul(torch.t(X), y_onehot.float()) ) 183 | # print(B.size()) 184 | m = nn.Sigmoid() 185 | Y_pred = m(torch.matmul(query_features, B)) 186 | 187 | 188 | # print(Y_pred, query_ys) 189 | # model = nn.Sequential(nn.Linear(64, 10),nn.LogSoftmax(dim=1)) 190 | # optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 191 | # criterion = nn.CrossEntropyLoss() 192 | 193 | # model.cuda() 194 | # criterion.cuda() 195 | # model.train() 196 | 197 | # for i in range(5): 198 | # output = model(support_features) 199 | # loss = criterion(output, support_ys) 200 | # optimizer.zero_grad() 201 | # loss.backward(retain_graph=True) # auto-grad 202 | # optimizer.step() # update weights 203 | 204 | # model.eval() 205 | # query_ys_pred = model(query_features) 206 | 207 | acc1, acc5 = accuracy(Y_pred, query_ys, topk=(1, 1)) 208 | 209 | 210 | # support_features = support_features.detach().cpu().numpy() 211 | # query_features = query_features.detach().cpu().numpy() 212 | 213 | # support_ys = support_ys.view(-1).numpy() 214 | # query_ys = query_ys.view(-1).numpy() 215 | 216 | # if classifier == 'LR': 217 | # clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000, 218 | # multi_class='multinomial') 219 | # clf.fit(support_features, support_ys) 220 | # query_ys_pred = clf.predict(query_features) 221 | # elif classifier == 'NN': 222 | # query_ys_pred = NN(support_features, support_ys, query_features) 223 | # elif classifier == 'Cosine': 224 | # query_ys_pred = Cosine(support_features, support_ys, query_features) 225 | # else: 226 | # raise NotImplementedError('classifier not supported: {}'.format(classifier)) 227 | 228 | acc.append(acc1.item()/100.0) 229 | 230 | pbar.set_postfix({"FSL_Acc":'{0:.4f}'.format(np.mean(acc))}) 231 | 232 | 233 | return mean_confidence_interval(acc) 234 | 235 | 236 | 237 | def meta_test_ensamble(net, testloader, use_logit=True, is_norm=True, classifier='LR'): 238 | for n in net: 239 | n = n.eval() 240 | acc = [] 241 | 242 | with torch.no_grad(): 243 | with tqdm(testloader, total=len(testloader)) as pbar: 244 | for idx, data in enumerate(pbar): 245 | support_xs, support_ys, query_xs, query_ys = data 246 | 247 | support_xs = support_xs.cuda() 248 | query_xs = query_xs.cuda() 249 | batch_size, _, height, width, channel = support_xs.size() 250 | support_xs = support_xs.view(-1, height, width, channel) 251 | query_xs = query_xs.view(-1, height, width, channel) 252 | 253 | if use_logit: 254 | support_features = net[0](support_xs).view(support_xs.size(0), -1) 255 | query_features = net[0](query_xs).view(query_xs.size(0), -1) 256 | for n in net[1:]: 257 | support_features += n(support_xs).view(support_xs.size(0), -1) 258 | query_features += n(query_xs).view(query_xs.size(0), -1) 259 | else: 260 | feat_support, _ = net(support_xs, is_feat=True) 261 | support_features = feat_support[-1].view(support_xs.size(0), -1) 262 | feat_query, _ = net(query_xs, is_feat=True) 263 | query_features = feat_query[-1].view(query_xs.size(0), -1) 264 | 265 | if is_norm: 266 | support_features = normalize(support_features) 267 | query_features = normalize(query_features) 268 | 269 | support_features = support_features.detach().cpu().numpy() 270 | query_features = query_features.detach().cpu().numpy() 271 | 272 | support_ys = support_ys.view(-1).numpy() 273 | query_ys = query_ys.view(-1).numpy() 274 | 275 | if classifier == 'LR': 276 | clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000, 277 | multi_class='multinomial') 278 | clf.fit(support_features, support_ys) 279 | query_ys_pred = clf.predict(query_features) 280 | elif classifier == 'NN': 281 | query_ys_pred = NN(support_features, support_ys, query_features) 282 | elif classifier == 'Cosine': 283 | query_ys_pred = Cosine(support_features, support_ys, query_features) 284 | else: 285 | raise NotImplementedError('classifier not supported: {}'.format(classifier)) 286 | 287 | acc.append(metrics.accuracy_score(query_ys, query_ys_pred)) 288 | 289 | pbar.set_postfix({"FSL_Acc":'{0:.2f}'.format(metrics.accuracy_score(query_ys, query_ys_pred))}) 290 | 291 | return mean_confidence_interval(acc) 292 | 293 | 294 | def NN(support, support_ys, query): 295 | """nearest classifier""" 296 | support = np.expand_dims(support.transpose(), 0) 297 | query = np.expand_dims(query, 2) 298 | 299 | diff = np.multiply(query - support, query - support) 300 | distance = diff.sum(1) 301 | min_idx = np.argmin(distance, axis=1) 302 | pred = [support_ys[idx] for idx in min_idx] 303 | return pred 304 | 305 | 306 | def Cosine(support, support_ys, query): 307 | """Cosine classifier""" 308 | support_norm = np.linalg.norm(support, axis=1, keepdims=True) 309 | support = support / support_norm 310 | query_norm = np.linalg.norm(query, axis=1, keepdims=True) 311 | query = query / query_norm 312 | 313 | cosine_distance = query @ support.transpose() 314 | max_idx = np.argmax(cosine_distance, axis=1) 315 | pred = [support_ys[idx] for idx in max_idx] 316 | return pred 317 | -------------------------------------------------------------------------------- /eval/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | def accuracy(output, target, topk=(1,)): 24 | """Computes the accuracy over the k top predictions for the specified values of k""" 25 | with torch.no_grad(): 26 | maxk = max(topk) 27 | batch_size = target.size(0) 28 | 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 36 | res.append(correct_k.mul_(100.0 / batch_size)) 37 | return res 38 | -------------------------------------------------------------------------------- /eval_fewshot.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import socket 5 | import time 6 | import os 7 | import mkl 8 | 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torch.utils.data import DataLoader 13 | 14 | from models import model_pool 15 | from models.util import create_model 16 | 17 | from dataset.mini_imagenet import MetaImageNet 18 | from dataset.tiered_imagenet import MetaTieredImageNet 19 | from dataset.cifar import MetaCIFAR100 20 | from dataset.transform_cfg import transforms_test_options, transforms_list 21 | 22 | from eval.meta_eval import meta_test, meta_test_tune 23 | from eval.cls_eval import validate, embedding 24 | from dataloader import get_dataloaders 25 | 26 | mkl.set_num_threads(2) 27 | 28 | 29 | 30 | def parse_option(): 31 | 32 | parser = argparse.ArgumentParser('argument for training') 33 | 34 | # load pretrained model 35 | parser.add_argument('--model', type=str, default='resnet12', choices=model_pool) 36 | parser.add_argument('--model_path', type=str, default="", help='absolute path to .pth model') 37 | # parser.add_argument('--model_path', type=str, default="/raid/data/IncrementLearn/imagenet/neurips20/model/maml_miniimagenet_test_5shot_step_5_5ways_5shots/pretrain_maml_miniimagenet_test_5shot_step_5_5ways_5shots.pt", help='absolute path to .pth model') 38 | 39 | # dataset 40 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet', 41 | 'CIFAR-FS', 'FC100', "toy"]) 42 | parser.add_argument('--transform', type=str, default='A', choices=transforms_list) 43 | 44 | # specify data_root 45 | parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root') 46 | parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation') 47 | 48 | # meta setting 49 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N', 50 | help='Number of test runs') 51 | parser.add_argument('--n_ways', type=int, default=5, metavar='N', 52 | help='Number of classes for doing each classification run') 53 | parser.add_argument('--n_shots', type=int, default=1, metavar='N', 54 | help='Number of shots in test') 55 | parser.add_argument('--n_queries', type=int, default=15, metavar='N', 56 | help='Number of query in test') 57 | parser.add_argument('--n_aug_support_samples', default=5, type=int, 58 | help='The number of augmented samples for each meta test sample') 59 | parser.add_argument('--num_workers', type=int, default=3, metavar='N', 60 | help='Number of workers for dataloader') 61 | parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size', 62 | help='Size of test batch)') 63 | 64 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 65 | 66 | opt = parser.parse_args() 67 | 68 | if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': 69 | opt.transform = 'D' 70 | 71 | if 'trainval' in opt.model_path: 72 | opt.use_trainval = True 73 | else: 74 | opt.use_trainval = False 75 | 76 | # set the path according to the environment 77 | if not opt.data_root: 78 | opt.data_root = './data/{}'.format(opt.dataset) 79 | else: 80 | if(opt.dataset=="toy"): 81 | opt.data_root = '{}/{}'.format(opt.data_root, "CIFAR-FS") 82 | else: 83 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset) 84 | opt.data_aug = True 85 | 86 | return opt 87 | 88 | 89 | def main(): 90 | 91 | opt = parse_option() 92 | 93 | opt.n_test_runs = 600 94 | train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt) 95 | 96 | # load model 97 | model = create_model(opt.model, n_cls, opt.dataset) 98 | ckpt = torch.load(opt.model_path) 99 | model.load_state_dict(ckpt["model"]) 100 | 101 | if torch.cuda.is_available(): 102 | model = model.cuda() 103 | cudnn.benchmark = True 104 | 105 | start = time.time() 106 | test_acc, test_std = meta_test(model, meta_testloader) 107 | test_time = time.time() - start 108 | print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time)) 109 | 110 | 111 | start = time.time() 112 | test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False) 113 | test_time = time.time() - start 114 | print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc_feat, test_std_feat, test_time)) 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .convnet import convnet4 2 | from .resnet import resnet12 3 | from .resnet_ssl import resnet12_ssl 4 | from .resnet_sd import resnet12_sd 5 | from .resnet_selfdist import multi_resnet12_kd 6 | from .resnet import seresnet12 7 | from .wresnet import wrn_28_10 8 | 9 | from .resnet_new import resnet50 10 | 11 | model_pool = [ 12 | 'convnet4', 13 | 'resnet12', 14 | 'resnet12_ssl', 15 | 'resnet12_kd', 16 | 'resnet12_sd', 17 | 'seresnet12', 18 | 'wrn_28_10', 19 | ] 20 | 21 | model_dict = { 22 | 'wrn_28_10': wrn_28_10, 23 | 'convnet4': convnet4, 24 | 'resnet12': resnet12, 25 | 'resnet12_ssl': resnet12_ssl, 26 | 'resnet12_kd': multi_resnet12_kd, 27 | 'resnet12_sd': resnet12_sd, 28 | 'seresnet12': seresnet12, 29 | 'resnet50': resnet50, 30 | } 31 | -------------------------------------------------------------------------------- /models/convnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ConvNet(nn.Module): 8 | 9 | def __init__(self, num_classes=-1): 10 | super(ConvNet, self).__init__() 11 | self.layer1 = nn.Sequential( 12 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(64), 14 | nn.ReLU(), 15 | nn.MaxPool2d(2)) 16 | self.layer2 = nn.Sequential( 17 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(64), 19 | nn.ReLU(), 20 | nn.MaxPool2d(2)) 21 | self.layer3 = nn.Sequential( 22 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(64), 24 | nn.ReLU(), 25 | nn.MaxPool2d(2)) 26 | self.layer4 = nn.Sequential( 27 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 28 | # nn.BatchNorm2d(64, momentum=1, affine=True, track_running_stats=False), 29 | nn.BatchNorm2d(64), 30 | nn.ReLU()) 31 | 32 | self.avgpool = nn.AdaptiveAvgPool2d(1) 33 | 34 | self.num_classes = num_classes 35 | if self.num_classes > 0: 36 | self.classifier = nn.Linear(64, self.num_classes) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 41 | elif isinstance(m, nn.BatchNorm2d): 42 | nn.init.constant_(m.weight, 1) 43 | nn.init.constant_(m.bias, 0) 44 | 45 | def forward(self, x, is_feat=False): 46 | out = self.layer1(x) 47 | f0 = out 48 | out = self.layer2(out) 49 | f1 = out 50 | out = self.layer3(out) 51 | f2 = out 52 | out = self.layer4(out) 53 | f3 = out 54 | out = self.avgpool(out) 55 | out = out.view(out.size(0), -1) 56 | feat = out 57 | 58 | if self.num_classes > 0: 59 | out = self.classifier(out) 60 | 61 | if is_feat: 62 | return [f0, f1, f2, f3, feat], out 63 | else: 64 | return out 65 | 66 | 67 | def convnet4(**kwargs): 68 | """Four layer ConvNet 69 | """ 70 | model = ConvNet(**kwargs) 71 | return model 72 | 73 | 74 | if __name__ == '__main__': 75 | model = convnet4(num_classes=64) 76 | data = torch.randn(2, 3, 84, 84) 77 | feat, logit = model(data, is_feat=True) 78 | print(feat[-1].shape) 79 | print(logit.shape) 80 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.distributions import Bernoulli 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class SELayer(nn.Module): 14 | def __init__(self, channel, reduction=16): 15 | super(SELayer, self).__init__() 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | self.fc = nn.Sequential( 18 | nn.Linear(channel, channel // reduction), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(channel // reduction, channel), 21 | nn.Sigmoid() 22 | ) 23 | 24 | def forward(self, x): 25 | b, c, _, _ = x.size() 26 | y = self.avg_pool(x).view(b, c) 27 | y = self.fc(y).view(b, c, 1, 1) 28 | return x * y 29 | 30 | 31 | class DropBlock(nn.Module): 32 | def __init__(self, block_size): 33 | super(DropBlock, self).__init__() 34 | 35 | self.block_size = block_size 36 | #self.gamma = gamma 37 | #self.bernouli = Bernoulli(gamma) 38 | 39 | def forward(self, x, gamma): 40 | # shape: (bsize, channels, height, width) 41 | 42 | if self.training: 43 | batch_size, channels, height, width = x.shape 44 | 45 | bernoulli = Bernoulli(gamma) 46 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 47 | block_mask = self._compute_block_mask(mask) 48 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 49 | count_ones = block_mask.sum() 50 | 51 | return block_mask * x * (countM / count_ones) 52 | else: 53 | return x 54 | 55 | def _compute_block_mask(self, mask): 56 | left_padding = int((self.block_size-1) / 2) 57 | right_padding = int(self.block_size / 2) 58 | 59 | batch_size, channels, height, width = mask.shape 60 | #print ("mask", mask[0][0]) 61 | non_zero_idxs = mask.nonzero() 62 | nr_blocks = non_zero_idxs.shape[0] 63 | 64 | offsets = torch.stack( 65 | [ 66 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 67 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 68 | ] 69 | ).t().cuda() 70 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 71 | 72 | if nr_blocks > 0: 73 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 74 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 75 | offsets = offsets.long() 76 | 77 | block_idxs = non_zero_idxs + offsets 78 | #block_idxs += left_padding 79 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 80 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 81 | else: 82 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 83 | 84 | block_mask = 1 - padded_mask#[:height, :width] 85 | return block_mask 86 | 87 | 88 | class BasicBlock(nn.Module): 89 | expansion = 1 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, 92 | block_size=1, use_se=False): 93 | super(BasicBlock, self).__init__() 94 | self.conv1 = conv3x3(inplanes, planes) 95 | self.bn1 = nn.BatchNorm2d(planes) 96 | self.relu = nn.LeakyReLU(0.1) 97 | self.conv2 = conv3x3(planes, planes) 98 | self.bn2 = nn.BatchNorm2d(planes) 99 | self.conv3 = conv3x3(planes, planes) 100 | self.bn3 = nn.BatchNorm2d(planes) 101 | self.maxpool = nn.MaxPool2d(stride) 102 | self.downsample = downsample 103 | self.stride = stride 104 | self.drop_rate = drop_rate 105 | self.num_batches_tracked = 0 106 | self.drop_block = drop_block 107 | self.block_size = block_size 108 | self.DropBlock = DropBlock(block_size=self.block_size) 109 | self.use_se = use_se 110 | if self.use_se: 111 | self.se = SELayer(planes, 4) 112 | 113 | def forward(self, x): 114 | self.num_batches_tracked += 1 115 | 116 | residual = x 117 | 118 | out = self.conv1(x) 119 | out = self.bn1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.bn3(out) 128 | if self.use_se: 129 | out = self.se(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | out += residual 134 | out = self.relu(out) 135 | out = self.maxpool(out) 136 | 137 | if self.drop_rate > 0: 138 | if self.drop_block == True: 139 | feat_size = out.size()[2] 140 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 141 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 142 | out = self.DropBlock(out, gamma=gamma) 143 | else: 144 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 145 | 146 | return out 147 | 148 | 149 | class ResNet(nn.Module): 150 | 151 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0, 152 | dropblock_size=5, num_classes=-1, use_se=False): 153 | super(ResNet, self).__init__() 154 | 155 | self.inplanes = 3 156 | self.use_se = use_se 157 | self.layer1 = self._make_layer(block, n_blocks[0], 64, 158 | stride=2, drop_rate=drop_rate) 159 | self.layer2 = self._make_layer(block, n_blocks[1], 160, 160 | stride=2, drop_rate=drop_rate) 161 | self.layer3 = self._make_layer(block, n_blocks[2], 320, 162 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 163 | self.layer4 = self._make_layer(block, n_blocks[3], 640, 164 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 165 | if avg_pool: 166 | # self.avgpool = nn.AvgPool2d(5, stride=1) 167 | self.avgpool = nn.AdaptiveAvgPool2d(1) 168 | self.keep_prob = keep_prob 169 | self.keep_avg_pool = avg_pool 170 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 171 | self.drop_rate = drop_rate 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 176 | elif isinstance(m, nn.BatchNorm2d): 177 | nn.init.constant_(m.weight, 1) 178 | nn.init.constant_(m.bias, 0) 179 | 180 | self.num_classes = num_classes 181 | if self.num_classes > 0: 182 | self.classifier = nn.Linear(640, self.num_classes) 183 | 184 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 185 | downsample = None 186 | if stride != 1 or self.inplanes != planes * block.expansion: 187 | downsample = nn.Sequential( 188 | nn.Conv2d(self.inplanes, planes * block.expansion, 189 | kernel_size=1, stride=1, bias=False), 190 | nn.BatchNorm2d(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | if n_block == 1: 195 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se) 196 | else: 197 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se) 198 | layers.append(layer) 199 | self.inplanes = planes * block.expansion 200 | 201 | for i in range(1, n_block): 202 | if i == n_block - 1: 203 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block, 204 | block_size=block_size, use_se=self.use_se) 205 | else: 206 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se) 207 | layers.append(layer) 208 | 209 | return nn.Sequential(*layers) 210 | 211 | def forward(self, x, is_feat=False): 212 | x = self.layer1(x) 213 | f0 = x 214 | x = self.layer2(x) 215 | f1 = x 216 | x = self.layer3(x) 217 | f2 = x 218 | x = self.layer4(x) 219 | f3 = x 220 | if self.keep_avg_pool: 221 | x = self.avgpool(x) 222 | x = x.view(x.size(0), -1) 223 | feat = x 224 | 225 | if self.num_classes > 0: 226 | x = self.classifier(x) 227 | 228 | if is_feat: 229 | return [f0, f1, f2, f3, feat], x 230 | else: 231 | return x 232 | 233 | 234 | def resnet12(keep_prob=1.0, avg_pool=False, **kwargs): 235 | """Constructs a ResNet-12 model. 236 | """ 237 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 238 | return model 239 | 240 | 241 | def resnet18(keep_prob=1.0, avg_pool=False, **kwargs): 242 | """Constructs a ResNet-18 model. 243 | """ 244 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 245 | return model 246 | 247 | 248 | def resnet24(keep_prob=1.0, avg_pool=False, **kwargs): 249 | """Constructs a ResNet-24 model. 250 | """ 251 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 252 | return model 253 | 254 | 255 | def resnet50(keep_prob=1.0, avg_pool=False, **kwargs): 256 | """Constructs a ResNet-50 model. 257 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers 258 | """ 259 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 260 | return model 261 | 262 | 263 | def resnet101(keep_prob=1.0, avg_pool=False, **kwargs): 264 | """Constructs a ResNet-101 model. 265 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers 266 | """ 267 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 268 | return model 269 | 270 | 271 | def seresnet12(keep_prob=1.0, avg_pool=False, **kwargs): 272 | """Constructs a ResNet-12 model. 273 | """ 274 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 275 | return model 276 | 277 | 278 | def seresnet18(keep_prob=1.0, avg_pool=False, **kwargs): 279 | """Constructs a ResNet-18 model. 280 | """ 281 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 282 | return model 283 | 284 | 285 | def seresnet24(keep_prob=1.0, avg_pool=False, **kwargs): 286 | """Constructs a ResNet-24 model. 287 | """ 288 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 289 | return model 290 | 291 | 292 | def seresnet50(keep_prob=1.0, avg_pool=False, **kwargs): 293 | """Constructs a ResNet-50 model. 294 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers 295 | """ 296 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 297 | return model 298 | 299 | 300 | def seresnet101(keep_prob=1.0, avg_pool=False, **kwargs): 301 | """Constructs a ResNet-101 model. 302 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers 303 | """ 304 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 305 | return model 306 | 307 | 308 | if __name__ == '__main__': 309 | 310 | import argparse 311 | 312 | parser = argparse.ArgumentParser('argument for training') 313 | parser.add_argument('--model', type=str, choices=['resnet12', 'resnet18', 'resnet24', 'resnet50', 'resnet101', 314 | 'seresnet12', 'seresnet18', 'seresnet24', 'seresnet50', 315 | 'seresnet101']) 316 | args = parser.parse_args() 317 | 318 | model_dict = { 319 | 'resnet12': resnet12, 320 | 'resnet18': resnet18, 321 | 'resnet24': resnet24, 322 | 'resnet50': resnet50, 323 | 'resnet101': resnet101, 324 | 'seresnet12': seresnet12, 325 | 'seresnet18': seresnet18, 326 | 'seresnet24': seresnet24, 327 | 'seresnet50': seresnet50, 328 | 'seresnet101': seresnet101, 329 | } 330 | 331 | model = model_dict[args.model](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=64) 332 | data = torch.randn(2, 3, 84, 84) 333 | model = model.cuda() 334 | data = data.cuda() 335 | feat, logit = model(data, is_feat=True) 336 | print(feat[-1].shape) 337 | print(logit.shape) 338 | -------------------------------------------------------------------------------- /models/resnet_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['ResNet', 'resnet50'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class Normalize(nn.Module): 24 | 25 | def __init__(self, power=2): 26 | super(Normalize, self).__init__() 27 | self.power = power 28 | 29 | def forward(self, x): 30 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 31 | out = x.div(norm) 32 | return out 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None): 71 | super(Bottleneck, self).__init__() 72 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(planes) 74 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 75 | padding=1, bias=False) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(planes * 4) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, in_channel=3, width=1, num_classes=64): 109 | self.inplanes = 64 110 | super(ResNet, self).__init__() 111 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, 112 | bias=False) 113 | self.bn1 = nn.BatchNorm2d(64) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | self.base = int(64 * width) 117 | 118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 119 | self.layer1 = self._make_layer(block, self.base, layers[0]) 120 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 121 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 122 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 123 | self.avgpool = nn.AvgPool2d(3, stride=1) 124 | self.classifier = nn.Linear(self.base * 8 * block.expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x, is_feat=False): 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | x = self.layer1(x) 157 | x = self.layer2(x) 158 | x = self.layer3(x) 159 | x = self.layer4(x) 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | 163 | if is_feat: 164 | return [x], x 165 | 166 | x = self.classifier(x) 167 | return x 168 | 169 | 170 | def resnet50(pretrained=False, **kwargs): 171 | """Constructs a ResNet-50 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 178 | return model 179 | 180 | 181 | if __name__ == '__main__': 182 | model = resnet50(num_classes=200) 183 | 184 | data = torch.randn(2, 3, 84, 84) 185 | -------------------------------------------------------------------------------- /models/resnet_sd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch 4 | 5 | __all__ = ['ResNet_StoDepth_lineardecay', 'resnet18_StoDepth_lineardecay', 'resnet34_StoDepth_lineardecay', 'resnet50_StoDepth_lineardecay', 'resnet101_StoDepth_lineardecay', 6 | 'resnet152_StoDepth_lineardecay'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class StoDepth_BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, prob, multFlag, inplanes, planes, stride=1, downsample=None): 33 | super(StoDepth_BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | self.prob = prob 42 | self.m = torch.distributions.bernoulli.Bernoulli(torch.Tensor([self.prob])) 43 | self.multFlag = multFlag 44 | 45 | def forward(self, x): 46 | 47 | identity = x.clone() 48 | 49 | if self.training: 50 | if torch.equal(self.m.sample(),torch.ones(1)): 51 | 52 | self.conv1.weight.requires_grad = True 53 | self.conv2.weight.requires_grad = True 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | identity = self.downsample(x) 63 | 64 | out += identity 65 | else: 66 | # Resnet does not use bias terms 67 | self.conv1.weight.requires_grad = False 68 | self.conv2.weight.requires_grad = False 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out = identity 74 | else: 75 | 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | if self.multFlag: 87 | out = self.prob*out + identity 88 | else: 89 | out = out + identity 90 | 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class StoDepth_Bottleneck(nn.Module): 97 | expansion = 4 98 | 99 | def __init__(self, prob, multFlag, inplanes, planes, stride=1, downsample=None): 100 | super(StoDepth_Bottleneck, self).__init__() 101 | self.conv1 = conv1x1(inplanes, planes) 102 | self.bn1 = nn.BatchNorm2d(planes) 103 | self.conv2 = conv3x3(planes, planes, stride) 104 | self.bn2 = nn.BatchNorm2d(planes) 105 | self.conv3 = conv1x1(planes, planes * self.expansion) 106 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.downsample = downsample 109 | self.stride = stride 110 | self.prob = prob 111 | self.m = torch.distributions.bernoulli.Bernoulli(torch.Tensor([self.prob])) 112 | self.multFlag = multFlag 113 | 114 | def forward(self, x): 115 | 116 | identity = x.clone() 117 | 118 | if self.training: 119 | if torch.equal(self.m.sample(),torch.ones(1)): 120 | self.conv1.weight.requires_grad = True 121 | self.conv2.weight.requires_grad = True 122 | self.conv3.weight.requires_grad = True 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv3(out) 133 | out = self.bn3(out) 134 | 135 | if self.downsample is not None: 136 | identity = self.downsample(x) 137 | 138 | out += identity 139 | else: 140 | # Resnet does not use bias terms 141 | self.conv1.weight.requires_grad = False 142 | self.conv2.weight.requires_grad = False 143 | self.conv3.weight.requires_grad = False 144 | 145 | if self.downsample is not None: 146 | identity = self.downsample(x) 147 | 148 | out = identity 149 | else: 150 | out = self.conv1(x) 151 | out = self.bn1(out) 152 | out = self.relu(out) 153 | 154 | out = self.conv2(out) 155 | out = self.bn2(out) 156 | out = self.relu(out) 157 | 158 | out = self.conv3(out) 159 | out = self.bn3(out) 160 | 161 | if self.downsample is not None: 162 | identity = self.downsample(x) 163 | 164 | if self.multFlag: 165 | out = self.prob*out + identity 166 | else: 167 | out = out + identity 168 | 169 | out = self.relu(out) 170 | 171 | return out 172 | 173 | 174 | class ResNet_StoDepth_lineardecay(nn.Module): 175 | 176 | def __init__(self, block, prob_0_L, multFlag, layers, num_classes=1000, zero_init_residual=False): 177 | super(ResNet_StoDepth_lineardecay, self).__init__() 178 | self.inplanes = 64 179 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 180 | bias=False) 181 | self.bn1 = nn.BatchNorm2d(64) 182 | self.relu = nn.ReLU(inplace=True) 183 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 184 | 185 | self.multFlag = multFlag 186 | self.prob_now = prob_0_L[0] 187 | self.prob_delta = prob_0_L[0]-prob_0_L[1] 188 | self.prob_step = self.prob_delta/(sum(layers)-1) 189 | 190 | self.layer1 = self._make_layer(block, 64, layers[0]) 191 | self.layer2 = self._make_layer(block, 160, layers[1], stride=2) 192 | self.layer3 = self._make_layer(block, 320, layers[2], stride=2) 193 | self.layer4 = self._make_layer(block, 640, layers[3], stride=2) 194 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 195 | self.fc = nn.Linear(640 * block.expansion, 64) 196 | 197 | for m in self.modules(): 198 | if isinstance(m, nn.Conv2d): 199 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 200 | elif isinstance(m, nn.BatchNorm2d): 201 | nn.init.constant_(m.weight, 1) 202 | nn.init.constant_(m.bias, 0) 203 | 204 | # Zero-initialize the last BN in each residual branch, 205 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 206 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 207 | if zero_init_residual: 208 | for m in self.modules(): 209 | if isinstance(m, StoDepth_lineardecayBottleneck): 210 | nn.init.constant_(m.bn3.weight, 0) 211 | elif isinstance(m, StoDepth_lineardecayBasicBlock): 212 | nn.init.constant_(m.bn2.weight, 0) 213 | 214 | def _make_layer(self, block, planes, blocks, stride=1): 215 | downsample = None 216 | if stride != 1 or self.inplanes != planes * block.expansion: 217 | downsample = nn.Sequential( 218 | conv1x1(self.inplanes, planes * block.expansion, stride), 219 | nn.BatchNorm2d(planes * block.expansion), 220 | ) 221 | 222 | layers = [] 223 | layers.append(block(self.prob_now, self.multFlag, self.inplanes, planes, stride, downsample)) 224 | self.prob_now = self.prob_now - self.prob_step 225 | self.inplanes = planes * block.expansion 226 | for _ in range(1, blocks): 227 | layers.append(block(self.prob_now, self.multFlag, self.inplanes, planes)) 228 | self.prob_now = self.prob_now - self.prob_step 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def forward(self, x, is_feat=False): 233 | x = self.conv1(x) 234 | x = self.bn1(x) 235 | x = self.relu(x) 236 | x = self.maxpool(x) 237 | 238 | x = self.layer1(x) 239 | x = self.layer2(x) 240 | x = self.layer3(x) 241 | x = self.layer4(x) 242 | 243 | x = self.avgpool(x) 244 | x = x.view(x.size(0), -1) 245 | x = self.fc(x) 246 | 247 | return x 248 | 249 | 250 | 251 | def resnet12_sd(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs): 252 | """Constructs a ResNet_StoDepth_lineardecay-18 model. 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | """ 256 | model = ResNet_StoDepth_lineardecay(StoDepth_BasicBlock, prob_0_L, multFlag, [1, 1, 1, 1], **kwargs) 257 | return model 258 | 259 | 260 | def resnet18_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs): 261 | """Constructs a ResNet_StoDepth_lineardecay-18 model. 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | """ 265 | model = ResNet_StoDepth_lineardecay(StoDepth_BasicBlock, prob_0_L, multFlag, [2, 2, 2, 2], **kwargs) 266 | if pretrained: 267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 268 | return model 269 | 270 | 271 | def resnet34_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs): 272 | """Constructs a ResNet_StoDepth_lineardecay-34 model. 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | """ 276 | model = ResNet_StoDepth_lineardecay(StoDepth_BasicBlock, prob_0_L, multFlag, [3, 4, 6, 3], **kwargs) 277 | if pretrained: 278 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 279 | return model 280 | 281 | 282 | def resnet50_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs): 283 | """Constructs a ResNet_StoDepth_lineardecay-50 model. 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | """ 287 | model = ResNet_StoDepth_lineardecay(StoDepth_Bottleneck, prob_0_L, multFlag, [3, 4, 6, 3], **kwargs) 288 | if pretrained: 289 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 290 | return model 291 | 292 | 293 | def resnet101_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs): 294 | """Constructs a ResNet_StoDepth_lineardecay-101 model. 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | """ 298 | model = ResNet_StoDepth_lineardecay(StoDepth_Bottleneck, prob_0_L, multFlag, [3, 4, 23, 3], **kwargs) 299 | if pretrained: 300 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 301 | return model 302 | 303 | 304 | def resnet152_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs): 305 | """Constructs a ResNet_StoDepth_lineardecay-152 model. 306 | Args: 307 | pretrained (bool): If True, returns a model pre-trained on ImageNet 308 | """ 309 | model = ResNet_StoDepth_lineardecay(StoDepth_Bottleneck, prob_0_L, multFlag, [3, 8, 36, 3], **kwargs) 310 | if pretrained: 311 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 312 | return model -------------------------------------------------------------------------------- /models/resnet_selfdist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(in_planes, out_planes, stride=1): 5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 6 | stride=stride, padding=1, bias=False) 7 | 8 | def conv1x1(in_planes, planes, stride=1): 9 | return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 10 | 11 | def branchBottleNeck(channel_in, channel_out, kernel_size): 12 | middle_channel = channel_out//4 13 | return nn.Sequential( 14 | nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1), 15 | nn.BatchNorm2d(middle_channel), 16 | nn.ReLU(), 17 | 18 | nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size), 19 | nn.BatchNorm2d(middle_channel), 20 | nn.ReLU(), 21 | 22 | nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1), 23 | nn.BatchNorm2d(channel_out), 24 | nn.ReLU(), 25 | ) 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | output = self.conv1(x) 42 | output = self.bn1(output) 43 | output = self.relu(output) 44 | 45 | output = self.conv2(output) 46 | output = self.bn2(output) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | output += residual 52 | output = self.relu(output) 53 | return output 54 | 55 | class BottleneckBlock(nn.Module): 56 | expansion = 4 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(BottleneckBlock, self).__init__() 59 | self.conv1 = conv1x1(inplanes, planes) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | 63 | self.conv2 = conv3x3(planes, planes, stride) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | 66 | self.conv3 = conv1x1(planes, planes*self.expansion) 67 | self.bn3 = nn.BatchNorm2d(planes*self.expansion) 68 | 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | output = self.conv1(x) 76 | output = self.bn1(output) 77 | output = self.relu(output) 78 | 79 | output = self.conv2(output) 80 | output = self.bn2(output) 81 | output = self.relu(output) 82 | 83 | output = self.conv3(output) 84 | output = self.bn3(output) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | output += residual 90 | output = self.relu(output) 91 | 92 | return output 93 | 94 | class Multi_ResNet(nn.Module): 95 | """Resnet model 96 | 97 | Args: 98 | block (class): block type, BasicBlock or BottleneckBlock 99 | layers (int list): layer num in each block 100 | num_classes (int): class num 101 | """ 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | super(Multi_ResNet, self).__init__() 105 | self.inplanes = 64 106 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 107 | self.bn1 = nn.BatchNorm2d(self.inplanes) 108 | self.relu = nn.ReLU(inplace=True) 109 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | 116 | self.downsample1_1 = nn.Sequential( 117 | conv1x1(64 * block.expansion, 512 * block.expansion, stride=8), 118 | nn.BatchNorm2d(512 * block.expansion), 119 | ) 120 | self.bottleneck1_1 = branchBottleNeck(64 * block.expansion, 512 * block.expansion, kernel_size=8) 121 | self.avgpool1 = nn.AdaptiveAvgPool2d((1,1)) 122 | self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | 125 | self.downsample2_1 = nn.Sequential( 126 | conv1x1(128 * block.expansion, 512 * block.expansion, stride=4), 127 | nn.BatchNorm2d(512 * block.expansion), 128 | ) 129 | self.bottleneck2_1 = branchBottleNeck(128 * block.expansion, 512 * block.expansion, kernel_size=4) 130 | self.avgpool2 = nn.AdaptiveAvgPool2d((1,1)) 131 | self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes) 132 | 133 | 134 | self.downsample3_1 = nn.Sequential( 135 | conv1x1(256 * block.expansion, 512 * block.expansion, stride=2), 136 | nn.BatchNorm2d(512 * block.expansion), 137 | ) 138 | self.bottleneck3_1 = branchBottleNeck(256 * block.expansion, 512 * block.expansion, kernel_size=2) 139 | self.avgpool3 = nn.AdaptiveAvgPool2d((1,1)) 140 | self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes) 141 | 142 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 143 | self.fc = nn.Linear(512 * block.expansion, num_classes) 144 | 145 | for m in self.modules(): 146 | if isinstance(m, nn.Conv2d): 147 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 148 | elif isinstance(m, nn.BatchNorm2d): 149 | nn.init.constant_(m.weight, 1) 150 | nn.init.constant_(m.bias, 0) 151 | 152 | def _make_layer(self, block, planes, layers, stride=1): 153 | """A block with 'layers' layers 154 | 155 | Args: 156 | block (class): block type 157 | planes (int): output channels = planes * expansion 158 | layers (int): layer num in the block 159 | stride (int): the first layer stride in the block 160 | """ 161 | downsample = None 162 | if stride !=1 or self.inplanes != planes * block.expansion: 163 | downsample = nn.Sequential( 164 | conv1x1(self.inplanes, planes * block.expansion, stride), 165 | nn.BatchNorm2d(planes * block.expansion), 166 | ) 167 | layer = [] 168 | layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample)) 169 | self.inplanes = planes * block.expansion 170 | for i in range(1, layers): 171 | layer.append(block(self.inplanes, planes)) 172 | 173 | return nn.Sequential(*layer) 174 | 175 | def forward(self, x, is_feat=False, is_dist=False): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | # x = self.maxpool(x) 180 | 181 | x = self.layer1(x) 182 | middle_output1 = self.bottleneck1_1(x) 183 | middle_output1 = self.avgpool1(middle_output1) 184 | middle1_fea = middle_output1 185 | middle_output1 = torch.flatten(middle_output1, 1) 186 | middle_output1 = self.middle_fc1(middle_output1) 187 | 188 | x = self.layer2(x) 189 | middle_output2 = self.bottleneck2_1(x) 190 | middle_output2 = self.avgpool2(middle_output2) 191 | middle2_fea = middle_output2 192 | middle_output2 = torch.flatten(middle_output2, 1) 193 | middle_output2 = self.middle_fc2(middle_output2) 194 | 195 | x = self.layer3(x) 196 | middle_output3 = self.bottleneck3_1(x) 197 | middle_output3 = self.avgpool3(middle_output3) 198 | middle3_fea = middle_output3 199 | middle_output3 = torch.flatten(middle_output3, 1) 200 | middle_output3 = self.middle_fc3(middle_output3) 201 | 202 | x = self.layer4(x) 203 | x = self.avgpool(x) 204 | final_fea = x 205 | x = torch.flatten(x, 1) 206 | x = self.fc(x) 207 | 208 | 209 | if is_dist: 210 | return x, middle_output1, middle_output2, middle_output3, final_fea, middle1_fea, middle2_fea, middle3_fea 211 | 212 | if is_feat: 213 | return [final_fea, final_fea, final_fea, final_fea, final_fea], x 214 | else: 215 | return x 216 | 217 | 218 | def multi_resnet50_kd(num_classes=1000): 219 | return Multi_ResNet(BottleneckBlock, [3,4,6,3], num_classes=num_classes) 220 | 221 | def multi_resnet18_kd(num_classes=1000): 222 | return Multi_ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes) 223 | 224 | 225 | def multi_resnet12_kd(num_classes=64): 226 | return Multi_ResNet(BasicBlock, [1,1,1,1], num_classes=num_classes) 227 | 228 | -------------------------------------------------------------------------------- /models/resnet_ssl.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.distributions import Bernoulli 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class SELayer(nn.Module): 14 | def __init__(self, channel, reduction=16): 15 | super(SELayer, self).__init__() 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | self.fc = nn.Sequential( 18 | nn.Linear(channel, channel // reduction), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(channel // reduction, channel), 21 | nn.Sigmoid() 22 | ) 23 | 24 | def forward(self, x): 25 | b, c, _, _ = x.size() 26 | y = self.avg_pool(x).view(b, c) 27 | y = self.fc(y).view(b, c, 1, 1) 28 | return x * y 29 | 30 | 31 | class DropBlock(nn.Module): 32 | def __init__(self, block_size): 33 | super(DropBlock, self).__init__() 34 | 35 | self.block_size = block_size 36 | #self.gamma = gamma 37 | #self.bernouli = Bernoulli(gamma) 38 | 39 | def forward(self, x, gamma): 40 | # shape: (bsize, channels, height, width) 41 | 42 | if self.training: 43 | batch_size, channels, height, width = x.shape 44 | 45 | bernoulli = Bernoulli(gamma) 46 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 47 | block_mask = self._compute_block_mask(mask) 48 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 49 | count_ones = block_mask.sum() 50 | 51 | return block_mask * x * (countM / count_ones) 52 | else: 53 | return x 54 | 55 | def _compute_block_mask(self, mask): 56 | left_padding = int((self.block_size-1) / 2) 57 | right_padding = int(self.block_size / 2) 58 | 59 | batch_size, channels, height, width = mask.shape 60 | #print ("mask", mask[0][0]) 61 | non_zero_idxs = mask.nonzero() 62 | nr_blocks = non_zero_idxs.shape[0] 63 | 64 | offsets = torch.stack( 65 | [ 66 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 67 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 68 | ] 69 | ).t().cuda() 70 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 71 | 72 | if nr_blocks > 0: 73 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 74 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 75 | offsets = offsets.long() 76 | 77 | block_idxs = non_zero_idxs + offsets 78 | #block_idxs += left_padding 79 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 80 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 81 | else: 82 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 83 | 84 | block_mask = 1 - padded_mask#[:height, :width] 85 | return block_mask 86 | 87 | 88 | class BasicBlock(nn.Module): 89 | expansion = 1 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, 92 | block_size=1, use_se=False): 93 | super(BasicBlock, self).__init__() 94 | self.conv1 = conv3x3(inplanes, planes) 95 | self.bn1 = nn.BatchNorm2d(planes) 96 | self.relu = nn.LeakyReLU(0.1) 97 | self.conv2 = conv3x3(planes, planes) 98 | self.bn2 = nn.BatchNorm2d(planes) 99 | self.conv3 = conv3x3(planes, planes) 100 | self.bn3 = nn.BatchNorm2d(planes) 101 | self.maxpool = nn.MaxPool2d(stride) 102 | self.downsample = downsample 103 | self.stride = stride 104 | self.drop_rate = drop_rate 105 | self.num_batches_tracked = 0 106 | self.drop_block = drop_block 107 | self.block_size = block_size 108 | self.DropBlock = DropBlock(block_size=self.block_size) 109 | self.use_se = use_se 110 | if self.use_se: 111 | self.se = SELayer(planes, 4) 112 | 113 | def forward(self, x): 114 | self.num_batches_tracked += 1 115 | 116 | residual = x 117 | 118 | out = self.conv1(x) 119 | out = self.bn1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.bn3(out) 128 | if self.use_se: 129 | out = self.se(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | out += residual 134 | out = self.relu(out) 135 | out = self.maxpool(out) 136 | 137 | if self.drop_rate > 0: 138 | if self.drop_block == True: 139 | feat_size = out.size()[2] 140 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 141 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 142 | out = self.DropBlock(out, gamma=gamma) 143 | else: 144 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 145 | 146 | return out 147 | 148 | 149 | class ResNet(nn.Module): 150 | 151 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0, 152 | dropblock_size=5, num_classes=-1, use_se=False): 153 | super(ResNet, self).__init__() 154 | 155 | self.inplanes = 3 156 | self.use_se = use_se 157 | self.layer1 = self._make_layer(block, n_blocks[0], 64, 158 | stride=2, drop_rate=drop_rate) 159 | self.layer2 = self._make_layer(block, n_blocks[1], 160, 160 | stride=2, drop_rate=drop_rate) 161 | self.layer3 = self._make_layer(block, n_blocks[2], 320, 162 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 163 | self.layer4 = self._make_layer(block, n_blocks[3], 640, 164 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 165 | if avg_pool: 166 | # self.avgpool = nn.AvgPool2d(5, stride=1) 167 | self.avgpool = nn.AdaptiveAvgPool2d(1) 168 | self.keep_prob = keep_prob 169 | self.keep_avg_pool = avg_pool 170 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 171 | self.drop_rate = drop_rate 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 176 | elif isinstance(m, nn.BatchNorm2d): 177 | nn.init.constant_(m.weight, 1) 178 | nn.init.constant_(m.bias, 0) 179 | 180 | self.num_classes = num_classes 181 | if self.num_classes > 0: 182 | self.classifier = nn.Linear(640, self.num_classes) 183 | self.rot_classifier = nn.Linear(self.num_classes, 4) 184 | # self.rot_classifier1 = nn.Linear(self.num_classes, 32) 185 | # self.rot_classifier2 = nn.Linear(32, 16) 186 | # self.rot_classifier3 = nn.Linear(16, 4) 187 | 188 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 189 | downsample = None 190 | if stride != 1 or self.inplanes != planes * block.expansion: 191 | downsample = nn.Sequential( 192 | nn.Conv2d(self.inplanes, planes * block.expansion, 193 | kernel_size=1, stride=1, bias=False), 194 | nn.BatchNorm2d(planes * block.expansion), 195 | ) 196 | 197 | layers = [] 198 | if n_block == 1: 199 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se) 200 | else: 201 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se) 202 | layers.append(layer) 203 | self.inplanes = planes * block.expansion 204 | 205 | for i in range(1, n_block): 206 | if i == n_block - 1: 207 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block, 208 | block_size=block_size, use_se=self.use_se) 209 | else: 210 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se) 211 | layers.append(layer) 212 | 213 | return nn.Sequential(*layers) 214 | 215 | def forward(self, x, is_feat=False, rot=False): 216 | x = self.layer1(x) 217 | f0 = x 218 | x = self.layer2(x) 219 | f1 = x 220 | x = self.layer3(x) 221 | f2 = x 222 | x = self.layer4(x) 223 | f3 = x 224 | if self.keep_avg_pool: 225 | x = self.avgpool(x) 226 | x = x.view(x.size(0), -1) 227 | feat = x 228 | 229 | xx = self.classifier(x) 230 | 231 | if(rot): 232 | # xy1 = self.rot_classifier1(xx) 233 | # xy2 = self.rot_classifier2(xy1) 234 | xy = self.rot_classifier(xx) 235 | return [f0, f1, f2, f3, feat], (xx, xy) 236 | 237 | if is_feat: 238 | return [f0, f1, f2, f3, feat], xx 239 | else: 240 | return xx 241 | 242 | 243 | def resnet12_ssl(keep_prob=1.0, avg_pool=False, **kwargs): 244 | """Constructs a ResNet-12 model. 245 | """ 246 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 247 | return model 248 | 249 | 250 | def resnet18(keep_prob=1.0, avg_pool=False, **kwargs): 251 | """Constructs a ResNet-18 model. 252 | """ 253 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 254 | return model 255 | 256 | 257 | def resnet24(keep_prob=1.0, avg_pool=False, **kwargs): 258 | """Constructs a ResNet-24 model. 259 | """ 260 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 261 | return model 262 | 263 | 264 | def resnet50(keep_prob=1.0, avg_pool=False, **kwargs): 265 | """Constructs a ResNet-50 model. 266 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers 267 | """ 268 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 269 | return model 270 | 271 | 272 | def resnet101(keep_prob=1.0, avg_pool=False, **kwargs): 273 | """Constructs a ResNet-101 model. 274 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers 275 | """ 276 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 277 | return model 278 | 279 | 280 | def seresnet12(keep_prob=1.0, avg_pool=False, **kwargs): 281 | """Constructs a ResNet-12 model. 282 | """ 283 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 284 | return model 285 | 286 | 287 | def seresnet18(keep_prob=1.0, avg_pool=False, **kwargs): 288 | """Constructs a ResNet-18 model. 289 | """ 290 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 291 | return model 292 | 293 | 294 | def seresnet24(keep_prob=1.0, avg_pool=False, **kwargs): 295 | """Constructs a ResNet-24 model. 296 | """ 297 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 298 | return model 299 | 300 | 301 | def seresnet50(keep_prob=1.0, avg_pool=False, **kwargs): 302 | """Constructs a ResNet-50 model. 303 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers 304 | """ 305 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 306 | return model 307 | 308 | 309 | def seresnet101(keep_prob=1.0, avg_pool=False, **kwargs): 310 | """Constructs a ResNet-101 model. 311 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers 312 | """ 313 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs) 314 | return model 315 | 316 | 317 | if __name__ == '__main__': 318 | 319 | import argparse 320 | 321 | parser = argparse.ArgumentParser('argument for training') 322 | parser.add_argument('--model', type=str, choices=['resnet12', 'resnet18', 'resnet24', 'resnet50', 'resnet101', 323 | 'seresnet12', 'seresnet18', 'seresnet24', 'seresnet50', 324 | 'seresnet101']) 325 | args = parser.parse_args() 326 | 327 | model_dict = { 328 | 'resnet12': resnet12, 329 | 'resnet18': resnet18, 330 | 'resnet24': resnet24, 331 | 'resnet50': resnet50, 332 | 'resnet101': resnet101, 333 | 'seresnet12': seresnet12, 334 | 'seresnet18': seresnet18, 335 | 'seresnet24': seresnet24, 336 | 'seresnet50': seresnet50, 337 | 'seresnet101': seresnet101, 338 | } 339 | 340 | model = model_dict[args.model](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=64) 341 | data = torch.randn(2, 3, 84, 84) 342 | model = model.cuda() 343 | data = data.cuda() 344 | feat, logit = model(data, is_feat=True) 345 | print(feat[-1].shape) 346 | print(logit.shape) 347 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | 4 | from . import model_dict 5 | 6 | 7 | def create_model(name, n_cls, dataset='miniImageNet', dropout=0.1): 8 | """create model by name""" 9 | if dataset == 'miniImageNet' or dataset == 'tieredImageNet': 10 | if name.endswith('v2') or name.endswith('v3'): 11 | model = model_dict[name](num_classes=n_cls) 12 | elif(name.endswith("kd")): 13 | model = model_dict[name](num_classes=n_cls) 14 | elif(name.endswith("ssl")): 15 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=n_cls) 16 | elif name.startswith('resnet50'): 17 | print('use imagenet-style resnet50') 18 | model = model_dict[name](num_classes=n_cls) 19 | elif name.startswith('resnet') or name.startswith('seresnet'): 20 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=n_cls) 21 | elif name.startswith('wrn'): 22 | model = model_dict[name](num_classes=n_cls) 23 | elif name.startswith('convnet'): 24 | model = model_dict[name](num_classes=n_cls) 25 | else: 26 | raise NotImplementedError('model {} not supported in dataset {}:'.format(name, dataset)) 27 | elif dataset == 'CIFAR-FS' or dataset == 'FC100' or dataset=="toy": 28 | 29 | print("***********", name) 30 | if(name.endswith("kd")): 31 | model = model_dict[name](num_classes=n_cls) 32 | elif(name.endswith("sd")): 33 | model = model_dict[name]() 34 | elif(name.endswith("ssl")): 35 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=2, num_classes=n_cls) 36 | elif name.startswith('resnet') or name.startswith('seresnet'): 37 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=2, num_classes=n_cls) 38 | elif name.startswith('convnet'): 39 | model = model_dict[name](num_classes=n_cls) 40 | else: 41 | raise NotImplementedError('model {} not supported in dataset {}:'.format(name, dataset)) 42 | else: 43 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 44 | 45 | return model 46 | 47 | 48 | def get_teacher_name(model_path): 49 | """parse to get teacher model name""" 50 | segments = model_path.split('/')[-2].split('_') 51 | if ':' in segments[0]: 52 | return segments[0].split(':')[-1] 53 | else: 54 | if segments[0] != 'wrn': 55 | return segments[0] 56 | else: 57 | return segments[0] + '_' + segments[1] + '_' + segments[2] 58 | -------------------------------------------------------------------------------- /models/wresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 13 | 14 | 15 | def conv_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 19 | init.constant(m.bias, 0) 20 | elif classname.find('BatchNorm') != -1: 21 | init.constant(m.weight, 1) 22 | init.constant(m.bias, 0) 23 | 24 | 25 | class wide_basic(nn.Module): 26 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 27 | super(wide_basic, self).__init__() 28 | self.bn1 = nn.BatchNorm2d(in_planes) 29 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 30 | self.dropout = nn.Dropout(p=dropout_rate) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 38 | ) 39 | 40 | def forward(self, x): 41 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 42 | out = self.conv2(F.relu(self.bn2(out))) 43 | out += self.shortcut(x) 44 | 45 | return out 46 | 47 | 48 | class Wide_ResNet(nn.Module): 49 | def __init__(self, depth, widen_factor, dropout_rate, num_classes=-1): 50 | super(Wide_ResNet, self).__init__() 51 | self.in_planes = 16 52 | 53 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 54 | n = (depth-4) // 6 55 | k = widen_factor 56 | 57 | print('| Wide-Resnet %dx%d' %(depth, k)) 58 | nStages = [16, 16*k, 32*k, 64*k] 59 | 60 | self.conv1 = conv3x3(3,nStages[0]) 61 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 62 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 63 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 64 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 65 | 66 | self.num_classes = num_classes 67 | if self.num_classes > 0: 68 | self.classifier = nn.Linear(64*k, self.num_classes) 69 | 70 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 71 | strides = [stride] + [1]*(num_blocks-1) 72 | layers = [] 73 | 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 76 | self.in_planes = planes 77 | 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x, is_feat=False): 81 | out = self.conv1(x) 82 | out = self.layer1(out) 83 | out = self.layer2(out) 84 | out = self.layer3(out) 85 | out = F.relu(self.bn1(out)) 86 | out = F.adaptive_avg_pool2d(out, 1) 87 | out = out.view(out.size(0), -1) 88 | feat = out 89 | if self.num_classes > 0: 90 | out = self.classifier(out) 91 | 92 | if is_feat: 93 | return [feat], out 94 | else: 95 | return out 96 | 97 | 98 | def wrn_28_10(dropout_rate=0.3, num_classes=-1): 99 | return Wide_ResNet(28, 10, dropout_rate, num_classes) 100 | 101 | 102 | if __name__ == '__main__': 103 | net=Wide_ResNet(28, 10, 0.3) 104 | y = net(Variable(torch.randn(1,3,32,32))) 105 | 106 | print(y.size()) 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.1 2 | mkl==2019.0 3 | numpy==1.18.4 4 | Pillow==7.1.2 5 | scikit_learn==0.23.1 6 | scipy==1.4.1 7 | torch==1.5.0 8 | torchvision==0.6.0 9 | tqdm==4.46.0 10 | wandb==0.8.36 11 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | ###################################################################################################################################################### 2 | ###################################################################################################################################################### 3 | ####################################################### CIFAR-FS ########################################################## 4 | ###################################################################################################################################################### 5 | ###################################################################################################################################################### 6 | 7 | # # # # self supervision 8 | # python3 train_supervised_ssl.py \ 9 | # --tags cifarfs,may30 \ 10 | # --model resnet12_ssl \ 11 | # --model_path save/backup \ 12 | # --dataset CIFAR-FS \ 13 | # --data_root ../../Datasets/CIFAR_FS/ \ 14 | # --n_aug_support_samples 5 \ 15 | # --n_ways 5 \ 16 | # --n_shots 1 \ 17 | # --epochs 65 \ 18 | # --lr_decay_epochs 60 \ 19 | # --gamma 2.0 & 20 | 21 | 22 | 23 | # for i in {0..0}; do 24 | # python3 train_distillation.py \ 25 | # --tags cifarfs,gen1,may30 \ 26 | # --model_s resnet12_ssl \ 27 | # --model_t resnet12_ssl \ 28 | # --path_t save/backup/resnet12_ssl_CIFAR-FS_lr_0.05_decay_0.0005_trans_D_trial_1/model_firm-sun-394.pth \ 29 | # --model_path save/backup \ 30 | # --dataset CIFAR-FS \ 31 | # --data_root ../../Datasets/CIFAR_FS/ \ 32 | # --n_aug_support_samples 5 \ 33 | # --n_ways 5 \ 34 | # --n_shots 1 \ 35 | # --epochs 65 \ 36 | # --lr_decay_epochs 60 \ 37 | # --gamma 0.1 & 38 | # sleep 1m 39 | # done 40 | 41 | 42 | 43 | 44 | 45 | 46 | # # # evaluation 47 | # CUDA_VISIBLE_DEVICES=0 python3 eval_fewshot.py \ 48 | # --model resnet12_ssl \ 49 | # --model_path save/backup2/resnet12_ssl_toy_lr_0.05_decay_0.0005_trans_A_trial_1/model_upbeat-dew-17.pth \ 50 | # --dataset toy \ 51 | # --data_root ../../Datasets/CIFAR_FS/ \ 52 | # --n_aug_support_samples 5 \ 53 | # --n_ways 5 \ 54 | # --n_shots 1 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | ###################################################################################################################################################### 66 | ###################################################################################################################################################### 67 | ####################################################### FC100 ########################################################## 68 | ###################################################################################################################################################### 69 | ###################################################################################################################################################### 70 | 71 | 72 | # # # # GEN0 73 | # CUDA_VISIBLE_DEVICES=1 python3 train_supervised_ssl.py \ 74 | # --model resnet12_ssl \ 75 | # --model_path save/backup \ 76 | # --tags fc100 \ 77 | # --dataset FC100 \ 78 | # --data_root ../Datasets/neurips2020/ \ 79 | # --n_aug_support_samples 5 \ 80 | # --n_ways 5 \ 81 | # --n_shots 1 \ 82 | # --epochs 65 \ 83 | # --lr_decay_epochs 60 \ 84 | # --gamma 2 & 85 | 86 | 87 | # # # # GEN1 88 | # for i in {0..0}; do 89 | # python3 train_distillation5.py \ 90 | # --model_s resnet12_ssl \ 91 | # --model_t resnet12_ssl \ 92 | # --path_t save/backup/resnet12_ssl_FC100_lr_0.05_decay_0.0005_trans_D_trial_1/model_effortless-wood-315.pth \ 93 | # --tags fc100 \ 94 | # --model_path save/neurips2020 \ 95 | # --dataset FC100 \ 96 | # --data_root ../Datasets/FC100/ \ 97 | # --n_aug_support_samples 5 \ 98 | # --n_ways 5 \ 99 | # --n_shots 1 \ 100 | # --batch_size 64 \ 101 | # --epochs 8 \ 102 | # --lr_decay_epochs 3 \ 103 | # --gamma 0.2 \ 104 | # # sleep 4m 105 | # done 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | ###################################################################################################################################################### 124 | ###################################################################################################################################################### 125 | ####################################################### miniImagenet ########################################################## 126 | ###################################################################################################################################################### 127 | ###################################################################################################################################################### 128 | 129 | 130 | 131 | 132 | # # # # GEN0 133 | # python3 train_supervised_ssl.py \ 134 | # --model resnet12_ssl \ 135 | # --model_path save/neurips2020 \ 136 | # --tags miniimagenet,gen0 \ 137 | # --dataset miniImageNet \ 138 | # --data_root ../Datasets/MiniImagenet/ \ 139 | # --n_aug_support_samples 5 \ 140 | # --n_ways 5 \ 141 | # --n_shots 1 \ 142 | # --epochs 65 \ 143 | # --lr_decay_epochs 60 \ 144 | # --gamma 2.0 & 145 | 146 | 147 | # params=( 0.025 0.05 0.075 0.1 0.15 0.2 0.25 0.3 0.4 0.5 ) 148 | 149 | # # # # GEN1 150 | # for i in {0..2}; do 151 | # python3 train_distillation.py \ 152 | # --model_s resnet12_ssl \ 153 | # --model_t resnet12_ssl \ 154 | # --path_t save/neurips2020/resnet12_ssl_miniImageNet_lr_0.05_decay_0.0005_trans_A_trial_1/model_swift-lake-4.pth \ 155 | # --tags miniimagenet,gen1,beta \ 156 | # --model_path save/neurips2020 \ 157 | # --dataset miniImageNet \ 158 | # --data_root ../Datasets/MiniImagenet/ \ 159 | # --n_aug_support_samples 5 \ 160 | # --n_ways 5 \ 161 | # --n_shots 1 \ 162 | # --batch_size 64 \ 163 | # --epochs 8 \ 164 | # --lr_decay_epochs 5 \ 165 | # --gamma ${params[i]} & 166 | # sleep 30m 167 | # done 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | ###################################################################################################################################################### 185 | ###################################################################################################################################################### 186 | ####################################################### tieredImageNet ########################################################## 187 | ###################################################################################################################################################### 188 | ###################################################################################################################################################### 189 | 190 | 191 | 192 | 193 | # # # # GEN0 194 | # python3 train_supervised_ssl.py \ 195 | # --model resnet12_ssl \ 196 | # --model_path save/backup \ 197 | # --tags tieredimageNet \ 198 | # --dataset tieredImageNet \ 199 | # --data_root ../Datasets/TieredImagenet/ \ 200 | # --n_aug_support_samples 5 \ 201 | # --n_ways 5 \ 202 | # --n_shots 1 \ 203 | # --epochs 60 \ 204 | # --lr_decay_epochs 30,40,50 \ 205 | # --gamma 2 206 | 207 | 208 | 209 | 210 | # # # # GEN1 211 | # for i in {0..6}; do 212 | # python3 train_distillation5.py \ 213 | # --model_s resnet12_ssl \ 214 | # --model_t resnet12_ssl \ 215 | # --path_t save/backup/resnet12_ssl_FC100_lr_0.05_decay_0.0005_trans_D_trial_1/model_effortless-wood-315.pth \ 216 | # --tags fc100 \ 217 | # --model_path save/backup \ 218 | # --dataset FC100 \ 219 | # --data_root ../../Datasets/FC100/ \ 220 | # --n_aug_support_samples 5 \ 221 | # --n_ways 5 \ 222 | # --n_shots 1 \ 223 | # --batch_size 64 \ 224 | # --epochs 8 \ 225 | # --lr_decay_epochs 3 \ 226 | # --gamma ${params[i]} \ 227 | # # sleep 4m 228 | # done 229 | 230 | -------------------------------------------------------------------------------- /train_distillation.py: -------------------------------------------------------------------------------- 1 | """ 2 | the general training framework 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import os 8 | import argparse 9 | import socket 10 | import time 11 | import sys 12 | from tqdm import tqdm 13 | import mkl 14 | 15 | import torch 16 | import torch.optim as optim 17 | import torch.nn as nn 18 | import torch.backends.cudnn as cudnn 19 | from torch.utils.data import DataLoader 20 | import torch.nn.functional as F 21 | 22 | from models import model_pool 23 | from models.util import create_model, get_teacher_name 24 | 25 | from distill.util import Embed 26 | from distill.criterion import DistillKL, NCELoss, Attention, HintLoss 27 | 28 | from dataset.mini_imagenet import ImageNet, MetaImageNet 29 | from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet 30 | from dataset.cifar import CIFAR100, MetaCIFAR100 31 | from dataset.transform_cfg import transforms_options, transforms_list 32 | 33 | from util import adjust_learning_rate, accuracy, AverageMeter 34 | from eval.meta_eval import meta_test, meta_test_tune 35 | from eval.cls_eval import validate 36 | 37 | from models.resnet import resnet12 38 | import numpy as np 39 | from util import Logger 40 | import wandb 41 | from dataloader import get_dataloaders 42 | import copy 43 | 44 | def get_freer_gpu(): 45 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 46 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 47 | return np.argmax(memory_available) 48 | 49 | os.environ["CUDA_VISIBLE_DEVICES"]=str(get_freer_gpu()) 50 | # os.environ['OPENBLAS_NUM_THREADS'] = '4' 51 | mkl.set_num_threads(2) 52 | 53 | 54 | class Wrapper(nn.Module): 55 | 56 | def __init__(self, model, args): 57 | super(Wrapper, self).__init__() 58 | 59 | self.model = model 60 | self.feat = torch.nn.Sequential(*list(self.model.children())[:-2]) 61 | 62 | self.last = torch.nn.Linear(list(self.model.children())[-2].in_features, 64) 63 | 64 | def forward(self, images): 65 | feat = self.feat(images) 66 | feat = feat.view(images.size(0), -1) 67 | out = self.last(feat) 68 | 69 | return feat, out 70 | 71 | 72 | 73 | def parse_option(): 74 | 75 | parser = argparse.ArgumentParser('argument for training') 76 | 77 | parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency') 78 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 79 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 80 | parser.add_argument('--save_freq', type=int, default=10, help='save frequency') 81 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 82 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') 83 | parser.add_argument('--epochs', type=int, default=100, help='number of training epochs') 84 | 85 | # optimization 86 | parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate') 87 | parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list') 88 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 89 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 90 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 91 | 92 | # dataset and model 93 | parser.add_argument('--model_s', type=str, default='resnet12', choices=model_pool) 94 | parser.add_argument('--model_t', type=str, default='resnet12', choices=model_pool) 95 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet', 96 | 'CIFAR-FS', 'FC100']) 97 | parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation') 98 | parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning') 99 | parser.add_argument('--tags', type=str, default="gen1, ssl", help='add tags for the experiment') 100 | parser.add_argument('--transform', type=str, default='A', choices=transforms_list) 101 | 102 | # path to teacher model 103 | parser.add_argument('--path_t', type=str, default="", help='teacher model snapshot') 104 | 105 | # distillation 106 | parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'contrast', 'hint', 'attention']) 107 | parser.add_argument('--trial', type=str, default='1', help='trial id') 108 | 109 | parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification') 110 | parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for KD') 111 | parser.add_argument('-b', '--beta', type=float, default=0, help='weight balance for other losses') 112 | 113 | # KL distillation 114 | parser.add_argument('--kd_T', type=float, default=2, help='temperature for KD distillation') 115 | # NCE distillation 116 | parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension') 117 | parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE') 118 | parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax') 119 | parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates') 120 | 121 | # cosine annealing 122 | parser.add_argument('--cosine', action='store_true', help='using cosine annealing') 123 | 124 | # specify folder 125 | parser.add_argument('--model_path', type=str, default='save/', help='path to save model') 126 | parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard') 127 | parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root') 128 | 129 | # setting for meta-learning 130 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N', 131 | help='Number of test runs') 132 | parser.add_argument('--n_ways', type=int, default=5, metavar='N', 133 | help='Number of classes for doing each classification run') 134 | parser.add_argument('--n_shots', type=int, default=1, metavar='N', 135 | help='Number of shots in test') 136 | parser.add_argument('--n_queries', type=int, default=15, metavar='N', 137 | help='Number of query in test') 138 | parser.add_argument('--n_aug_support_samples', default=5, type=int, 139 | help='The number of augmented samples for each meta test sample') 140 | parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size', 141 | help='Size of test batch)') 142 | 143 | opt = parser.parse_args() 144 | 145 | if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': 146 | opt.transform = 'D' 147 | 148 | if 'trainval' in opt.path_t: 149 | opt.use_trainval = True 150 | else: 151 | opt.use_trainval = False 152 | 153 | if opt.use_trainval: 154 | opt.trial = opt.trial + '_trainval' 155 | 156 | # set the path according to the environment 157 | if not opt.model_path: 158 | opt.model_path = './models_distilled' 159 | if not opt.tb_path: 160 | opt.tb_path = './tensorboard' 161 | if not opt.data_root: 162 | opt.data_root = './data/{}'.format(opt.dataset) 163 | else: 164 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset) 165 | opt.data_aug = True 166 | 167 | tags = opt.tags.split(',') 168 | opt.tags = list([]) 169 | for it in tags: 170 | opt.tags.append(it) 171 | 172 | iterations = opt.lr_decay_epochs.split(',') 173 | opt.lr_decay_epochs = list([]) 174 | for it in iterations: 175 | opt.lr_decay_epochs.append(int(it)) 176 | 177 | opt.model_name = 'S:{}_T:{}_{}_{}_r:{}_a:{}_b:{}_trans_{}'.format(opt.model_s, opt.model_t, opt.dataset, 178 | opt.distill, opt.gamma, opt.alpha, opt.beta, 179 | opt.transform) 180 | 181 | if opt.cosine: 182 | opt.model_name = '{}_cosine'.format(opt.model_name) 183 | 184 | opt.model_name = '{}_{}'.format(opt.model_name, opt.trial) 185 | 186 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 187 | if not os.path.isdir(opt.tb_folder): 188 | os.makedirs(opt.tb_folder) 189 | 190 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 191 | if not os.path.isdir(opt.save_folder): 192 | os.makedirs(opt.save_folder) 193 | 194 | #extras 195 | opt.fresh_start = True 196 | 197 | 198 | return opt 199 | 200 | 201 | 202 | 203 | 204 | def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'): 205 | """load the teacher model""" 206 | print('==> loading teacher model') 207 | print(model_name) 208 | model = create_model(model_name, n_cls, dataset) 209 | model.load_state_dict(torch.load(model_path)['model']) 210 | print('==> done') 211 | return model 212 | 213 | 214 | def main(): 215 | best_acc = 0 216 | 217 | opt = parse_option() 218 | wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags) 219 | wandb.config.update(opt) 220 | wandb.save('*.py') 221 | wandb.run.save() 222 | 223 | 224 | # dataloader 225 | train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt) 226 | 227 | # model 228 | model_t = [] 229 | if("," in opt.path_t): 230 | for path in opt.path_t.split(","): 231 | model_t.append(load_teacher(path, opt.model_t, n_cls, opt.dataset)) 232 | else: 233 | model_t.append(load_teacher(opt.path_t, opt.model_t, n_cls, opt.dataset)) 234 | 235 | # model_s = create_model(opt.model_s, n_cls, opt.dataset, dropout=0.4) 236 | # model_s = Wrapper(model_, opt) 237 | model_s = copy.deepcopy(model_t[0]) 238 | 239 | criterion_cls = nn.CrossEntropyLoss() 240 | criterion_div = DistillKL(opt.kd_T) 241 | criterion_kd = DistillKL(opt.kd_T) 242 | 243 | optimizer = optim.SGD(model_s.parameters(), 244 | lr=opt.learning_rate, 245 | momentum=opt.momentum, 246 | weight_decay=opt.weight_decay) 247 | 248 | 249 | 250 | 251 | if torch.cuda.is_available(): 252 | for m in model_t: 253 | m.cuda() 254 | model_s.cuda() 255 | criterion_cls = criterion_cls.cuda() 256 | criterion_div = criterion_div.cuda() 257 | criterion_kd = criterion_kd.cuda() 258 | cudnn.benchmark = True 259 | 260 | 261 | meta_test_acc = 0 262 | meta_test_std = 0 263 | # routine: supervised model distillation 264 | for epoch in range(1, opt.epochs + 1): 265 | 266 | if opt.cosine: 267 | scheduler.step() 268 | else: 269 | adjust_learning_rate(epoch, opt, optimizer) 270 | print("==> training...") 271 | 272 | time1 = time.time() 273 | train_acc, train_loss = train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt) 274 | time2 = time.time() 275 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 276 | 277 | val_acc = 0 278 | val_loss = 0 279 | meta_val_acc = 0 280 | meta_val_std = 0 281 | # val_acc, val_acc_top5, val_loss = validate(val_loader, model_s, criterion_cls, opt) 282 | 283 | 284 | # #evaluate 285 | # start = time.time() 286 | # meta_val_acc, meta_val_std = meta_test(model_s, meta_valloader) 287 | # test_time = time.time() - start 288 | # print('Meta Val Acc: {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(meta_val_acc, meta_val_std, test_time)) 289 | 290 | #evaluate 291 | 292 | start = time.time() 293 | meta_test_acc, meta_test_std = meta_test(model_s, meta_testloader, use_logit=False) 294 | test_time = time.time() - start 295 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.format(meta_test_acc, meta_test_std, test_time)) 296 | 297 | 298 | # regular saving 299 | if epoch % opt.save_freq == 0 or epoch==opt.epochs: 300 | print('==> Saving...') 301 | state = { 302 | 'epoch': epoch, 303 | 'model': model_s.state_dict(), 304 | } 305 | save_file = os.path.join(opt.save_folder, 'model_'+str(wandb.run.name)+'.pth') 306 | torch.save(state, save_file) 307 | 308 | #wandb saving 309 | torch.save(state, os.path.join(wandb.run.dir, "model.pth")) 310 | 311 | wandb.log({'epoch': epoch, 312 | 'Train Acc': train_acc, 313 | 'Train Loss':train_loss, 314 | 'Val Acc': val_acc, 315 | 'Val Loss':val_loss, 316 | 'Meta Test Acc': meta_test_acc, 317 | 'Meta Test std': meta_test_std, 318 | 'Meta Val Acc': meta_val_acc, 319 | 'Meta Val std': meta_val_std 320 | }) 321 | 322 | #final report 323 | generate_final_report(model_s, opt, wandb) 324 | 325 | #remove output.txt log file 326 | output_log_file = os.path.join(wandb.run.dir, "output.log") 327 | if os.path.isfile(output_log_file): 328 | os.remove(output_log_file) 329 | else: ## Show an error ## 330 | print("Error: %s file not found" % output_log_file) 331 | 332 | 333 | 334 | 335 | def train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt): 336 | """One epoch training""" 337 | model_s.train() 338 | for m in model_t: 339 | m.eval() 340 | 341 | batch_time = AverageMeter() 342 | data_time = AverageMeter() 343 | losses = AverageMeter() 344 | top1 = AverageMeter() 345 | top5 = AverageMeter() 346 | 347 | end = time.time() 348 | 349 | with tqdm(train_loader, total=len(train_loader)) as pbar: 350 | for idx, data in enumerate(pbar): 351 | 352 | inputs, targets, _ = data 353 | data_time.update(time.time() - end) 354 | 355 | inputs = inputs.float() 356 | if torch.cuda.is_available(): 357 | inputs = inputs.cuda() 358 | targets = targets.cuda() 359 | 360 | batch_size = inputs.size()[0] 361 | x = inputs 362 | 363 | x_90 = x.transpose(2,3).flip(2) 364 | # x_180 = x.flip(2).flip(3) 365 | # x_270 = x.flip(2).transpose(2,3) 366 | # inputs_aug = torch.cat((x_90, x_180, x_270),0) 367 | 368 | 369 | # sampled_inputs = inputs_aug[torch.randperm(3*batch_size)[:batch_size]] 370 | inputs_all = torch.cat((x, x_180, x_90, x_270),0) 371 | 372 | # ===================forward===================== 373 | 374 | with torch.no_grad(): 375 | (_,_,_,_, feat_t), (logit_t, rot_t) = model_t[0](inputs_all[:batch_size], rot=True) 376 | 377 | (_,_,_,_, feat_s_all), (logit_s_all, rot_s_all) = model_s(inputs_all[:4*batch_size], rot=True) 378 | 379 | loss_div = criterion_div(logit_s_all[:batch_size], logit_t[:batch_size]) 380 | 381 | d_90 = logit_s_all[batch_size:2*batch_size] - logit_s_all[:batch_size] 382 | loss_a = torch.mean(torch.sqrt(torch.sum((d_90)**2, dim=1))) 383 | # d_180 = logit_s_all[2*batch_size:3*batch_size] - logit_s_all[:batch_size] 384 | # loss_a += torch.mean(torch.sqrt(torch.sum((d_180)**2, dim=1))) 385 | # d_270 = logit_s_all[3*batch_size:4*batch_size] - logit_s_all[:batch_size] 386 | # loss_a += torch.mean(torch.sqrt(torch.sum((d_270)**2, dim=1))) 387 | 388 | 389 | if(torch.isnan(loss_a).any()): 390 | break 391 | else: 392 | loss = loss_div + opt.gamma*loss_a / 3 393 | 394 | 395 | acc1, acc5 = accuracy(logit_s_all[:batch_size], targets, topk=(1, 5)) 396 | losses.update(loss.item(), inputs.size(0)) 397 | top1.update(acc1[0], inputs.size(0)) 398 | top5.update(acc5[0], inputs.size(0)) 399 | 400 | # ===================backward===================== 401 | optimizer.zero_grad() 402 | loss.backward() 403 | optimizer.step() 404 | 405 | # ===================meters===================== 406 | batch_time.update(time.time() - end) 407 | end = time.time() 408 | 409 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()), 410 | "Acc@5":'{0:.2f}'.format(top5.avg.cpu().numpy(),2), 411 | "Loss" :'{0:.2f}'.format(losses.avg,2), 412 | }) 413 | 414 | 415 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 416 | .format(top1=top1, top5=top5)) 417 | 418 | return top1.avg, losses.avg 419 | 420 | 421 | def generate_final_report(model, opt, wandb): 422 | 423 | 424 | opt.n_shots = 1 425 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt) 426 | 427 | #validate 428 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True) 429 | 430 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) 431 | 432 | #evaluate 433 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True) 434 | 435 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) 436 | 437 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std)) 438 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat)) 439 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std)) 440 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat)) 441 | 442 | 443 | wandb.log({'Final Meta Test Acc @1': meta_test_acc, 444 | 'Final Meta Test std @1': meta_test_std, 445 | 'Final Meta Test Acc (feat) @1': meta_test_acc_feat, 446 | 'Final Meta Test std (feat) @1': meta_test_std_feat, 447 | 'Final Meta Val Acc @1': meta_val_acc, 448 | 'Final Meta Val std @1': meta_val_std, 449 | 'Final Meta Val Acc (feat) @1': meta_val_acc_feat, 450 | 'Final Meta Val std (feat) @1': meta_val_std_feat 451 | }) 452 | 453 | 454 | opt.n_shots = 5 455 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt) 456 | 457 | #validate 458 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True) 459 | 460 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) 461 | 462 | #evaluate 463 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True) 464 | 465 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) 466 | 467 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std)) 468 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat)) 469 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std)) 470 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat)) 471 | 472 | wandb.log({'Final Meta Test Acc @5': meta_test_acc, 473 | 'Final Meta Test std @5': meta_test_std, 474 | 'Final Meta Test Acc (feat) @5': meta_test_acc_feat, 475 | 'Final Meta Test std (feat) @5': meta_test_std_feat, 476 | 'Final Meta Val Acc @5': meta_val_acc, 477 | 'Final Meta Val std @5': meta_val_std, 478 | 'Final Meta Val Acc (feat) @5': meta_val_acc_feat, 479 | 'Final Meta Val std (feat) @5': meta_val_std_feat 480 | }) 481 | 482 | 483 | if __name__ == '__main__': 484 | main() 485 | -------------------------------------------------------------------------------- /train_selfsupervison.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import socket 6 | import time 7 | import sys 8 | from tqdm import tqdm 9 | import mkl 10 | 11 | # import tensorboard_logger as tb_logger 12 | import torch 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | from torch.utils.data import DataLoader 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | 20 | from models import model_pool 21 | from models.util import create_model 22 | 23 | from dataset.mini_imagenet import ImageNet, MetaImageNet 24 | from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet 25 | from dataset.cifar import CIFAR100, MetaCIFAR100 26 | from dataset.transform_cfg import transforms_options, transforms_test_options, transforms_list 27 | 28 | from util import adjust_learning_rate, accuracy, AverageMeter 29 | from eval.meta_eval import meta_test, meta_test_tune 30 | from eval.cls_eval import validate 31 | 32 | from models.resnet import resnet12 33 | import numpy as np 34 | from util import Logger 35 | import wandb 36 | from dataloader import get_dataloaders 37 | 38 | def get_freer_gpu(): 39 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 40 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 41 | return np.argmax(memory_available) 42 | 43 | os.environ["CUDA_VISIBLE_DEVICES"]=str(get_freer_gpu()) 44 | mkl.set_num_threads(2) 45 | 46 | 47 | def parse_option(): 48 | 49 | parser = argparse.ArgumentParser('argument for training') 50 | 51 | parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency') 52 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 53 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 54 | parser.add_argument('--save_freq', type=int, default=10, help='save frequency') 55 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 56 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') 57 | parser.add_argument('--epochs', type=int, default=100, help='number of training epochs') 58 | 59 | # optimization 60 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') 61 | parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list') 62 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 63 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 64 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 65 | parser.add_argument('--adam', action='store_true', help='use adam optimizer') 66 | parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation') 67 | parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning') 68 | parser.add_argument('--tags', type=str, default="gen0, ssl", help='add tags for the experiment') 69 | 70 | 71 | # dataset 72 | parser.add_argument('--model', type=str, default='resnet12', choices=model_pool) 73 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet', 74 | 'CIFAR-FS', 'FC100']) 75 | parser.add_argument('--transform', type=str, default='A', choices=transforms_list) 76 | parser.add_argument('--use_trainval', type=bool, help='use trainval set') 77 | 78 | # cosine annealing 79 | parser.add_argument('--cosine', action='store_true', help='using cosine annealing') 80 | 81 | # specify folder 82 | parser.add_argument('--model_path', type=str, default='save/', help='path to save model') 83 | parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard') 84 | parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root') 85 | 86 | # meta setting 87 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N', 88 | help='Number of test runs') 89 | parser.add_argument('--n_ways', type=int, default=5, metavar='N', 90 | help='Number of classes for doing each classification run') 91 | parser.add_argument('--n_shots', type=int, default=1, metavar='N', 92 | help='Number of shots in test') 93 | parser.add_argument('--n_queries', type=int, default=15, metavar='N', 94 | help='Number of query in test') 95 | parser.add_argument('--n_aug_support_samples', default=5, type=int, 96 | help='The number of augmented samples for each meta test sample') 97 | parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size', 98 | help='Size of test batch)') 99 | 100 | parser.add_argument('-t', '--trial', type=str, default='1', help='the experiment id') 101 | 102 | 103 | 104 | #hyper parameters 105 | parser.add_argument('--gamma', type=float, default=2, help='loss cofficient for ssl loss') 106 | 107 | opt = parser.parse_args() 108 | 109 | if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100': 110 | opt.transform = 'D' 111 | 112 | if opt.use_trainval: 113 | opt.trial = opt.trial + '_trainval' 114 | 115 | # set the path according to the environment 116 | if not opt.model_path: 117 | opt.model_path = './models_pretrained' 118 | if not opt.tb_path: 119 | opt.tb_path = './tensorboard' 120 | if not opt.data_root: 121 | opt.data_root = './data/{}'.format(opt.dataset) 122 | else: 123 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset) 124 | opt.data_aug = True 125 | 126 | iterations = opt.lr_decay_epochs.split(',') 127 | opt.lr_decay_epochs = list([]) 128 | for it in iterations: 129 | opt.lr_decay_epochs.append(int(it)) 130 | 131 | tags = opt.tags.split(',') 132 | opt.tags = list([]) 133 | for it in tags: 134 | opt.tags.append(it) 135 | 136 | opt.model_name = '{}_{}_lr_{}_decay_{}_trans_{}'.format(opt.model, opt.dataset, opt.learning_rate, 137 | opt.weight_decay, opt.transform) 138 | 139 | if opt.cosine: 140 | opt.model_name = '{}_cosine'.format(opt.model_name) 141 | 142 | if opt.adam: 143 | opt.model_name = '{}_useAdam'.format(opt.model_name) 144 | 145 | opt.model_name = '{}_trial_{}'.format(opt.model_name, opt.trial) 146 | 147 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 148 | if not os.path.isdir(opt.tb_folder): 149 | os.makedirs(opt.tb_folder) 150 | 151 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 152 | if not os.path.isdir(opt.save_folder): 153 | os.makedirs(opt.save_folder) 154 | 155 | opt.n_gpu = torch.cuda.device_count() 156 | 157 | 158 | #extras 159 | opt.fresh_start = True 160 | 161 | 162 | return opt 163 | 164 | 165 | def main(): 166 | 167 | opt = parse_option() 168 | wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags) 169 | wandb.config.update(opt) 170 | wandb.save('*.py') 171 | wandb.run.save() 172 | 173 | 174 | train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt) 175 | 176 | # model 177 | model = create_model(opt.model, n_cls, opt.dataset) 178 | wandb.watch(model) 179 | 180 | # optimizer 181 | if opt.adam: 182 | print("Adam") 183 | optimizer = torch.optim.Adam(model.parameters(), 184 | lr=opt.learning_rate, 185 | weight_decay=0.0005) 186 | else: 187 | print("SGD") 188 | optimizer = optim.SGD(model.parameters(), 189 | lr=opt.learning_rate, 190 | momentum=opt.momentum, 191 | weight_decay=opt.weight_decay) 192 | 193 | 194 | 195 | criterion = nn.CrossEntropyLoss() 196 | 197 | if torch.cuda.is_available(): 198 | if opt.n_gpu > 1: 199 | model = nn.DataParallel(model) 200 | model = model.cuda() 201 | criterion = criterion.cuda() 202 | cudnn.benchmark = True 203 | 204 | # set cosine annealing scheduler 205 | if opt.cosine: 206 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 207 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1) 208 | 209 | # routine: supervised pre-training 210 | for epoch in range(1, opt.epochs + 1): 211 | if opt.cosine: 212 | scheduler.step() 213 | else: 214 | adjust_learning_rate(epoch, opt, optimizer) 215 | print("==> training...") 216 | 217 | 218 | time1 = time.time() 219 | train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt) 220 | time2 = time.time() 221 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 222 | 223 | 224 | val_acc, val_acc_top5, val_loss = 0,0,0 #validate(val_loader, model, criterion, opt) 225 | 226 | 227 | #validate 228 | start = time.time() 229 | meta_val_acc, meta_val_std = 0,0#meta_test(model, meta_valloader) 230 | test_time = time.time() - start 231 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(meta_val_acc, meta_val_std, test_time)) 232 | 233 | #evaluate 234 | start = time.time() 235 | meta_test_acc, meta_test_std = 0,0#meta_test(model, meta_testloader) 236 | test_time = time.time() - start 237 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.format(meta_test_acc, meta_test_std, test_time)) 238 | 239 | 240 | # regular saving 241 | if epoch % opt.save_freq == 0 or epoch==opt.epochs: 242 | print('==> Saving...') 243 | state = { 244 | 'epoch': epoch, 245 | 'optimizer': optimizer.state_dict(), 246 | 'model': model.state_dict(), 247 | } 248 | save_file = os.path.join(opt.save_folder, 'model_'+str(wandb.run.name)+'.pth') 249 | torch.save(state, save_file) 250 | 251 | #wandb saving 252 | torch.save(state, os.path.join(wandb.run.dir, "model.pth")) 253 | 254 | ## onnx saving 255 | #dummy_input = torch.autograd.Variable(torch.randn(1, 3, 32, 32)).cuda() 256 | #torch.onnx.export(model, dummy_input, os.path.join(wandb.run.dir, "model.onnx")) 257 | 258 | wandb.log({'epoch': epoch, 259 | 'Train Acc': train_acc, 260 | 'Train Loss':train_loss, 261 | 'Val Acc': val_acc, 262 | 'Val Loss':val_loss, 263 | 'Meta Test Acc': meta_test_acc, 264 | 'Meta Test std': meta_test_std, 265 | 'Meta Val Acc': meta_val_acc, 266 | 'Meta Val std': meta_val_std 267 | }) 268 | 269 | #final report 270 | generate_final_report(model, opt, wandb) 271 | 272 | #remove output.txt log file 273 | output_log_file = os.path.join(wandb.run.dir, "output.log") 274 | if os.path.isfile(output_log_file): 275 | os.remove(output_log_file) 276 | else: ## Show an error ## 277 | print("Error: %s file not found" % output_log_file) 278 | 279 | 280 | 281 | def train(epoch, train_loader, model, criterion, optimizer, opt): 282 | """One epoch training""" 283 | model.train() 284 | 285 | batch_time = AverageMeter() 286 | data_time = AverageMeter() 287 | losses = AverageMeter() 288 | top1 = AverageMeter() 289 | top5 = AverageMeter() 290 | 291 | end = time.time() 292 | with tqdm(train_loader, total=len(train_loader)) as pbar: 293 | for idx, (input, target, _) in enumerate(pbar): 294 | data_time.update(time.time() - end) 295 | 296 | input = input.float() 297 | if torch.cuda.is_available(): 298 | input = input.cuda() 299 | target = target.cuda() 300 | 301 | 302 | batch_size = input.size()[0] 303 | x = input 304 | x_90 = x.transpose(2,3).flip(2) 305 | x_180 = x.flip(2).flip(3) 306 | x_270 = x.flip(2).transpose(2,3) 307 | generated_data = torch.cat((x, x_90, x_180, x_270),0) 308 | train_targets = target.repeat(4) 309 | 310 | rot_labels = torch.zeros(4*batch_size).cuda().long() 311 | for i in range(4*batch_size): 312 | if i < batch_size: 313 | rot_labels[i] = 0 314 | elif i < 2*batch_size: 315 | rot_labels[i] = 1 316 | elif i < 3*batch_size: 317 | rot_labels[i] = 2 318 | else: 319 | rot_labels[i] = 3 320 | 321 | # ===================forward===================== 322 | 323 | (_,_,_,_, feat), (train_logit, rot_logits) = model(generated_data, rot=True) 324 | 325 | rot_labels = F.one_hot(rot_labels.to(torch.int64), 4).float() 326 | loss_ss = torch.sum(F.binary_cross_entropy_with_logits(input = rot_logits, target = rot_labels)) 327 | loss_ce = criterion(train_logit, train_targets) 328 | 329 | loss = opt.gamma * loss_ss + loss_ce 330 | 331 | acc1, acc5 = accuracy(train_logit, train_targets, topk=(1, 5)) 332 | losses.update(loss.item(), input.size(0)) 333 | top1.update(acc1[0], input.size(0)) 334 | top5.update(acc5[0], input.size(0)) 335 | 336 | # ===================backward===================== 337 | optimizer.zero_grad() 338 | loss.backward() 339 | optimizer.step() 340 | 341 | # ===================meters===================== 342 | batch_time.update(time.time() - end) 343 | end = time.time() 344 | 345 | 346 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()), 347 | "Acc@5":'{0:.2f}'.format(top5.avg.cpu().numpy(),2), 348 | "Loss" :'{0:.2f}'.format(losses.avg,2), 349 | }) 350 | 351 | print('Train_Acc@1 {top1.avg:.3f} Train_Acc@5 {top5.avg:.3f}' 352 | .format(top1=top1, top5=top5)) 353 | 354 | return top1.avg, losses.avg 355 | 356 | 357 | 358 | def generate_final_report(model, opt, wandb): 359 | 360 | 361 | opt.n_shots = 1 362 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt) 363 | 364 | #validate 365 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True) 366 | 367 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) 368 | 369 | #evaluate 370 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True) 371 | 372 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) 373 | 374 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std)) 375 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat)) 376 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std)) 377 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat)) 378 | 379 | 380 | wandb.log({'Final Meta Test Acc @1': meta_test_acc, 381 | 'Final Meta Test std @1': meta_test_std, 382 | 'Final Meta Test Acc (feat) @1': meta_test_acc_feat, 383 | 'Final Meta Test std (feat) @1': meta_test_std_feat, 384 | 'Final Meta Val Acc @1': meta_val_acc, 385 | 'Final Meta Val std @1': meta_val_std, 386 | 'Final Meta Val Acc (feat) @1': meta_val_acc_feat, 387 | 'Final Meta Val std (feat) @1': meta_val_std_feat 388 | }) 389 | 390 | 391 | opt.n_shots = 5 392 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt) 393 | 394 | #validate 395 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True) 396 | 397 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) 398 | 399 | #evaluate 400 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True) 401 | 402 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) 403 | 404 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std)) 405 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat)) 406 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std)) 407 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat)) 408 | 409 | wandb.log({'Final Meta Test Acc @5': meta_test_acc, 410 | 'Final Meta Test std @5': meta_test_std, 411 | 'Final Meta Test Acc (feat) @5': meta_test_acc_feat, 412 | 'Final Meta Test std (feat) @5': meta_test_std_feat, 413 | 'Final Meta Val Acc @5': meta_val_acc, 414 | 'Final Meta Val std @5': meta_val_std, 415 | 'Final Meta Val Acc (feat) @5': meta_val_acc_feat, 416 | 'Final Meta Val std (feat) @5': meta_val_std_feat 417 | }) 418 | 419 | 420 | if __name__ == '__main__': 421 | main() 422 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import os 7 | import sys 8 | from dataloader import get_dataloaders 9 | 10 | 11 | class LabelSmoothing(nn.Module): 12 | """ 13 | NLL loss with label smoothing. 14 | """ 15 | def __init__(self, smoothing=0.0): 16 | """ 17 | Constructor for the LabelSmoothing module. 18 | :param smoothing: label smoothing factor 19 | """ 20 | super(LabelSmoothing, self).__init__() 21 | self.confidence = 1.0 - smoothing 22 | self.smoothing = smoothing 23 | 24 | def forward(self, x, target): 25 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 26 | 27 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 28 | nll_loss = nll_loss.squeeze(1) 29 | smooth_loss = -logprobs.mean(dim=-1) 30 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 31 | return loss.mean() 32 | 33 | 34 | class BCEWithLogitsLoss(nn.Module): 35 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, num_classes=64): 36 | super(BCEWithLogitsLoss, self).__init__() 37 | self.num_classes = num_classes 38 | self.criterion = nn.BCEWithLogitsLoss(weight=weight, 39 | size_average=size_average, 40 | reduce=reduce, 41 | reduction=reduction, 42 | pos_weight=pos_weight) 43 | def forward(self, input, target): 44 | target_onehot = F.one_hot(target, num_classes=self.num_classes) 45 | return self.criterion(input, target_onehot) 46 | 47 | 48 | class AverageMeter(object): 49 | """Computes and stores the average and current value""" 50 | def __init__(self): 51 | self.reset() 52 | 53 | def reset(self): 54 | self.val = 0 55 | self.avg = 0 56 | self.sum = 0 57 | self.count = 0 58 | 59 | def update(self, val, n=1): 60 | self.val = val 61 | self.sum += val * n 62 | self.count += n 63 | self.avg = self.sum / self.count 64 | 65 | 66 | def adjust_learning_rate(epoch, opt, optimizer): 67 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 68 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 69 | if steps > 0: 70 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 71 | for param_group in optimizer.param_groups: 72 | param_group['lr'] = new_lr 73 | 74 | 75 | def accuracy(output, target, topk=(1,)): 76 | """Computes the accuracy over the k top predictions for the specified values of k""" 77 | with torch.no_grad(): 78 | maxk = max(topk) 79 | batch_size = target.size(0) 80 | 81 | _, pred = output.topk(maxk, 1, True, True) 82 | pred = pred.t() 83 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 84 | 85 | res = [] 86 | for k in topk: 87 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 88 | res.append(correct_k.mul_(100.0 / batch_size)) 89 | return res 90 | 91 | 92 | 93 | 94 | class Logger(object): 95 | '''Save training process to log file with simple plot function.''' 96 | def __init__(self, fpath, title=None, resume=False): 97 | self.file = None 98 | self.resume = resume 99 | self.title = '' if title == None else title 100 | if fpath is not None: 101 | if resume: 102 | self.file = open(fpath, 'r') 103 | name = self.file.readline() 104 | self.names = name.rstrip().split('\t') 105 | self.numbers = {} 106 | for _, name in enumerate(self.names): 107 | self.numbers[name] = [] 108 | 109 | for numbers in self.file: 110 | numbers = numbers.rstrip().split('\t') 111 | for i in range(0, len(numbers)): 112 | self.numbers[self.names[i]].append(numbers[i]) 113 | self.file.close() 114 | self.file = open(fpath, 'a') 115 | else: 116 | self.file = open(fpath, 'w') 117 | 118 | def set_names(self, names): 119 | if self.resume: 120 | pass 121 | # initialize numbers as empty list 122 | self.numbers = {} 123 | self.names = names 124 | for _, name in enumerate(self.names): 125 | self.file.write(name) 126 | self.file.write('\t') 127 | self.numbers[name] = [] 128 | self.file.write('\n') 129 | self.file.flush() 130 | 131 | 132 | def append(self, numbers): 133 | assert len(self.names) == len(numbers), 'Numbers do not match names' 134 | for index, num in enumerate(numbers): 135 | self.file.write("{0:.6f}".format(num)) 136 | self.file.write('\t') 137 | self.numbers[self.names[index]].append(num) 138 | self.file.write('\n') 139 | self.file.flush() 140 | 141 | def plot(self, names=None): 142 | names = self.names if names == None else names 143 | numbers = self.numbers 144 | for _, name in enumerate(names): 145 | x = np.arange(len(numbers[name])) 146 | plt.plot(x, np.asarray(numbers[name])) 147 | plt.legend([self.title + '(' + name + ')' for name in names]) 148 | plt.grid(True) 149 | 150 | 151 | def close(self): 152 | if self.file is not None: 153 | self.file.close() 154 | 155 | 156 | 157 | 158 | 159 | 160 | def generate_final_report(model, opt, wandb): 161 | from eval.meta_eval import meta_test 162 | 163 | opt.n_shots = 1 164 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt) 165 | 166 | #validate 167 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader) 168 | 169 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) 170 | 171 | #evaluate 172 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader) 173 | 174 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) 175 | 176 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std)) 177 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat)) 178 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std)) 179 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat)) 180 | 181 | 182 | wandb.log({'Final Meta Test Acc @1': meta_test_acc, 183 | 'Final Meta Test std @1': meta_test_std, 184 | 'Final Meta Test Acc (feat) @1': meta_test_acc_feat, 185 | 'Final Meta Test std (feat) @1': meta_test_std_feat, 186 | 'Final Meta Val Acc @1': meta_val_acc, 187 | 'Final Meta Val std @1': meta_val_std, 188 | 'Final Meta Val Acc (feat) @1': meta_val_acc_feat, 189 | 'Final Meta Val std (feat) @1': meta_val_std_feat 190 | }) 191 | 192 | 193 | opt.n_shots = 5 194 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt) 195 | 196 | #validate 197 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader) 198 | 199 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False) 200 | 201 | #evaluate 202 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader) 203 | 204 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False) 205 | 206 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std)) 207 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat)) 208 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std)) 209 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat)) 210 | 211 | wandb.log({'Final Meta Test Acc @5': meta_test_acc, 212 | 'Final Meta Test std @5': meta_test_std, 213 | 'Final Meta Test Acc (feat) @5': meta_test_acc_feat, 214 | 'Final Meta Test std (feat) @5': meta_test_std_feat, 215 | 'Final Meta Val Acc @5': meta_val_acc, 216 | 'Final Meta Val std @5': meta_val_std, 217 | 'Final Meta Val Acc (feat) @5': meta_val_acc_feat, 218 | 'Final Meta Val std (feat) @5': meta_val_std_feat 219 | }) -------------------------------------------------------------------------------- /utils/figs/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/main.png -------------------------------------------------------------------------------- /utils/figs/results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/results1.png -------------------------------------------------------------------------------- /utils/figs/results2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/results2.png -------------------------------------------------------------------------------- /utils/figs/training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/training.png --------------------------------------------------------------------------------