├── .gitignore
├── README.md
├── dataloader.py
├── dataset
├── cifar.py
├── dataset_selfsupervision.py
├── mini_imagenet.py
├── tiered_imagenet.py
└── transform_cfg.py
├── distill
├── NCEAverage.py
├── NCECriterion.py
├── __init__.py
├── alias_multinomial.py
├── criterion.py
└── util.py
├── eval
├── __init__.py
├── cls_eval.py
├── meta_eval.py
└── util.py
├── eval_fewshot.py
├── models
├── __init__.py
├── convnet.py
├── resnet.py
├── resnet_new.py
├── resnet_sd.py
├── resnet_selfdist.py
├── resnet_ssl.py
├── util.py
└── wresnet.py
├── requirements.txt
├── run.sh
├── train_distillation.py
├── train_selfsupervison.py
├── util.py
└── utils
└── figs
├── main.png
├── results1.png
├── results2.png
└── training.png
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | data/
3 | output*/
4 | ckpts/
5 | *.pth
6 | *.t7
7 | *.jpg
8 | tmp*.py
9 | # run*.sh
10 | *.pdf
11 | *.npy
12 | models_distilled/
13 | tensorboard/
14 | models_pretrained/
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 | MANIFEST
42 |
43 | # PyInstaller
44 | # Usually these files are written by a python script from a template
45 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
46 | *.manifest
47 | *.spec
48 |
49 | # Installer logs
50 | pip-log.txt
51 | pip-delete-this-directory.txt
52 |
53 | # Unit test / coverage reports
54 | htmlcov/
55 | .tox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # celery beat schedule file
94 | celerybeat-schedule
95 |
96 | # SageMath parsed files
97 | *.sage.py
98 |
99 | # Environments
100 | .env
101 | .venv
102 | env/
103 | venv/
104 | ENV/
105 | env.bak/
106 | venv.bak/
107 |
108 | # Spyder project settings
109 | .spyderproject
110 | .spyproject
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 |
118 | # mypy
119 | .mypy_cache/
120 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SKD : Self-supervised Knowledge Distillation for Few-shot Learning
2 | Official implementation of "SKD : Self-supervised Knowledge Distillation for Few-shot Learning". [(paper link)](https://arxiv.org/abs/2006.09785). The paper reports state-of-the-art results on four popular few-shot learning benchmarks.
3 |
4 | Real-world contains an overwhelmingly large number of object classes, learning all of which at once is impossible. Few shot learning is a promising learning paradigm due to its ability to learn out of order distributions quickly with only a few samples. Recent works show that simply learning a good feature embedding can outperform more sophisticated meta-learning and metric learning algorithms. In this paper, we propose a simple approach to improve the representation capacity of deep neural networks for few-shot learning tasks. We follow a two-stage learning process: First, we train a neural network to maximize the entropy of the feature embedding, thus creating an optimal output manifold using self-supervision as an auxiliary loss. In the second stage, we minimize the entropy on feature embedding by bringing self-supervised twins together, while constraining the manifold with student-teacher distillation. Our experiments show that, even in the first stage, auxiliary self-supervision can outperform current state-of-the-art methods, with further gains achieved by our second stage distillation process.
5 |
6 | This official code provides an implementation for our SKD. This repository is implemented using PyTorch and it includes code for running the few-shot learning experiments on **CIFAR-FS**, **FC-100**, **miniImageNet** and **tieredImageNet** datasets.
7 |
8 |

9 | (a) SKD has two stage learning. In Gen-0
, self-supervision is used to estimate the true prediction manifold, equivariant to input transformations. Specifically, we enforce the model to predict the amount of input rotation using only the output logits. In Gen-1
, we force the original sample outputs to be the same as in Gen-0
(dotted lines), while reducing its distance with its augmented versions to enhance discriminability.
10 |
11 | 
12 | (b) SKD training pipeline.
13 |
14 | ## Dependencies
15 | This code requires the following:
16 | * matplotlib==3.2.1
17 | * mkl==2019.0
18 | * numpy==1.18.4
19 | * Pillow==7.1.2
20 | * scikit_learn==0.23.1
21 | * scipy==1.4.1
22 | * torch==1.5.0
23 | * torchvision==0.6.0
24 | * tqdm==4.46.0
25 | * wandb==0.8.36
26 |
27 | run `pip3 install -r requirements.txt` to install all the dependencies.
28 |
29 | ## Download Data
30 | The data we used here is preprocessed by the repo of [MetaOptNet](https://github.com/kjunelee/MetaOptNet), Please find the renamed versions of the files in below link by [RFS](https://github.com/WangYueFt/rfs).
31 |
32 | [[DropBox Data Packages Link]](https://www.dropbox.com/sh/6yd1ygtyc3yd981/AABVeEqzC08YQv4UZk7lNHvya?dl=0)
33 |
34 | ## Training
35 |
36 | ### Generation Zero
37 | To perform the Generation Zero experiment, run:
38 |
39 | `python3 train_supervised_ssl.py --tags cifarfs,may30 --model resnet12_ssl --model_path save/backup --dataset CIFAR-FS --data_root ../../Datasets/CIFAR_FS/ --n_aug_support_samples 5 --n_ways 5 --n_shots 1 --epochs 65 --lr_decay_epochs 60 --gamma 2.0`
40 |
41 | WANDB will create unique names for each runs, and save the model names accordingly. Use this name for the teacher in the next experiment.
42 |
43 |
44 | ### Generation One
45 | To perform the Generation One experiment, run:
46 |
47 | `python3 train_distillation.py --tags cifarfs,gen1,may30 --model_s resnet12_ssl --model_t resnet12_ssl --path_t save/backup/resnet12_ssl_CIFAR-FS_lr_0.05_decay_0.0005_trans_D_trial_1/model_firm-sun-1.pth --model_path save/backup --dataset CIFAR-FS --data_root ../../Datasets/CIFAR_FS/ --n_aug_support_samples 5 --n_ways 5 --n_shots 1 --epochs 65 --lr_decay_epochs 60 --gamma 0.1`
48 |
49 |
50 | ### Evaluation
51 |
52 | `python3 eval_fewshot.py --model resnet12_ssl --model_path save/backup2/resnet12_ssl_toy_lr_0.05_decay_0.0005_trans_A_trial_1/model_firm-sun-1.pth --dataset toy --data_root ../../Datasets/CIFAR_FS/ --n_aug_support_samples 5 --n_ways 5 --n_shots 1`
53 |
54 |
55 | ## Results
56 |
57 | We perform extensive experiments on four datasets in a few-shot learning setting, leading to significant improvements over the state of the art methods.
58 |
59 | 
60 | (c) SKD performance on miniImageNet and tieredImageNet.
61 |
62 |
63 | 
64 | (d) SKD performance on CIFAR-FS and FC100 datasets.
65 |
66 |
67 | ## We Credit
68 | Thanks to https://github.com/WangYueFt/rfs, for the preliminary implementations.
69 |
70 | ## Contact
71 | Jathushan Rajasegaran - jathushan.rajasegaran@inceptioniai.org or brjathu@gmail.com
72 |
73 | To ask questions or report issues, please open an issue on the [issues tracker](https://github.com/brjathu/SKD/issues).
74 |
75 | Discussions, suggestions and questions are welcome!
76 |
77 |
78 | ## Citation
79 | ```
80 | @article{rajasegaran2020self,
81 | title={Self-supervised Knowledge Distillation for Few-shot Learning},
82 | author={Rajasegaran, Jathushan and Khan, Salman and Hayat, Munawar and Khan, Fahad Shahbaz and Shah, Mubarak},
83 | journal={https://arxiv.org/abs/2006.09785},
84 | year = {2020}
85 | }
86 | ```
87 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 | import socket
6 | import time
7 | import sys
8 | from tqdm import tqdm
9 |
10 | import torch
11 | import torch.optim as optim
12 | import torch.nn as nn
13 | import torch.backends.cudnn as cudnn
14 | from torch.utils.data import DataLoader
15 | import torch.nn.functional as F
16 |
17 | from dataset.mini_imagenet import ImageNet, MetaImageNet
18 | from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet
19 | from dataset.cifar import CIFAR100, MetaCIFAR100, CIFAR100_toy
20 | from dataset.transform_cfg import transforms_options, transforms_test_options, transforms_list
21 |
22 | from dataset.dataset_selfsupervision import SSDatasetWrapper
23 |
24 | import numpy as np
25 |
26 |
27 | def get_dataloaders(opt):
28 | # dataloader
29 | train_partition = 'trainval' if opt.use_trainval else 'train'
30 |
31 | if opt.dataset == 'toy':
32 |
33 | train_trans, test_trans = transforms_options['D']
34 |
35 | train_loader = DataLoader(CIFAR100_toy(args=opt, partition=train_partition, transform=train_trans),
36 | batch_size=opt.batch_size, shuffle=True, drop_last=True,
37 | num_workers=opt.num_workers)
38 | val_loader = DataLoader(CIFAR100_toy(args=opt, partition='train', transform=test_trans),
39 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
40 | num_workers=opt.num_workers // 2)
41 |
42 | # train_trans, test_trans = transforms_test_options[opt.transform]
43 |
44 | # meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
45 | # train_transform=train_trans,
46 | # test_transform=test_trans),
47 | # batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
48 | # num_workers=opt.num_workers)
49 | # meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
50 | # train_transform=train_trans,
51 | # test_transform=test_trans),
52 | # batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
53 | # num_workers=opt.num_workers)
54 | n_cls = 5
55 |
56 | return train_loader, val_loader, 5, 5, n_cls
57 |
58 |
59 | if opt.dataset == 'miniImageNet':
60 |
61 | train_trans, test_trans = transforms_options[opt.transform]
62 | train_loader = DataLoader(ImageNet(args=opt, partition=train_partition, transform=train_trans),
63 | batch_size=opt.batch_size, shuffle=True, drop_last=True,
64 | num_workers=opt.num_workers)
65 | val_loader = DataLoader(ImageNet(args=opt, partition='val', transform=test_trans),
66 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
67 | num_workers=opt.num_workers // 2)
68 |
69 | train_trans, test_trans = transforms_test_options[opt.transform]
70 | meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test',
71 | train_transform=train_trans,
72 | test_transform=test_trans),
73 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
74 | num_workers=opt.num_workers)
75 | meta_valloader = DataLoader(MetaImageNet(args=opt, partition='val',
76 | train_transform=train_trans,
77 | test_transform=test_trans),
78 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
79 | num_workers=opt.num_workers)
80 |
81 | if opt.use_trainval:
82 | n_cls = 80
83 | else:
84 | n_cls = 64
85 | elif opt.dataset == 'tieredImageNet':
86 | train_trans, test_trans = transforms_options[opt.transform]
87 | train_loader = DataLoader(TieredImageNet(args=opt, partition=train_partition, transform=train_trans),
88 | batch_size=opt.batch_size, shuffle=True, drop_last=True,
89 | num_workers=opt.num_workers)
90 | val_loader = DataLoader(TieredImageNet(args=opt, partition='train_phase_val', transform=test_trans),
91 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
92 | num_workers=opt.num_workers // 2)
93 |
94 | train_trans, test_trans = transforms_test_options[opt.transform]
95 | meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test',
96 | train_transform=train_trans,
97 | test_transform=test_trans),
98 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
99 | num_workers=opt.num_workers)
100 | meta_valloader = DataLoader(MetaTieredImageNet(args=opt, partition='val',
101 | train_transform=train_trans,
102 | test_transform=test_trans),
103 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
104 | num_workers=opt.num_workers)
105 | if opt.use_trainval:
106 | n_cls = 448
107 | else:
108 | n_cls = 351
109 | elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
110 | train_trans, test_trans = transforms_options['D']
111 |
112 | train_loader = DataLoader(CIFAR100(args=opt, partition=train_partition, transform=train_trans),
113 | batch_size=opt.batch_size, shuffle=True, drop_last=True,
114 | num_workers=opt.num_workers)
115 | val_loader = DataLoader(CIFAR100(args=opt, partition='train', transform=test_trans),
116 | batch_size=opt.batch_size // 2, shuffle=False, drop_last=False,
117 | num_workers=opt.num_workers // 2)
118 |
119 | train_trans, test_trans = transforms_test_options[opt.transform]
120 |
121 |
122 | # ns = [opt.n_shots].copy()
123 | # opt.n_ways = 32
124 | # opt.n_shots = 5
125 | # opt.n_aug_support_samples = 2
126 | meta_trainloader = DataLoader(MetaCIFAR100(args=opt, partition='train',
127 | train_transform=train_trans,
128 | test_transform=test_trans),
129 | batch_size=1, shuffle=True, drop_last=False,
130 | num_workers=opt.num_workers)
131 |
132 | # opt.n_ways = 5
133 | # opt.n_shots = ns[0]
134 | # print(opt.n_shots)
135 | # opt.n_aug_support_samples = 5
136 | meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
137 | train_transform=train_trans,
138 | test_transform=test_trans),
139 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
140 | num_workers=opt.num_workers)
141 | meta_valloader = DataLoader(MetaCIFAR100(args=opt, partition='val',
142 | train_transform=train_trans,
143 | test_transform=test_trans),
144 | batch_size=opt.test_batch_size, shuffle=False, drop_last=False,
145 | num_workers=opt.num_workers)
146 | if opt.use_trainval:
147 | n_cls = 80
148 | else:
149 | if opt.dataset == 'CIFAR-FS':
150 | n_cls = 64
151 | elif opt.dataset == 'FC100':
152 | n_cls = 60
153 | else:
154 | raise NotImplementedError('dataset not supported: {}'.format(opt.dataset))
155 | # return train_loader, val_loader, meta_trainloader, meta_testloader, meta_valloader, n_cls
156 | else:
157 | raise NotImplementedError(opt.dataset)
158 |
159 | return train_loader, val_loader, meta_testloader, meta_valloader, n_cls
--------------------------------------------------------------------------------
/dataset/cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import pickle
5 | from PIL import Image
6 | import numpy as np
7 |
8 | import torch
9 | import torchvision.transforms as transforms
10 | from torch.utils.data import Dataset
11 |
12 |
13 | class CIFAR100(Dataset):
14 | """support FC100 and CIFAR-FS"""
15 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096,
16 | transform=None):
17 | super(Dataset, self).__init__()
18 | self.data_root = args.data_root
19 | self.partition = partition
20 | self.data_aug = args.data_aug
21 | self.mean = [0.5071, 0.4867, 0.4408]
22 | self.std = [0.2675, 0.2565, 0.2761]
23 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
24 | self.pretrain = pretrain
25 | self.simclr = args.simclr
26 |
27 |
28 | if transform is None:
29 | if self.partition == 'train' and self.data_aug:
30 | self.transform = transforms.Compose([
31 | lambda x: Image.fromarray(x),
32 | transforms.RandomCrop(32, padding=4),
33 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
34 | transforms.RandomHorizontalFlip(),
35 | lambda x: np.asarray(x),
36 | transforms.ToTensor(),
37 | self.normalize
38 | ])
39 | else:
40 | self.transform = transforms.Compose([
41 | lambda x: Image.fromarray(x),
42 | transforms.ToTensor(),
43 | self.normalize
44 | ])
45 | else:
46 | self.transform = transform
47 |
48 | if self.pretrain:
49 | self.file_pattern = '%s.pickle'
50 | else:
51 | self.file_pattern = '%s.pickle'
52 | self.data = {}
53 |
54 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
55 | data = pickle.load(f, encoding='latin1')
56 | self.imgs = data['data']
57 | labels = data['labels']
58 | # adjust sparse labels to labels from 0 to n.
59 | cur_class = 0
60 | label2label = {}
61 | for idx, label in enumerate(labels):
62 | if label not in label2label:
63 | label2label[label] = cur_class
64 | cur_class += 1
65 | new_labels = []
66 | for idx, label in enumerate(labels):
67 | new_labels.append(label2label[label])
68 | self.labels = new_labels
69 |
70 |
71 | # pre-process for contrastive sampling
72 | self.k = k
73 | self.is_sample = is_sample
74 | if self.is_sample:
75 | self.labels = np.asarray(self.labels)
76 | self.labels = self.labels - np.min(self.labels)
77 | num_classes = np.max(self.labels) + 1
78 |
79 | self.cls_positive = [[] for _ in range(num_classes)]
80 | for i in range(len(self.imgs)):
81 | self.cls_positive[self.labels[i]].append(i)
82 |
83 | self.cls_negative = [[] for _ in range(num_classes)]
84 | for i in range(num_classes):
85 | for j in range(num_classes):
86 | if j == i:
87 | continue
88 | self.cls_negative[i].extend(self.cls_positive[j])
89 |
90 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
91 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
92 | self.cls_positive = np.asarray(self.cls_positive)
93 | self.cls_negative = np.asarray(self.cls_negative)
94 |
95 | def __getitem__(self, item):
96 | img = np.asarray(self.imgs[item]).astype('uint8')
97 | target = self.labels[item] - min(self.labels)
98 |
99 | if(self.simclr):
100 | img1 = self.transform(img)
101 | img2 = self.transform(img)
102 | return (img1, img2), target, item
103 |
104 | img = self.transform(img)
105 | if not self.is_sample:
106 | return img, target, item
107 | else:
108 | pos_idx = item
109 | replace = True if self.k > len(self.cls_negative[target]) else False
110 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
111 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
112 | return img, target, item, sample_idx
113 |
114 | def __len__(self):
115 | return len(self.labels)
116 |
117 |
118 |
119 |
120 | class CIFAR100_toy(Dataset):
121 | """support FC100 and CIFAR-FS"""
122 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096,
123 | transform=None):
124 | super(Dataset, self).__init__()
125 | self.data_root = args.data_root
126 | self.partition = partition
127 | self.data_aug = args.data_aug
128 | self.mean = [0.5071, 0.4867, 0.4408]
129 | self.std = [0.2675, 0.2565, 0.2761]
130 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
131 | self.pretrain = pretrain
132 | self.simclr = args.simclr
133 |
134 |
135 | if transform is None:
136 | if self.partition == 'train' and self.data_aug:
137 | self.transform = transforms.Compose([
138 | lambda x: Image.fromarray(x),
139 | transforms.RandomCrop(32, padding=4),
140 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
141 | transforms.RandomHorizontalFlip(),
142 | lambda x: np.asarray(x),
143 | transforms.ToTensor(),
144 | self.normalize
145 | ])
146 | else:
147 | self.transform = transforms.Compose([
148 | lambda x: Image.fromarray(x),
149 | transforms.ToTensor(),
150 | self.normalize
151 | ])
152 | else:
153 | self.transform = transform
154 |
155 | if self.pretrain:
156 | self.file_pattern = '%s.pickle'
157 | else:
158 | self.file_pattern = '%s.pickle'
159 | self.data = {}
160 |
161 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
162 | data = pickle.load(f, encoding='latin1')
163 | self.imgs = data['data']
164 | labels = data['labels']
165 | # adjust sparse labels to labels from 0 to n.
166 | cur_class = 0
167 | label2label = {}
168 | for idx, label in enumerate(labels):
169 | if label not in label2label:
170 | label2label[label] = cur_class
171 | cur_class += 1
172 | new_labels = []
173 | for idx, label in enumerate(labels):
174 | new_labels.append(label2label[label])
175 | self.labels = new_labels
176 |
177 | self.labels = np.array(self.labels)
178 | self.imgs = np.array(self.imgs)
179 | print(self.labels.shape)
180 | print(self.imgs.shape)
181 |
182 | loc = np.where(self.labels<5)[0]
183 | self.labels = self.labels[loc]
184 | self.imgs = self.imgs[loc]
185 |
186 |
187 | self.k = k
188 | self.is_sample = is_sample
189 |
190 | def __getitem__(self, item):
191 | img = np.asarray(self.imgs[item]).astype('uint8')
192 | target = self.labels[item] - min(self.labels)
193 |
194 | if(self.simclr):
195 | img1 = self.transform(img)
196 | img2 = self.transform(img)
197 | return (img1, img2), target, item
198 |
199 | img = self.transform(img)
200 | if not self.is_sample:
201 | return img, target, item
202 | else:
203 | pos_idx = item
204 | replace = True if self.k > len(self.cls_negative[target]) else False
205 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
206 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
207 | return img, target, item, sample_idx
208 |
209 | def __len__(self):
210 | return len(self.labels)
211 |
212 |
213 |
214 |
215 |
216 | class MetaCIFAR100(CIFAR100):
217 |
218 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True):
219 | super(MetaCIFAR100, self).__init__(args, partition, False)
220 | self.fix_seed = fix_seed
221 | self.n_ways = args.n_ways
222 | self.n_shots = args.n_shots
223 | self.n_queries = args.n_queries
224 | self.classes = list(self.data.keys())
225 | self.n_test_runs = args.n_test_runs
226 | self.n_aug_support_samples = args.n_aug_support_samples
227 | if train_transform is None:
228 | self.train_transform = transforms.Compose([
229 | lambda x: Image.fromarray(x),
230 | transforms.RandomCrop(32, padding=4),
231 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
232 | transforms.RandomHorizontalFlip(),
233 | lambda x: np.asarray(x),
234 | transforms.ToTensor(),
235 | self.normalize
236 | ])
237 | else:
238 | self.train_transform = train_transform
239 |
240 | if test_transform is None:
241 | self.test_transform = transforms.Compose([
242 | lambda x: Image.fromarray(x),
243 | transforms.ToTensor(),
244 | self.normalize
245 | ])
246 | else:
247 | self.test_transform = test_transform
248 |
249 | self.data = {}
250 | for idx in range(self.imgs.shape[0]):
251 | if self.labels[idx] not in self.data:
252 | self.data[self.labels[idx]] = []
253 | self.data[self.labels[idx]].append(self.imgs[idx])
254 | self.classes = list(self.data.keys())
255 |
256 | def __getitem__(self, item):
257 | if self.fix_seed:
258 | np.random.seed(item)
259 | cls_sampled = np.random.choice(self.classes, self.n_ways, False)
260 |
261 | support_xs = []
262 | support_ys = []
263 | support_ts = []
264 | query_xs = []
265 | query_ys = []
266 | query_ts = []
267 | for idx, cls in enumerate(cls_sampled):
268 | imgs = np.asarray(self.data[cls]).astype('uint8')
269 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False)
270 | support_xs.append(imgs[support_xs_ids_sampled])
271 | support_ys.append([idx] * self.n_shots)
272 | support_ts.append([cls] * self.n_shots)
273 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled)
274 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False)
275 | query_xs.append(imgs[query_xs_ids])
276 | query_ys.append([idx] * query_xs_ids.shape[0])
277 | query_ts.append([cls] * query_xs_ids.shape[0])
278 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array(
279 | query_xs), np.array(query_ys)
280 | support_ts, query_ts = np.array(support_ts), np.array(query_ts)
281 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape
282 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel))
283 | query_ys = query_ys.reshape((num_ways * n_queries_per_way,))
284 | query_ts = query_ts.reshape((num_ways * n_queries_per_way,))
285 |
286 | support_xs = support_xs.reshape((-1, height, width, channel))
287 | if self.n_aug_support_samples > 1:
288 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1))
289 | support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples))
290 | support_ts = np.tile(support_ts.reshape((-1,)), (self.n_aug_support_samples))
291 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0)
292 |
293 |
294 |
295 | query_xs = query_xs.reshape((-1, height, width, channel))
296 | if self.n_aug_support_samples > 1:
297 | query_xs = np.tile(query_xs, (self.n_aug_support_samples, 1, 1, 1))
298 | query_ys = np.tile(query_ys.reshape((-1,)), (self.n_aug_support_samples))
299 | query_ts = np.tile(query_ts.reshape((-1,)), (self.n_aug_support_samples))
300 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0)
301 |
302 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs)))
303 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs)))
304 |
305 | return support_xs, support_ys, query_xs, query_ys
306 |
307 | def __len__(self):
308 | return self.n_test_runs
309 |
310 |
311 | if __name__ == '__main__':
312 | args = lambda x: None
313 | args.n_ways = 5
314 | args.n_shots = 1
315 | args.n_queries = 12
316 | # args.data_root = 'data'
317 | args.data_root = '/home/yonglong/Downloads/FC100'
318 | args.data_aug = True
319 | args.n_test_runs = 5
320 | args.n_aug_support_samples = 1
321 | imagenet = CIFAR100(args, 'train')
322 | print(len(imagenet))
323 | print(imagenet.__getitem__(500)[0].shape)
324 |
325 | metaimagenet = MetaCIFAR100(args, 'train')
326 | print(len(metaimagenet))
327 | print(metaimagenet.__getitem__(500)[0].size())
328 | print(metaimagenet.__getitem__(500)[1].shape)
329 | print(metaimagenet.__getitem__(500)[2].size())
330 | print(metaimagenet.__getitem__(500)[3].shape)
--------------------------------------------------------------------------------
/dataset/dataset_selfsupervision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import numpy as np
4 |
5 |
6 | class SSDatasetWrapper(torch.utils.data.Dataset):
7 | def __init__(self, dset, opt):
8 | self.dset = dset
9 | self.opt = opt
10 |
11 | def __getitem__(self, index):
12 | image, target, item = self.dset[index]
13 |
14 | if(not(self.opt.ssl)):
15 | return image, target, item
16 | else:
17 | if(self.opt.ssl_rot):
18 | label = np.random.randint(4)
19 | if label == 1:
20 | image_rot = tensor_rot_90(image)
21 | elif label == 2:
22 | image_rot = tensor_rot_180(image)
23 | elif label == 3:
24 | image_rot = tensor_rot_270(image)
25 | else:
26 | image_rot = image
27 |
28 | return (image, image_rot), (target, label), item
29 |
30 | if(self.opt.ssl_quad):
31 | label = np.random.randint(4)
32 |
33 | horstr = image.size(1) // 2
34 | verstr = image.size(2) // 2
35 | horlab = label // 2
36 | verlab = label % 2
37 |
38 | image_quad = image[:, horlab*horstr:(horlab+1)*horstr, verlab*verstr:(verlab+1)*verstr,]
39 | return (image, image_quad), (target, label), item
40 |
41 | def __len__(self):
42 | return len(self.dset)
43 |
44 | # Assumes that tensor is (nchannels, height, width)
45 | def tensor_rot_90(x):
46 | return x.flip(2).transpose(1,2)
47 | def tensor_rot_90_digit(x):
48 | return x.transpose(1,2)
49 |
50 | def tensor_rot_180(x):
51 | return x.flip(2).flip(1)
52 | def tensor_rot_180_digit(x):
53 | return x.flip(2)
54 |
55 | def tensor_rot_270(x):
56 | return x.transpose(1,2).flip(2)
--------------------------------------------------------------------------------
/dataset/mini_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 |
9 |
10 | class ImageNet(Dataset):
11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096,
12 | transform=None):
13 | super(Dataset, self).__init__()
14 | self.data_root = args.data_root
15 | self.partition = partition
16 | self.data_aug = args.data_aug
17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0]
18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]
19 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
20 | self.pretrain = pretrain
21 |
22 | if transform is None:
23 | if self.partition == 'train' and self.data_aug:
24 | self.transform = transforms.Compose([
25 | lambda x: Image.fromarray(x),
26 | transforms.RandomCrop(84, padding=8),
27 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
28 | transforms.RandomHorizontalFlip(),
29 | lambda x: np.asarray(x),
30 | transforms.ToTensor(),
31 | self.normalize
32 | ])
33 | else:
34 | self.transform = transforms.Compose([
35 | lambda x: Image.fromarray(x),
36 | transforms.ToTensor(),
37 | self.normalize
38 | ])
39 | else:
40 | self.transform = transform
41 |
42 | if self.pretrain:
43 | self.file_pattern = 'miniImageNet_category_split_train_phase_%s.pickle'
44 | else:
45 | self.file_pattern = 'miniImageNet_category_split_%s.pickle'
46 | self.data = {}
47 | with open(os.path.join(self.data_root, self.file_pattern % partition), 'rb') as f:
48 | data = pickle.load(f, encoding='latin1')
49 | self.imgs = data['data']
50 | self.labels = data['labels']
51 |
52 |
53 |
54 |
55 |
56 | # pre-process for contrastive sampling
57 | self.k = k
58 | self.is_sample = is_sample
59 | if self.is_sample:
60 | self.labels = np.asarray(self.labels)
61 | self.labels = self.labels - np.min(self.labels)
62 | num_classes = np.max(self.labels) + 1
63 |
64 | self.cls_positive = [[] for _ in range(num_classes)]
65 | for i in range(len(self.imgs)):
66 | self.cls_positive[self.labels[i]].append(i)
67 |
68 | self.cls_negative = [[] for _ in range(num_classes)]
69 | for i in range(num_classes):
70 | for j in range(num_classes):
71 | if j == i:
72 | continue
73 | self.cls_negative[i].extend(self.cls_positive[j])
74 |
75 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
76 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
77 | self.cls_positive = np.asarray(self.cls_positive)
78 | self.cls_negative = np.asarray(self.cls_negative)
79 |
80 | def __getitem__(self, item):
81 | img = np.asarray(self.imgs[item]).astype('uint8')
82 | img = self.transform(img)
83 | target = self.labels[item] - min(self.labels)
84 |
85 | if not self.is_sample:
86 | return img, target, item
87 | else:
88 | pos_idx = item
89 | replace = True if self.k > len(self.cls_negative[target]) else False
90 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
91 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
92 | return img, target, item, sample_idx
93 |
94 | def __len__(self):
95 | return len(self.labels)
96 |
97 |
98 | class MetaImageNet(ImageNet):
99 |
100 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True):
101 | super(MetaImageNet, self).__init__(args, partition, False)
102 | self.fix_seed = fix_seed
103 | self.n_ways = args.n_ways
104 | self.n_shots = args.n_shots
105 | self.n_queries = args.n_queries
106 | self.classes = list(self.data.keys())
107 | self.n_test_runs = args.n_test_runs
108 | self.n_aug_support_samples = args.n_aug_support_samples
109 | if train_transform is None:
110 | self.train_transform = transforms.Compose([
111 | lambda x: Image.fromarray(x),
112 | transforms.RandomCrop(84, padding=8),
113 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
114 | transforms.RandomHorizontalFlip(),
115 | lambda x: np.asarray(x),
116 | transforms.ToTensor(),
117 | self.normalize
118 | ])
119 | else:
120 | self.train_transform = train_transform
121 |
122 | if test_transform is None:
123 | self.test_transform = transforms.Compose([
124 | lambda x: Image.fromarray(x),
125 | transforms.ToTensor(),
126 | self.normalize
127 | ])
128 | else:
129 | self.test_transform = test_transform
130 |
131 | self.data = {}
132 | for idx in range(self.imgs.shape[0]):
133 | if self.labels[idx] not in self.data:
134 | self.data[self.labels[idx]] = []
135 | self.data[self.labels[idx]].append(self.imgs[idx])
136 | self.classes = list(self.data.keys())
137 |
138 | def __getitem__(self, item):
139 | if self.fix_seed:
140 | np.random.seed(item)
141 | cls_sampled = np.random.choice(self.classes, self.n_ways, False)
142 | support_xs = []
143 | support_ys = []
144 | query_xs = []
145 | query_ys = []
146 | for idx, cls in enumerate(cls_sampled):
147 | imgs = np.asarray(self.data[cls]).astype('uint8')
148 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False)
149 | support_xs.append(imgs[support_xs_ids_sampled])
150 | support_ys.append([idx] * self.n_shots)
151 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled)
152 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False)
153 | query_xs.append(imgs[query_xs_ids])
154 | query_ys.append([idx] * query_xs_ids.shape[0])
155 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array(
156 | query_xs), np.array(query_ys)
157 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape
158 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel))
159 | query_ys = query_ys.reshape((num_ways * n_queries_per_way, ))
160 |
161 | support_xs = support_xs.reshape((-1, height, width, channel))
162 | if self.n_aug_support_samples > 1:
163 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1))
164 | support_ys = np.tile(support_ys.reshape((-1, )), (self.n_aug_support_samples))
165 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0)
166 | query_xs = query_xs.reshape((-1, height, width, channel))
167 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0)
168 |
169 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs)))
170 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs)))
171 |
172 | return support_xs, support_ys, query_xs, query_ys
173 |
174 | def __len__(self):
175 | return self.n_test_runs
176 |
177 |
178 | if __name__ == '__main__':
179 | args = lambda x: None
180 | args.n_ways = 5
181 | args.n_shots = 1
182 | args.n_queries = 12
183 | args.data_root = 'data'
184 | args.data_aug = True
185 | args.n_test_runs = 5
186 | args.n_aug_support_samples = 1
187 | imagenet = ImageNet(args, 'val')
188 | print(len(imagenet))
189 | print(imagenet.__getitem__(500)[0].shape)
190 |
191 | metaimagenet = MetaImageNet(args)
192 | print(len(metaimagenet))
193 | print(metaimagenet.__getitem__(500)[0].size())
194 | print(metaimagenet.__getitem__(500)[1].shape)
195 | print(metaimagenet.__getitem__(500)[2].size())
196 | print(metaimagenet.__getitem__(500)[3].shape)
197 |
--------------------------------------------------------------------------------
/dataset/tiered_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 |
9 |
10 | class TieredImageNet(Dataset):
11 | def __init__(self, args, partition='train', pretrain=True, is_sample=False, k=4096,
12 | transform=None):
13 | super(Dataset, self).__init__()
14 | self.data_root = args.data_root
15 | self.partition = partition
16 | self.data_aug = args.data_aug
17 | self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0]
18 | self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]
19 |
20 | self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
21 | self.pretrain = pretrain
22 |
23 | if transform is None:
24 | if self.partition == 'train' and self.data_aug:
25 | self.transform = transforms.Compose([
26 | lambda x: Image.fromarray(x),
27 | transforms.RandomCrop(84, padding=8),
28 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
29 | transforms.RandomHorizontalFlip(),
30 | lambda x: np.asarray(x),
31 | transforms.ToTensor(),
32 | self.normalize
33 | ])
34 | else:
35 | self.transform = transforms.Compose([
36 | lambda x: Image.fromarray(x),
37 | transforms.ToTensor(),
38 | self.normalize
39 | ])
40 | else:
41 | self.transform = transform
42 |
43 | if self.pretrain:
44 | self.image_file_pattern = '%s_images.npz'
45 | self.label_file_pattern = '%s_labels.pkl'
46 | else:
47 | self.image_file_pattern = '%s_images.npz'
48 | self.label_file_pattern = '%s_labels.pkl'
49 |
50 | self.data = {}
51 |
52 | # modified code to load tieredImageNet
53 | image_file = os.path.join(self.data_root, self.image_file_pattern % partition)
54 | self.imgs = np.load(image_file)['images']
55 | label_file = os.path.join(self.data_root, self.label_file_pattern % partition)
56 | self.labels = self._load_labels(label_file)['labels']
57 |
58 | # pre-process for contrastive sampling
59 | self.k = k
60 | self.is_sample = is_sample
61 | if self.is_sample:
62 | self.labels = np.asarray(self.labels)
63 | self.labels = self.labels - np.min(self.labels)
64 | num_classes = np.max(self.labels) + 1
65 |
66 | self.cls_positive = [[] for _ in range(num_classes)]
67 | for i in range(len(self.imgs)):
68 | self.cls_positive[self.labels[i]].append(i)
69 |
70 | self.cls_negative = [[] for _ in range(num_classes)]
71 | for i in range(num_classes):
72 | for j in range(num_classes):
73 | if j == i:
74 | continue
75 | self.cls_negative[i].extend(self.cls_positive[j])
76 |
77 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
78 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
79 | self.cls_positive = np.asarray(self.cls_positive)
80 | self.cls_negative = np.asarray(self.cls_negative)
81 |
82 | def __getitem__(self, item):
83 | img = np.asarray(self.imgs[item]).astype('uint8')
84 | img = self.transform(img)
85 | target = self.labels[item] - min(self.labels)
86 |
87 | if not self.is_sample:
88 | return img, target, item
89 | else:
90 | pos_idx = item
91 | replace = True if self.k > len(self.cls_negative[target]) else False
92 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
93 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
94 | return img, target, item, sample_idx
95 |
96 | def __len__(self):
97 | return len(self.labels)
98 |
99 | @staticmethod
100 | def _load_labels(file):
101 | try:
102 | with open(file, 'rb') as fo:
103 | data = pickle.load(fo)
104 | return data
105 | except:
106 | with open(file, 'rb') as f:
107 | u = pickle._Unpickler(f)
108 | u.encoding = 'latin1'
109 | data = u.load()
110 | return data
111 |
112 |
113 | class MetaTieredImageNet(TieredImageNet):
114 |
115 | def __init__(self, args, partition='train', train_transform=None, test_transform=None, fix_seed=True):
116 | super(MetaTieredImageNet, self).__init__(args, partition, False)
117 | self.fix_seed = fix_seed
118 | self.n_ways = args.n_ways
119 | self.n_shots = args.n_shots
120 | self.n_queries = args.n_queries
121 | self.classes = list(self.data.keys())
122 | self.n_test_runs = args.n_test_runs
123 | self.n_aug_support_samples = args.n_aug_support_samples
124 | if train_transform is None:
125 | self.train_transform = transforms.Compose([
126 | lambda x: Image.fromarray(x),
127 | transforms.RandomCrop(84, padding=8),
128 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
129 | transforms.RandomHorizontalFlip(),
130 | lambda x: np.asarray(x),
131 | transforms.ToTensor(),
132 | self.normalize
133 | ])
134 | else:
135 | self.train_transform = train_transform
136 |
137 | if test_transform is None:
138 | self.test_transform = transforms.Compose([
139 | lambda x: Image.fromarray(x),
140 | transforms.ToTensor(),
141 | self.normalize
142 | ])
143 | else:
144 | self.test_transform = test_transform
145 |
146 | self.data = {}
147 | for idx in range(self.imgs.shape[0]):
148 | if self.labels[idx] not in self.data:
149 | self.data[self.labels[idx]] = []
150 | self.data[self.labels[idx]].append(self.imgs[idx])
151 | self.classes = list(self.data.keys())
152 |
153 | def __getitem__(self, item):
154 | if self.fix_seed:
155 | np.random.seed(item)
156 | cls_sampled = np.random.choice(self.classes, self.n_ways, False)
157 | support_xs = []
158 | support_ys = []
159 | query_xs = []
160 | query_ys = []
161 | for idx, cls in enumerate(cls_sampled):
162 | imgs = np.asarray(self.data[cls]).astype('uint8')
163 | support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False)
164 | support_xs.append(imgs[support_xs_ids_sampled])
165 | support_ys.append([idx] * self.n_shots)
166 | query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled)
167 | query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False)
168 | query_xs.append(imgs[query_xs_ids])
169 | query_ys.append([idx] * query_xs_ids.shape[0])
170 | support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array(
171 | query_xs), np.array(query_ys)
172 | num_ways, n_queries_per_way, height, width, channel = query_xs.shape
173 | query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel))
174 | query_ys = query_ys.reshape((num_ways * n_queries_per_way,))
175 |
176 | support_xs = support_xs.reshape((-1, height, width, channel))
177 | if self.n_aug_support_samples > 1:
178 | support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1))
179 | support_ys = np.tile(support_ys.reshape((-1,)), (self.n_aug_support_samples))
180 | support_xs = np.split(support_xs, support_xs.shape[0], axis=0)
181 | query_xs = query_xs.reshape((-1, height, width, channel))
182 | query_xs = np.split(query_xs, query_xs.shape[0], axis=0)
183 |
184 | support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs)))
185 | query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs)))
186 |
187 | return support_xs, support_ys, query_xs, query_ys
188 |
189 | def __len__(self):
190 | return self.n_test_runs
191 |
192 |
193 | if __name__ == '__main__':
194 | args = lambda x: None
195 | args.n_ways = 5
196 | args.n_shots = 1
197 | args.n_queries = 12
198 | # args.data_root = 'data'
199 | args.data_root = '/home/yonglong/Data/tiered-imagenet-kwon'
200 | args.data_aug = True
201 | args.n_test_runs = 5
202 | args.n_aug_support_samples = 1
203 | imagenet = TieredImageNet(args, 'train')
204 | print(len(imagenet))
205 | print(imagenet.__getitem__(500)[0].shape)
206 |
207 | metaimagenet = MetaTieredImageNet(args)
208 | print(len(metaimagenet))
209 | print(metaimagenet.__getitem__(500)[0].size())
210 | print(metaimagenet.__getitem__(500)[1].shape)
211 | print(metaimagenet.__getitem__(500)[2].size())
212 | print(metaimagenet.__getitem__(500)[3].shape)
213 |
--------------------------------------------------------------------------------
/dataset/transform_cfg.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | from PIL import Image
5 | import torchvision.transforms as transforms
6 |
7 |
8 | mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0]
9 | std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]
10 | normalize = transforms.Normalize(mean=mean, std=std)
11 |
12 |
13 | transform_A = [
14 | transforms.Compose([
15 | lambda x: Image.fromarray(x),
16 | transforms.RandomCrop(84, padding=8),
17 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
18 | transforms.RandomHorizontalFlip(),
19 | lambda x: np.asarray(x),
20 | transforms.ToTensor(),
21 | normalize
22 | ]),
23 |
24 | transforms.Compose([
25 | lambda x: Image.fromarray(x),
26 | transforms.ToTensor(),
27 | normalize
28 | ])
29 | ]
30 |
31 | transform_A_test = [
32 | transforms.Compose([
33 | lambda x: Image.fromarray(x),
34 | transforms.RandomCrop(84, padding=8),
35 | transforms.RandomHorizontalFlip(),
36 | lambda x: np.asarray(x),
37 | transforms.ToTensor(),
38 | normalize
39 | ]),
40 |
41 | transforms.Compose([
42 | lambda x: Image.fromarray(x),
43 | transforms.ToTensor(),
44 | normalize
45 | ])
46 | ]
47 |
48 |
49 |
50 |
51 |
52 |
53 | # CIFAR style transformation
54 | mean = [0.5071, 0.4867, 0.4408]
55 | std = [0.2675, 0.2565, 0.2761]
56 | normalize_cifar100 = transforms.Normalize(mean=mean, std=std)
57 | transform_D = [
58 | transforms.Compose([
59 | lambda x: Image.fromarray(x),
60 | transforms.RandomCrop(32, padding=4),
61 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
62 | # transforms.RandomRotation(10),
63 | transforms.RandomHorizontalFlip(),
64 | lambda x: np.asarray(x),
65 | transforms.ToTensor(),
66 | normalize_cifar100
67 | ]),
68 |
69 | transforms.Compose([
70 | lambda x: Image.fromarray(x),
71 | transforms.ToTensor(),
72 | normalize_cifar100
73 | ])
74 | ]
75 |
76 | transform_D_test = [
77 | transforms.Compose([
78 | lambda x: Image.fromarray(x),
79 | transforms.RandomCrop(32, padding=4),
80 | transforms.RandomHorizontalFlip(),
81 | lambda x: np.asarray(x),
82 | transforms.ToTensor(),
83 | normalize_cifar100
84 | ]),
85 |
86 | transforms.Compose([
87 | lambda x: Image.fromarray(x),
88 | transforms.ToTensor(),
89 | normalize_cifar100
90 | ])
91 | ]
92 |
93 |
94 | transforms_list = ['A', 'D']
95 |
96 |
97 | transforms_options = {
98 | 'A': transform_A,
99 | 'D': transform_D,
100 | }
101 |
102 | transforms_test_options = {
103 | 'A': transform_A_test,
104 | 'D': transform_D_test,
105 | }
106 |
--------------------------------------------------------------------------------
/distill/NCEAverage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch import nn
4 | from .alias_multinomial import AliasMethod
5 | import math
6 |
7 |
8 | class NCESoftmax(nn.Module):
9 |
10 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5):
11 | super(NCESoftmax, self).__init__()
12 | self.nLem = outputSize
13 | self.unigrams = torch.ones(self.nLem)
14 | self.multinomial = AliasMethod(self.unigrams)
15 | self.multinomial.cuda()
16 | self.K = K
17 |
18 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum]))
19 | stdv = 1. / math.sqrt(inputSize / 3)
20 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
21 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
22 |
23 | def forward(self, l, ab, y, idx=None):
24 | K = int(self.params[0].item())
25 | T = self.params[1].item()
26 | Z_l = self.params[2].item()
27 | Z_ab = self.params[3].item()
28 |
29 | momentum = self.params[4].item()
30 | batchSize = l.size(0)
31 | outputSize = self.memory_l.size(0)
32 | inputSize = self.memory_l.size(1)
33 |
34 | # original score computation
35 | if idx is None:
36 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
37 | idx.select(1, 0).copy_(y.data)
38 | # sample
39 | weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach()
40 | weight_l = weight_l.view(batchSize, K + 1, inputSize)
41 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1))
42 | # out_ab = torch.exp(torch.div(out_ab, T))
43 | out_ab = torch.div(out_ab, T)
44 | # sample
45 | weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach()
46 | weight_ab = weight_ab.view(batchSize, K + 1, inputSize)
47 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1))
48 | # out_l = torch.exp(torch.div(out_l, T))
49 | out_l = torch.div(out_l, T)
50 |
51 | # set Z if haven't been set yet
52 | if Z_l < 0:
53 | # self.params[2] = out_l.mean() * outputSize
54 | self.params[2] = 1
55 | Z_l = self.params[2].clone().detach().item()
56 | print("normalization constant Z_l is set to {:.1f}".format(Z_l))
57 | if Z_ab < 0:
58 | # self.params[3] = out_ab.mean() * outputSize
59 | self.params[3] = 1
60 | Z_ab = self.params[3].clone().detach().item()
61 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab))
62 |
63 | # compute out_l, out_ab
64 | # out_l = torch.div(out_l, Z_l).contiguous()
65 | # out_ab = torch.div(out_ab, Z_ab).contiguous()
66 | out_l = out_l.contiguous()
67 | out_ab = out_ab.contiguous()
68 |
69 | # update memory
70 | with torch.no_grad():
71 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1))
72 | l_pos.mul_(momentum)
73 | l_pos.add_(torch.mul(l, 1 - momentum))
74 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
75 | updated_l = l_pos.div(l_norm)
76 | self.memory_l.index_copy_(0, y, updated_l)
77 |
78 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1))
79 | ab_pos.mul_(momentum)
80 | ab_pos.add_(torch.mul(ab, 1 - momentum))
81 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
82 | updated_ab = ab_pos.div(ab_norm)
83 | self.memory_ab.index_copy_(0, y, updated_ab)
84 |
85 | return out_l, out_ab
86 |
87 |
88 | class NCEAverage(nn.Module):
89 |
90 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5):
91 | super(NCEAverage, self).__init__()
92 | self.nLem = outputSize
93 | self.unigrams = torch.ones(self.nLem)
94 | self.multinomial = AliasMethod(self.unigrams)
95 | self.multinomial.cuda()
96 | self.K = K
97 |
98 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum]))
99 | stdv = 1. / math.sqrt(inputSize / 3)
100 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
101 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
102 |
103 | def forward(self, l, ab, y, idx=None):
104 | K = int(self.params[0].item())
105 | T = self.params[1].item()
106 | Z_l = self.params[2].item()
107 | Z_ab = self.params[3].item()
108 |
109 | momentum = self.params[4].item()
110 | batchSize = l.size(0)
111 | outputSize = self.memory_l.size(0)
112 | inputSize = self.memory_l.size(1)
113 |
114 | # original score computation
115 | if idx is None:
116 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
117 | idx.select(1, 0).copy_(y.data)
118 | # sample
119 | weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach()
120 | weight_l = weight_l.view(batchSize, K + 1, inputSize)
121 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1))
122 | out_ab = torch.exp(torch.div(out_ab, T))
123 | # sample
124 | weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach()
125 | weight_ab = weight_ab.view(batchSize, K + 1, inputSize)
126 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1))
127 | out_l = torch.exp(torch.div(out_l, T))
128 |
129 | # set Z if haven't been set yet
130 | if Z_l < 0:
131 | self.params[2] = out_l.mean() * outputSize
132 | Z_l = self.params[2].clone().detach().item()
133 | print("normalization constant Z_l is set to {:.1f}".format(Z_l))
134 | if Z_ab < 0:
135 | self.params[3] = out_ab.mean() * outputSize
136 | Z_ab = self.params[3].clone().detach().item()
137 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab))
138 |
139 | # compute out_l, out_ab
140 | out_l = torch.div(out_l, Z_l).contiguous()
141 | out_ab = torch.div(out_ab, Z_ab).contiguous()
142 |
143 | # update memory
144 | with torch.no_grad():
145 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1))
146 | l_pos.mul_(momentum)
147 | l_pos.add_(torch.mul(l, 1 - momentum))
148 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
149 | updated_l = l_pos.div(l_norm)
150 | self.memory_l.index_copy_(0, y, updated_l)
151 |
152 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1))
153 | ab_pos.mul_(momentum)
154 | ab_pos.add_(torch.mul(ab, 1 - momentum))
155 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
156 | updated_ab = ab_pos.div(ab_norm)
157 | self.memory_ab.index_copy_(0, y, updated_ab)
158 |
159 | return out_l, out_ab
160 |
161 |
162 | class NCEAverageWithZ(nn.Module):
163 |
164 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5, z=None):
165 | super(NCEAverageWithZ, self).__init__()
166 | self.nLem = outputSize
167 | self.unigrams = torch.ones(self.nLem)
168 | self.multinomial = AliasMethod(self.unigrams)
169 | self.multinomial.cuda()
170 | self.K = K
171 |
172 | if z is None or z <= 0:
173 | z = -1
174 | else:
175 | pass
176 | self.register_buffer('params', torch.tensor([K, T, z, z, momentum]))
177 | stdv = 1. / math.sqrt(inputSize / 3)
178 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
179 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
180 |
181 | def forward(self, l, ab, y, idx=None):
182 | K = int(self.params[0].item())
183 | T = self.params[1].item()
184 | Z_l = self.params[2].item()
185 | Z_ab = self.params[3].item()
186 |
187 | momentum = self.params[4].item()
188 | batchSize = l.size(0)
189 | outputSize = self.memory_l.size(0)
190 | inputSize = self.memory_l.size(1)
191 |
192 | # original score computation
193 | if idx is None:
194 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
195 | idx.select(1, 0).copy_(y.data)
196 | # sample
197 | weight_l = torch.index_select(self.memory_l, 0, idx.view(-1)).detach()
198 | weight_l = weight_l.view(batchSize, K + 1, inputSize)
199 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1))
200 | out_ab = torch.exp(torch.div(out_ab, T))
201 | # sample
202 | weight_ab = torch.index_select(self.memory_ab, 0, idx.view(-1)).detach()
203 | weight_ab = weight_ab.view(batchSize, K + 1, inputSize)
204 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1))
205 | out_l = torch.exp(torch.div(out_l, T))
206 |
207 | # set Z if haven't been set yet
208 | if Z_l < 0:
209 | self.params[2] = out_l.mean() * outputSize
210 | Z_l = self.params[2].clone().detach().item()
211 | print("normalization constant Z_l is set to {:.1f}".format(Z_l))
212 | if Z_ab < 0:
213 | self.params[3] = out_ab.mean() * outputSize
214 | Z_ab = self.params[3].clone().detach().item()
215 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab))
216 |
217 | # compute out_l, out_ab
218 | out_l = torch.div(out_l, Z_l).contiguous()
219 | out_ab = torch.div(out_ab, Z_ab).contiguous()
220 |
221 | # update memory
222 | with torch.no_grad():
223 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1))
224 | l_pos.mul_(momentum)
225 | l_pos.add_(torch.mul(l, 1 - momentum))
226 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
227 | updated_l = l_pos.div(l_norm)
228 | self.memory_l.index_copy_(0, y, updated_l)
229 |
230 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1))
231 | ab_pos.mul_(momentum)
232 | ab_pos.add_(torch.mul(ab, 1 - momentum))
233 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
234 | updated_ab = ab_pos.div(ab_norm)
235 | self.memory_ab.index_copy_(0, y, updated_ab)
236 |
237 | return out_l, out_ab
238 |
239 |
240 | class NCEAverageFull(nn.Module):
241 |
242 | def __init__(self, inputSize, outputSize, T=0.07, momentum=0.5):
243 | super(NCEAverageFull, self).__init__()
244 | self.nLem = outputSize
245 |
246 | self.register_buffer('params', torch.tensor([T, -1, -1, momentum]))
247 | stdv = 1. / math.sqrt(inputSize / 3)
248 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
249 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
250 |
251 | def forward(self, l, ab, y):
252 | T = self.params[0].item()
253 | Z_l = self.params[1].item()
254 | Z_ab = self.params[2].item()
255 |
256 | momentum = self.params[3].item()
257 | batchSize = l.size(0)
258 | outputSize = self.memory_l.size(0)
259 | inputSize = self.memory_l.size(1)
260 |
261 | # score computation
262 | idx1 = y.unsqueeze(1).expand(-1, inputSize).unsqueeze(1).expand(-1, 1, -1)
263 | idx2 = torch.zeros(batchSize).long().cuda()
264 | idx2 = idx2.unsqueeze(1).expand(-1, inputSize).unsqueeze(1).expand(-1, 1, -1)
265 | # sample
266 | weight_l = self.memory_l.clone().detach().unsqueeze(0).expand(batchSize, outputSize, inputSize)
267 | weight_l_1 = weight_l.gather(dim=1, index=idx1)
268 | weight_l_2 = weight_l.gather(dim=1, index=idx2)
269 | weight_l.scatter_(1, idx1, weight_l_2)
270 | weight_l.scatter_(1, idx2, weight_l_1)
271 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1))
272 | out_ab = torch.exp(torch.div(out_ab, T))
273 | # sample
274 | weight_ab = self.memory_ab.clone().detach().unsqueeze(0).expand(batchSize, outputSize, inputSize)
275 | weight_ab_1 = weight_ab.gather(dim=1, index=idx1)
276 | weight_ab_2 = weight_ab.gather(dim=1, index=idx2)
277 | weight_ab.scatter_(1, idx1, weight_ab_2)
278 | weight_ab.scatter_(1, idx2, weight_ab_1)
279 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1))
280 | out_l = torch.exp(torch.div(out_l, T))
281 |
282 | # set Z if haven't been set yet
283 | if Z_l < 0:
284 | self.params[1] = out_l.mean() * outputSize
285 | Z_l = self.params[1].clone().detach().item()
286 | print("normalization constant Z_l is set to {:.1f}".format(Z_l))
287 | if Z_ab < 0:
288 | self.params[2] = out_ab.mean() * outputSize
289 | Z_ab = self.params[2].clone().detach().item()
290 | print("normalization constant Z_ab is set to {:.1f}".format(Z_ab))
291 |
292 | # compute out_l, out_ab
293 | out_l = torch.div(out_l, Z_l).contiguous()
294 | out_ab = torch.div(out_ab, Z_ab).contiguous()
295 |
296 | # update memory
297 | with torch.no_grad():
298 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1))
299 | l_pos.mul_(momentum)
300 | l_pos.add_(torch.mul(l, 1 - momentum))
301 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
302 | updated_l = l_pos.div(l_norm)
303 | self.memory_l.index_copy_(0, y, updated_l)
304 |
305 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1))
306 | ab_pos.mul_(momentum)
307 | ab_pos.add_(torch.mul(ab, 1 - momentum))
308 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
309 | updated_ab = ab_pos.div(ab_norm)
310 | self.memory_ab.index_copy_(0, y, updated_ab)
311 |
312 | return out_l, out_ab
313 |
314 |
315 | class NCEAverageFullSoftmax(nn.Module):
316 |
317 | def __init__(self, inputSize, outputSize, T=1, momentum=0.5):
318 | super(NCEAverageFullSoftmax, self).__init__()
319 | self.nLem = outputSize
320 |
321 | self.register_buffer('params', torch.tensor([T, momentum]))
322 | stdv = 1. / math.sqrt(inputSize / 3)
323 | self.register_buffer('memory_l', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
324 | self.register_buffer('memory_ab', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
325 |
326 | def forward(self, l, ab, y):
327 | T = self.params[0].item()
328 | momentum = self.params[1].item()
329 | batchSize = l.size(0)
330 | outputSize = self.memory_l.size(0)
331 | inputSize = self.memory_l.size(1)
332 |
333 | # score computation
334 | # weight_l = self.memory_l.unsqueeze(0).expand(batchSize, outputSize, inputSize).detach()
335 | weight_l = self.memory_l.clone().unsqueeze(0).expand(batchSize, outputSize, inputSize).detach()
336 | out_ab = torch.bmm(weight_l, ab.view(batchSize, inputSize, 1))
337 | out_ab = out_ab.div(T)
338 | out_ab = out_ab.squeeze().contiguous()
339 |
340 | # weight_ab = self.memory_ab.unsqueeze(0).expand(batchSize, outputSize, inputSize).detach()
341 | weight_ab = self.memory_ab.clone().unsqueeze(0).expand(batchSize, outputSize, inputSize).detach()
342 | out_l = torch.bmm(weight_ab, l.view(batchSize, inputSize, 1))
343 | out_l = out_l.div(T)
344 | out_l = out_l.squeeze().contiguous()
345 |
346 | # update memory
347 | with torch.no_grad():
348 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1))
349 | l_pos.mul_(momentum)
350 | l_pos.add_(torch.mul(l, 1 - momentum))
351 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
352 | updated_l = l_pos.div(l_norm)
353 | self.memory_l.index_copy_(0, y, updated_l)
354 |
355 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1))
356 | ab_pos.mul_(momentum)
357 | ab_pos.add_(torch.mul(ab, 1 - momentum))
358 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
359 | updated_ab = ab_pos.div(ab_norm)
360 | self.memory_ab.index_copy_(0, y, updated_ab)
361 |
362 | return out_l, out_ab
363 |
364 | def update_memory(self, l, ab, y):
365 | momentum = self.params[1].item()
366 | # update memory
367 | with torch.no_grad():
368 | l_pos = torch.index_select(self.memory_l, 0, y.view(-1))
369 | l_pos.mul_(momentum)
370 | l_pos.add_(torch.mul(l, 1 - momentum))
371 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
372 | updated_l = l_pos.div(l_norm)
373 | self.memory_l.index_copy_(0, y, updated_l)
374 |
375 | ab_pos = torch.index_select(self.memory_ab, 0, y.view(-1))
376 | ab_pos.mul_(momentum)
377 | ab_pos.add_(torch.mul(ab, 1 - momentum))
378 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
379 | updated_ab = ab_pos.div(ab_norm)
380 | self.memory_ab.index_copy_(0, y, updated_ab)
381 |
--------------------------------------------------------------------------------
/distill/NCECriterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | eps = 1e-7
5 |
6 |
7 | class NCECriterion(nn.Module):
8 |
9 | def __init__(self, nLem):
10 | super(NCECriterion, self).__init__()
11 | self.nLem = nLem
12 |
13 | def forward(self, x):
14 | batchSize = x.size(0)
15 | K = x.size(1) - 1
16 | Pnt = 1 / float(self.nLem)
17 | Pns = 1 / float(self.nLem)
18 |
19 | # eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
20 | Pmt = x.select(1, 0)
21 | Pmt_div = Pmt.add(K * Pnt + eps)
22 | lnPmt = torch.div(Pmt, Pmt_div)
23 |
24 | # eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
25 | Pon_div = x.narrow(1, 1, K).add(K * Pns + eps)
26 | Pon = Pon_div.clone().fill_(K * Pns)
27 | lnPon = torch.div(Pon, Pon_div)
28 |
29 | # equation 6 in ref. A
30 | lnPmt.log_()
31 | lnPon.log_()
32 |
33 | lnPmtsum = lnPmt.sum(0)
34 | lnPonsum = lnPon.view(-1, 1).sum(0)
35 |
36 | loss = - (lnPmtsum + lnPonsum) / batchSize
37 |
38 | return loss
--------------------------------------------------------------------------------
/distill/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/distill/__init__.py
--------------------------------------------------------------------------------
/distill/alias_multinomial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AliasMethod(object):
5 | '''
6 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
7 | '''
8 | def __init__(self, probs):
9 |
10 | if probs.sum() > 1:
11 | probs.div_(probs.sum())
12 | K = len(probs)
13 | self.prob = torch.zeros(K)
14 | self.alias = torch.LongTensor([0]*K)
15 |
16 | # Sort the data into the outcomes with probabilities
17 | # that are larger and smaller than 1/K.
18 | smaller = []
19 | larger = []
20 | for kk, prob in enumerate(probs):
21 | self.prob[kk] = K*prob
22 | if self.prob[kk] < 1.0:
23 | smaller.append(kk)
24 | else:
25 | larger.append(kk)
26 |
27 | # Loop though and create little binary mixtures that
28 | # appropriately allocate the larger outcomes over the
29 | # overall uniform mixture.
30 | while len(smaller) > 0 and len(larger) > 0:
31 | small = smaller.pop()
32 | large = larger.pop()
33 |
34 | self.alias[small] = large
35 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
36 |
37 | if self.prob[large] < 1.0:
38 | smaller.append(large)
39 | else:
40 | larger.append(large)
41 |
42 | for last_one in smaller+larger:
43 | self.prob[last_one] = 1
44 |
45 | def cuda(self):
46 | self.prob = self.prob.cuda()
47 | self.alias = self.alias.cuda()
48 |
49 | def draw(self, N):
50 | '''
51 | Draw N samples from multinomial
52 | '''
53 | K = self.alias.size(0)
54 |
55 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
56 | prob = self.prob.index_select(0, kk)
57 | alias = self.alias.index_select(0, kk)
58 | # b is whether a random number is greater than q
59 | b = torch.bernoulli(prob)
60 | oq = kk.mul(b.long())
61 | oj = alias.mul((1-b).long())
62 |
63 | return oq + oj
64 |
--------------------------------------------------------------------------------
/distill/criterion.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from scipy.stats import norm
9 |
10 | from .NCEAverage import NCEAverage
11 | from .NCEAverage import NCESoftmax
12 | from .NCECriterion import NCECriterion
13 |
14 |
15 | class DistillKL(nn.Module):
16 | """KL divergence for distillation"""
17 | def __init__(self, T):
18 | super(DistillKL, self).__init__()
19 | self.T = T
20 |
21 | def forward(self, y_s, y_t):
22 | p_s = F.log_softmax(y_s/self.T, dim=1)
23 | p_t = F.softmax(y_t/self.T, dim=1)
24 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
25 | return loss
26 |
27 |
28 | class NCELoss(nn.Module):
29 | """NCE contrastive loss"""
30 | def __init__(self, opt, n_data):
31 | super(NCELoss, self).__init__()
32 | self.contrast = NCEAverage(opt.feat_dim, n_data, opt.nce_k, opt.nce_t, opt.nce_m)
33 | self.criterion_t = NCECriterion(n_data)
34 | self.criterion_s = NCECriterion(n_data)
35 |
36 | def forward(self, f_s, f_t, idx, contrast_idx=None):
37 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
38 | s_loss = self.criterion_s(out_s)
39 | t_loss = self.criterion_t(out_t)
40 | loss = s_loss + t_loss
41 | return loss
42 |
43 |
44 | class NCESoftmaxLoss(nn.Module):
45 | """info NCE style loss, softmax"""
46 | def __init__(self, opt, n_data):
47 | super(NCESoftmaxLoss, self).__init__()
48 | self.contrast = NCESoftmax(opt.feat_dim, n_data, opt.nce_k, opt.nce_t, opt.nce_m)
49 | self.criterion_t = nn.CrossEntropyLoss()
50 | self.criterion_s = nn.CrossEntropyLoss()
51 |
52 | def forward(self, f_s, f_t, idx, contrast_idx=None):
53 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
54 | bsz = f_s.shape[0]
55 | label = torch.zeros([bsz, 1]).cuda().long()
56 | s_loss = self.criterion_s(out_s, label)
57 | t_loss = self.criterion_t(out_t, label)
58 | loss = s_loss + t_loss
59 | return loss
60 |
61 |
62 | class Attention(nn.Module):
63 | """attention transfer loss"""
64 | def __init__(self, p=2):
65 | super(Attention, self).__init__()
66 | self.p = p
67 |
68 | def forward(self, g_s, g_t):
69 | return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
70 |
71 | def at_loss(self, f_s, f_t):
72 | s_H, t_H = f_s.shape[2], f_t.shape[2]
73 | if s_H > t_H:
74 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
75 | elif s_H < t_H:
76 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
77 | else:
78 | pass
79 | return (self.at(f_s) - self.at(f_t)).pow(2).mean()
80 |
81 | def at(self, f):
82 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
83 |
84 |
85 | class HintLoss(nn.Module):
86 | """regression loss from hints"""
87 | def __init__(self):
88 | super(HintLoss, self).__init__()
89 | self.crit = nn.MSELoss()
90 |
91 | def forward(self, f_s, f_t):
92 | loss = self.crit(f_s, f_t)
93 | return loss
94 |
95 |
96 | if __name__ == '__main__':
97 | pass
98 |
--------------------------------------------------------------------------------
/distill/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class Embed(nn.Module):
7 | """Embedding module"""
8 | def __init__(self, dim_in=1024, dim_out=128):
9 | super(Embed, self).__init__()
10 | self.linear = nn.Linear(dim_in, dim_out)
11 | self.l2norm = Normalize(2)
12 |
13 | def forward(self, x):
14 | x = x.view(x.shape[0], -1)
15 | x = self.linear(x)
16 | x = self.l2norm(x)
17 | return x
18 |
19 |
20 | class LinearEmbed(nn.Module):
21 | """Linear Embedding"""
22 | def __init__(self, dim_in=1024, dim_out=128):
23 | super(LinearEmbed, self).__init__()
24 | self.linear = nn.Linear(dim_in, dim_out)
25 |
26 | def forward(self, x):
27 | x = x.view(x.shape[0], -1)
28 | x = self.linear(x)
29 | return x
30 |
31 |
32 | class MLPEmbed(nn.Module):
33 | """non-linear embed by MLP"""
34 | def __init__(self, dim_in=1024, dim_out=128):
35 | super(MLPEmbed, self).__init__()
36 | self.linear1 = nn.Linear(dim_in, 2 * dim_out)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.linear2 = nn.Linear(2 * dim_out, dim_out)
39 | self.l2norm = Normalize(2)
40 |
41 | def forward(self, x):
42 | x = x.view(x.shape[0], -1)
43 | x = self.relu(self.linear1(x))
44 | x = self.l2norm(self.linear2(x))
45 | return x
46 |
47 |
48 | class Normalize(nn.Module):
49 | """normalization layer"""
50 | def __init__(self, power=2):
51 | super(Normalize, self).__init__()
52 | self.power = power
53 |
54 | def forward(self, x):
55 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
56 | out = x.div(norm)
57 | return out
58 |
59 |
60 | if __name__ == '__main__':
61 | pass
62 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/eval/__init__.py
--------------------------------------------------------------------------------
/eval/cls_eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import time
5 | from tqdm import tqdm
6 | from .util import AverageMeter, accuracy
7 | import numpy as np
8 |
9 | def validate(val_loader, model, criterion, opt):
10 | """One epoch validation"""
11 | batch_time = AverageMeter()
12 | losses = AverageMeter()
13 | top1 = AverageMeter()
14 | top5 = AverageMeter()
15 |
16 | # switch to evaluate mode
17 | model.eval()
18 |
19 | with torch.no_grad():
20 | with tqdm(val_loader, total=len(val_loader)) as pbar:
21 | end = time.time()
22 | for idx, (input, target, _) in enumerate(pbar):
23 |
24 | if(opt.simclr):
25 | input = input[0].float()
26 | else:
27 | input = input.float()
28 |
29 | if torch.cuda.is_available():
30 | input = input.cuda()
31 | target = target.cuda()
32 |
33 | # compute output
34 | output = model(input)
35 | loss = criterion(output, target)
36 |
37 | # measure accuracy and record loss
38 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
39 | losses.update(loss.item(), input.size(0))
40 | top1.update(acc1[0], input.size(0))
41 | top5.update(acc5[0], input.size(0))
42 |
43 | # measure elapsed time
44 | batch_time.update(time.time() - end)
45 | end = time.time()
46 |
47 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()),
48 | "Acc@5":'{0:.2f}'.format(top1.avg.cpu().numpy(),2),
49 | "Loss" :'{0:.2f}'.format(losses.avg,2),
50 | })
51 | # if idx % opt.print_freq == 0:
52 | # print('Test: [{0}/{1}]\t'
53 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
54 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
55 | # 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
56 | # 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
57 | # idx, len(val_loader), batch_time=batch_time, loss=losses,
58 | # top1=top1, top5=top5))
59 |
60 | print('Val_Acc@1 {top1.avg:.3f} Val_Acc@5 {top5.avg:.3f}'
61 | .format(top1=top1, top5=top5))
62 |
63 | return top1.avg, top5.avg, losses.avg
64 |
65 |
66 |
67 |
68 | def embedding(val_loader, model, opt):
69 | """One epoch validation"""
70 | batch_time = AverageMeter()
71 | losses = AverageMeter()
72 | top1 = AverageMeter()
73 | top5 = AverageMeter()
74 |
75 | # switch to evaluate mode
76 | model.eval()
77 |
78 |
79 | with torch.no_grad():
80 | with tqdm(val_loader, total=len(val_loader)) as pbar:
81 | end = time.time()
82 | for idx, (input, target, _) in enumerate(pbar):
83 |
84 | if(opt.simclr):
85 | input = input[0].float()
86 | else:
87 | input = input.float()
88 |
89 | if torch.cuda.is_available():
90 | input = input.cuda()
91 | target = target.cuda()
92 |
93 | batch_size = input.size()[0]
94 | x = input
95 | x_90 = x.transpose(2,3).flip(2)
96 | x_180 = x.flip(2).flip(3)
97 | x_270 = x.flip(2).transpose(2,3)
98 | generated_data = torch.cat((x, x_90, x_180, x_270),0)
99 | train_targets = target.repeat(4)
100 |
101 | # compute output
102 | # output = model(input)
103 | (_,_,_,_, feat), (output, rot_logits) = model(generated_data, rot=True)
104 | # loss = criterion(output, target)
105 |
106 | # measure accuracy and record loss
107 | acc1, acc5 = accuracy(output[:batch_size], target, topk=(1, 5))
108 | # losses.update(loss.item(), input.size(0))
109 | top1.update(acc1[0], input.size(0))
110 | top5.update(acc5[0], input.size(0))
111 |
112 | if(idx==0):
113 | embeddings = output
114 | classes = train_targets
115 | else:
116 | embeddings = torch.cat((embeddings, output),0)
117 | classes = torch.cat((classes, train_targets),0)
118 |
119 |
120 | # measure elapsed time
121 | batch_time.update(time.time() - end)
122 | end = time.time()
123 |
124 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()),
125 | "Acc@5":'{0:.2f}'.format(top1.avg.cpu().numpy(),2)
126 | })
127 | # if idx % opt.print_freq == 0:
128 | # print('Test: [{0}/{1}]\t'
129 | # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
130 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
131 | # 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
132 | # 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
133 | # idx, len(val_loader), batch_time=batch_time, loss=losses,
134 | # top1=top1, top5=top5))
135 |
136 | print('Val_Acc@1 {top1.avg:.3f} Val_Acc@5 {top5.avg:.3f}'
137 | .format(top1=top1, top5=top5))
138 | print(embeddings.size())
139 | print(classes.size())
140 |
141 | np.save("embeddings.npy", embeddings.detach().cpu().numpy())
142 | np.save("classes.npy", classes.detach().cpu().numpy())
143 |
144 |
145 |
146 | # with tqdm(val_loader, total=len(val_loader)) as pbar:
147 | # end = time.time()
148 | # for idx, (input, target, _) in enumerate(pbar):
149 |
150 | # if(opt.simclr):
151 | # input = input[0].float()
152 | # else:
153 | # input = input.float()
154 |
155 | # if torch.cuda.is_available():
156 | # input = input.cuda()
157 | # target = target.cuda()
158 |
159 | # generated_data = torch.cat((x, x_180),0)
160 | # # compute output
161 | # # output = model(input)
162 | # (_,_,_,_, feat), (output, rot_logits) = model(input, rot=True)
163 | # # loss = criterion(output, target)
164 |
165 | # # measure accuracy and record loss
166 | # acc1, acc5 = accuracy(output, target, topk=(1, 5))
167 | # # losses.update(loss.item(), input.size(0))
168 | # top1.update(acc1[0], input.size(0))
169 | # top5.update(acc5[0], input.size(0))
170 |
171 | # if(idx==0):
172 | # embeddings = output
173 | # classes = target
174 | # else:
175 | # embeddings = torch.cat((embeddings, output),0)
176 | # classes = torch.cat((classes, target),0)
177 |
178 |
179 | # # measure elapsed time
180 | # batch_time.update(time.time() - end)
181 | # end = time.time()
182 |
183 | # pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()),
184 | # "Acc@5":'{0:.2f}'.format(top1.avg.cpu().numpy(),2)
185 | # })
186 | # # if idx % opt.print_freq == 0:
187 | # # print('Test: [{0}/{1}]\t'
188 | # # 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
189 | # # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
190 | # # 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
191 | # # 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
192 | # # idx, len(val_loader), batch_time=batch_time, loss=losses,
193 | # # top1=top1, top5=top5))
194 |
195 | # print('Val_Acc@1 {top1.avg:.3f} Val_Acc@5 {top5.avg:.3f}'
196 | # .format(top1=top1, top5=top5))
197 | # print(embeddings.size())
198 | # print(classes.size())
199 |
200 | # np.save("embeddings.npy", embeddings.detach().cpu().numpy())
201 | # np.save("classes.npy", classes.detach().cpu().numpy())
202 | return top1.avg, top5.avg, losses.avg
203 |
--------------------------------------------------------------------------------
/eval/meta_eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | import scipy
5 | from scipy.stats import t
6 | from tqdm import tqdm
7 |
8 | import torch
9 | from sklearn import metrics
10 | from sklearn.svm import SVC
11 | from sklearn.linear_model import LogisticRegression
12 | from sklearn.neighbors import KNeighborsClassifier
13 | from sklearn.ensemble import RandomForestClassifier
14 |
15 | import torch
16 | import torch.nn as nn
17 | import sys, os
18 | from collections import Counter
19 |
20 |
21 | sys.path.append(os.path.abspath('..'))
22 |
23 | from util import accuracy
24 |
25 |
26 | def mean_confidence_interval(data, confidence=0.95):
27 | a = 100.0 * np.array(data)
28 | n = len(a)
29 | m, se = np.mean(a), scipy.stats.sem(a)
30 | h = se * t._ppf((1+confidence)/2., n-1)
31 | return m, h
32 |
33 |
34 | def normalize(x):
35 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2)
36 | out = x.div(norm)
37 | return out
38 |
39 |
40 | def meta_test(net, testloader, use_logit=False, is_norm=True, classifier='LR'):
41 | net = net.eval()
42 | acc = []
43 |
44 | with torch.no_grad():
45 | with tqdm(testloader, total=len(testloader)) as pbar:
46 | for idx, data in enumerate(pbar):
47 | support_xs, support_ys, query_xs, query_ys = data
48 |
49 | support_xs = support_xs.cuda()
50 | query_xs = query_xs.cuda()
51 | batch_size, _, height, width, channel = support_xs.size()
52 | support_xs = support_xs.view(-1, height, width, channel)
53 | query_xs = query_xs.view(-1, height, width, channel)
54 |
55 |
56 |
57 | # batch_size = support_xs.size()[0]
58 | # x = support_xs
59 | # x_90 = x.transpose(2,3).flip(2)
60 | # x_180 = x.flip(2).flip(3)
61 | # x_270 = x.flip(2).transpose(2,3)
62 | # generated_data = torch.cat((x, x_90, x_180, x_270),0)
63 | # support_ys = support_ys.repeat(1,4)
64 | # support_xs = generated_data
65 |
66 | # print(support_xs.size())
67 | # print(support_ys.size())
68 |
69 |
70 |
71 | if use_logit:
72 | support_features = net(support_xs).view(support_xs.size(0), -1)
73 | query_features = net(query_xs).view(query_xs.size(0), -1)
74 | else:
75 | feat_support, _ = net(support_xs, is_feat=True)
76 | support_features = feat_support[-1].view(support_xs.size(0), -1)
77 | feat_query, _ = net(query_xs, is_feat=True)
78 | query_features = feat_query[-1].view(query_xs.size(0), -1)
79 |
80 | # feat_support, _ = net(support_xs)
81 | # support_features = feat_support.view(support_xs.size(0), -1)
82 | # feat_query, _ = net(query_xs)
83 | # query_features = feat_query.view(query_xs.size(0), -1)
84 |
85 |
86 | if is_norm:
87 | support_features = normalize(support_features)
88 | query_features = normalize(query_features)
89 |
90 | support_features = support_features.detach().cpu().numpy()
91 | query_features = query_features.detach().cpu().numpy()
92 |
93 | support_ys = support_ys.view(-1).numpy()
94 | query_ys = query_ys.view(-1).numpy()
95 |
96 |
97 |
98 | if classifier == 'LR':
99 | clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000, penalty='l2',
100 | multi_class='multinomial')
101 | clf.fit(support_features, support_ys)
102 | query_ys_pred = clf.predict(query_features)
103 | elif classifier == 'NN':
104 | query_ys_pred = NN(support_features, support_ys, query_features)
105 | elif classifier == 'Cosine':
106 | query_ys_pred = Cosine(support_features, support_ys, query_features)
107 | else:
108 | raise NotImplementedError('classifier not supported: {}'.format(classifier))
109 |
110 |
111 | # bs = query_features.shape[0]//opt.n_aug_support_samples
112 | # a = np.reshape(query_ys_pred[:bs], (-1,1))
113 | # c = query_ys[:bs]
114 | # for i in range(1,opt.n_aug_support_samples):
115 | # a = np.hstack([a, np.reshape(query_ys_pred[i*bs:(i+1)*bs], (-1,1))])
116 |
117 | # d = []
118 | # for i in range(a.shape[0]):
119 | # b = Counter(a[i,:])
120 | # d.append(b.most_common(1)[0][0])
121 |
122 | # # (values,counts) = np.unique(a,axis=1, return_counts=True)
123 | # # print(counts)
124 | # # ind=np.argmax(counts)
125 | # # print values[ind] # pr
126 |
127 |
128 | # # # a = np.argmax
129 | # # print(a.shape)
130 | # # print(c.shape)
131 |
132 | acc.append(metrics.accuracy_score(query_ys, query_ys_pred))
133 |
134 | pbar.set_postfix({"FSL_Acc":'{0:.2f}'.format(metrics.accuracy_score(query_ys, query_ys_pred))})
135 |
136 | return mean_confidence_interval(acc)
137 |
138 |
139 |
140 |
141 | def meta_test_tune(net, testloader, use_logit=False, is_norm=True, classifier='LR', lamda=0.2):
142 | net = net.eval()
143 | acc = []
144 |
145 | with tqdm(testloader, total=len(testloader)) as pbar:
146 | for idx, data in enumerate(pbar):
147 | support_xs, support_ys, query_xs, query_ys, support_ts, query_ts = data
148 |
149 | support_xs = support_xs.cuda()
150 | support_ys = support_ys.cuda()
151 | query_ys = query_ys.cuda()
152 | query_xs = query_xs.cuda()
153 | batch_size, _, height, width, channel = support_xs.size()
154 | support_xs = support_xs.view(-1, height, width, channel)
155 | support_ys = support_ys.view(-1,1)
156 | query_ys = query_ys.view(-1)
157 | query_xs = query_xs.view(-1, height, width, channel)
158 |
159 | if use_logit:
160 | support_features = net(support_xs).view(support_xs.size(0), -1)
161 | query_features = net(query_xs).view(query_xs.size(0), -1)
162 | else:
163 | feat_support, _ = net(support_xs, is_feat=True)
164 | support_features = feat_support[-1].view(support_xs.size(0), -1)
165 | feat_query, _ = net(query_xs, is_feat=True)
166 | query_features = feat_query[-1].view(query_xs.size(0), -1)
167 |
168 | if is_norm:
169 | support_features = normalize(support_features)
170 | query_features = normalize(query_features)
171 |
172 | y_onehot = torch.FloatTensor(support_ys.size()[0], 5).cuda()
173 |
174 | # In your for loop
175 | y_onehot.zero_()
176 | y_onehot.scatter_(1, support_ys, 1)
177 |
178 |
179 | X = support_features
180 | XTX = torch.matmul(torch.t(X),X)
181 |
182 | B = torch.matmul( (XTX + lamda*torch.eye(640).cuda() ).inverse(), torch.matmul(torch.t(X), y_onehot.float()) )
183 | # print(B.size())
184 | m = nn.Sigmoid()
185 | Y_pred = m(torch.matmul(query_features, B))
186 |
187 |
188 | # print(Y_pred, query_ys)
189 | # model = nn.Sequential(nn.Linear(64, 10),nn.LogSoftmax(dim=1))
190 | # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
191 | # criterion = nn.CrossEntropyLoss()
192 |
193 | # model.cuda()
194 | # criterion.cuda()
195 | # model.train()
196 |
197 | # for i in range(5):
198 | # output = model(support_features)
199 | # loss = criterion(output, support_ys)
200 | # optimizer.zero_grad()
201 | # loss.backward(retain_graph=True) # auto-grad
202 | # optimizer.step() # update weights
203 |
204 | # model.eval()
205 | # query_ys_pred = model(query_features)
206 |
207 | acc1, acc5 = accuracy(Y_pred, query_ys, topk=(1, 1))
208 |
209 |
210 | # support_features = support_features.detach().cpu().numpy()
211 | # query_features = query_features.detach().cpu().numpy()
212 |
213 | # support_ys = support_ys.view(-1).numpy()
214 | # query_ys = query_ys.view(-1).numpy()
215 |
216 | # if classifier == 'LR':
217 | # clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000,
218 | # multi_class='multinomial')
219 | # clf.fit(support_features, support_ys)
220 | # query_ys_pred = clf.predict(query_features)
221 | # elif classifier == 'NN':
222 | # query_ys_pred = NN(support_features, support_ys, query_features)
223 | # elif classifier == 'Cosine':
224 | # query_ys_pred = Cosine(support_features, support_ys, query_features)
225 | # else:
226 | # raise NotImplementedError('classifier not supported: {}'.format(classifier))
227 |
228 | acc.append(acc1.item()/100.0)
229 |
230 | pbar.set_postfix({"FSL_Acc":'{0:.4f}'.format(np.mean(acc))})
231 |
232 |
233 | return mean_confidence_interval(acc)
234 |
235 |
236 |
237 | def meta_test_ensamble(net, testloader, use_logit=True, is_norm=True, classifier='LR'):
238 | for n in net:
239 | n = n.eval()
240 | acc = []
241 |
242 | with torch.no_grad():
243 | with tqdm(testloader, total=len(testloader)) as pbar:
244 | for idx, data in enumerate(pbar):
245 | support_xs, support_ys, query_xs, query_ys = data
246 |
247 | support_xs = support_xs.cuda()
248 | query_xs = query_xs.cuda()
249 | batch_size, _, height, width, channel = support_xs.size()
250 | support_xs = support_xs.view(-1, height, width, channel)
251 | query_xs = query_xs.view(-1, height, width, channel)
252 |
253 | if use_logit:
254 | support_features = net[0](support_xs).view(support_xs.size(0), -1)
255 | query_features = net[0](query_xs).view(query_xs.size(0), -1)
256 | for n in net[1:]:
257 | support_features += n(support_xs).view(support_xs.size(0), -1)
258 | query_features += n(query_xs).view(query_xs.size(0), -1)
259 | else:
260 | feat_support, _ = net(support_xs, is_feat=True)
261 | support_features = feat_support[-1].view(support_xs.size(0), -1)
262 | feat_query, _ = net(query_xs, is_feat=True)
263 | query_features = feat_query[-1].view(query_xs.size(0), -1)
264 |
265 | if is_norm:
266 | support_features = normalize(support_features)
267 | query_features = normalize(query_features)
268 |
269 | support_features = support_features.detach().cpu().numpy()
270 | query_features = query_features.detach().cpu().numpy()
271 |
272 | support_ys = support_ys.view(-1).numpy()
273 | query_ys = query_ys.view(-1).numpy()
274 |
275 | if classifier == 'LR':
276 | clf = LogisticRegression(random_state=0, solver='lbfgs', max_iter=1000,
277 | multi_class='multinomial')
278 | clf.fit(support_features, support_ys)
279 | query_ys_pred = clf.predict(query_features)
280 | elif classifier == 'NN':
281 | query_ys_pred = NN(support_features, support_ys, query_features)
282 | elif classifier == 'Cosine':
283 | query_ys_pred = Cosine(support_features, support_ys, query_features)
284 | else:
285 | raise NotImplementedError('classifier not supported: {}'.format(classifier))
286 |
287 | acc.append(metrics.accuracy_score(query_ys, query_ys_pred))
288 |
289 | pbar.set_postfix({"FSL_Acc":'{0:.2f}'.format(metrics.accuracy_score(query_ys, query_ys_pred))})
290 |
291 | return mean_confidence_interval(acc)
292 |
293 |
294 | def NN(support, support_ys, query):
295 | """nearest classifier"""
296 | support = np.expand_dims(support.transpose(), 0)
297 | query = np.expand_dims(query, 2)
298 |
299 | diff = np.multiply(query - support, query - support)
300 | distance = diff.sum(1)
301 | min_idx = np.argmin(distance, axis=1)
302 | pred = [support_ys[idx] for idx in min_idx]
303 | return pred
304 |
305 |
306 | def Cosine(support, support_ys, query):
307 | """Cosine classifier"""
308 | support_norm = np.linalg.norm(support, axis=1, keepdims=True)
309 | support = support / support_norm
310 | query_norm = np.linalg.norm(query, axis=1, keepdims=True)
311 | query = query / query_norm
312 |
313 | cosine_distance = query @ support.transpose()
314 | max_idx = np.argmax(cosine_distance, axis=1)
315 | pred = [support_ys[idx] for idx in max_idx]
316 | return pred
317 |
--------------------------------------------------------------------------------
/eval/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 | def __init__(self):
8 | self.reset()
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
22 |
23 | def accuracy(output, target, topk=(1,)):
24 | """Computes the accuracy over the k top predictions for the specified values of k"""
25 | with torch.no_grad():
26 | maxk = max(topk)
27 | batch_size = target.size(0)
28 |
29 | _, pred = output.topk(maxk, 1, True, True)
30 | pred = pred.t()
31 | correct = pred.eq(target.view(1, -1).expand_as(pred))
32 |
33 | res = []
34 | for k in topk:
35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
36 | res.append(correct_k.mul_(100.0 / batch_size))
37 | return res
38 |
--------------------------------------------------------------------------------
/eval_fewshot.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import socket
5 | import time
6 | import os
7 | import mkl
8 |
9 |
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from torch.utils.data import DataLoader
13 |
14 | from models import model_pool
15 | from models.util import create_model
16 |
17 | from dataset.mini_imagenet import MetaImageNet
18 | from dataset.tiered_imagenet import MetaTieredImageNet
19 | from dataset.cifar import MetaCIFAR100
20 | from dataset.transform_cfg import transforms_test_options, transforms_list
21 |
22 | from eval.meta_eval import meta_test, meta_test_tune
23 | from eval.cls_eval import validate, embedding
24 | from dataloader import get_dataloaders
25 |
26 | mkl.set_num_threads(2)
27 |
28 |
29 |
30 | def parse_option():
31 |
32 | parser = argparse.ArgumentParser('argument for training')
33 |
34 | # load pretrained model
35 | parser.add_argument('--model', type=str, default='resnet12', choices=model_pool)
36 | parser.add_argument('--model_path', type=str, default="", help='absolute path to .pth model')
37 | # parser.add_argument('--model_path', type=str, default="/raid/data/IncrementLearn/imagenet/neurips20/model/maml_miniimagenet_test_5shot_step_5_5ways_5shots/pretrain_maml_miniimagenet_test_5shot_step_5_5ways_5shots.pt", help='absolute path to .pth model')
38 |
39 | # dataset
40 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
41 | 'CIFAR-FS', 'FC100', "toy"])
42 | parser.add_argument('--transform', type=str, default='A', choices=transforms_list)
43 |
44 | # specify data_root
45 | parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')
46 | parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
47 |
48 | # meta setting
49 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
50 | help='Number of test runs')
51 | parser.add_argument('--n_ways', type=int, default=5, metavar='N',
52 | help='Number of classes for doing each classification run')
53 | parser.add_argument('--n_shots', type=int, default=1, metavar='N',
54 | help='Number of shots in test')
55 | parser.add_argument('--n_queries', type=int, default=15, metavar='N',
56 | help='Number of query in test')
57 | parser.add_argument('--n_aug_support_samples', default=5, type=int,
58 | help='The number of augmented samples for each meta test sample')
59 | parser.add_argument('--num_workers', type=int, default=3, metavar='N',
60 | help='Number of workers for dataloader')
61 | parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
62 | help='Size of test batch)')
63 |
64 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
65 |
66 | opt = parser.parse_args()
67 |
68 | if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
69 | opt.transform = 'D'
70 |
71 | if 'trainval' in opt.model_path:
72 | opt.use_trainval = True
73 | else:
74 | opt.use_trainval = False
75 |
76 | # set the path according to the environment
77 | if not opt.data_root:
78 | opt.data_root = './data/{}'.format(opt.dataset)
79 | else:
80 | if(opt.dataset=="toy"):
81 | opt.data_root = '{}/{}'.format(opt.data_root, "CIFAR-FS")
82 | else:
83 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
84 | opt.data_aug = True
85 |
86 | return opt
87 |
88 |
89 | def main():
90 |
91 | opt = parse_option()
92 |
93 | opt.n_test_runs = 600
94 | train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt)
95 |
96 | # load model
97 | model = create_model(opt.model, n_cls, opt.dataset)
98 | ckpt = torch.load(opt.model_path)
99 | model.load_state_dict(ckpt["model"])
100 |
101 | if torch.cuda.is_available():
102 | model = model.cuda()
103 | cudnn.benchmark = True
104 |
105 | start = time.time()
106 | test_acc, test_std = meta_test(model, meta_testloader)
107 | test_time = time.time() - start
108 | print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))
109 |
110 |
111 | start = time.time()
112 | test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False)
113 | test_time = time.time() - start
114 | print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc_feat, test_std_feat, test_time))
115 |
116 |
117 | if __name__ == '__main__':
118 | main()
119 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .convnet import convnet4
2 | from .resnet import resnet12
3 | from .resnet_ssl import resnet12_ssl
4 | from .resnet_sd import resnet12_sd
5 | from .resnet_selfdist import multi_resnet12_kd
6 | from .resnet import seresnet12
7 | from .wresnet import wrn_28_10
8 |
9 | from .resnet_new import resnet50
10 |
11 | model_pool = [
12 | 'convnet4',
13 | 'resnet12',
14 | 'resnet12_ssl',
15 | 'resnet12_kd',
16 | 'resnet12_sd',
17 | 'seresnet12',
18 | 'wrn_28_10',
19 | ]
20 |
21 | model_dict = {
22 | 'wrn_28_10': wrn_28_10,
23 | 'convnet4': convnet4,
24 | 'resnet12': resnet12,
25 | 'resnet12_ssl': resnet12_ssl,
26 | 'resnet12_kd': multi_resnet12_kd,
27 | 'resnet12_sd': resnet12_sd,
28 | 'seresnet12': seresnet12,
29 | 'resnet50': resnet50,
30 | }
31 |
--------------------------------------------------------------------------------
/models/convnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ConvNet(nn.Module):
8 |
9 | def __init__(self, num_classes=-1):
10 | super(ConvNet, self).__init__()
11 | self.layer1 = nn.Sequential(
12 | nn.Conv2d(3, 64, kernel_size=3, padding=1),
13 | nn.BatchNorm2d(64),
14 | nn.ReLU(),
15 | nn.MaxPool2d(2))
16 | self.layer2 = nn.Sequential(
17 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
18 | nn.BatchNorm2d(64),
19 | nn.ReLU(),
20 | nn.MaxPool2d(2))
21 | self.layer3 = nn.Sequential(
22 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
23 | nn.BatchNorm2d(64),
24 | nn.ReLU(),
25 | nn.MaxPool2d(2))
26 | self.layer4 = nn.Sequential(
27 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
28 | # nn.BatchNorm2d(64, momentum=1, affine=True, track_running_stats=False),
29 | nn.BatchNorm2d(64),
30 | nn.ReLU())
31 |
32 | self.avgpool = nn.AdaptiveAvgPool2d(1)
33 |
34 | self.num_classes = num_classes
35 | if self.num_classes > 0:
36 | self.classifier = nn.Linear(64, self.num_classes)
37 |
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
41 | elif isinstance(m, nn.BatchNorm2d):
42 | nn.init.constant_(m.weight, 1)
43 | nn.init.constant_(m.bias, 0)
44 |
45 | def forward(self, x, is_feat=False):
46 | out = self.layer1(x)
47 | f0 = out
48 | out = self.layer2(out)
49 | f1 = out
50 | out = self.layer3(out)
51 | f2 = out
52 | out = self.layer4(out)
53 | f3 = out
54 | out = self.avgpool(out)
55 | out = out.view(out.size(0), -1)
56 | feat = out
57 |
58 | if self.num_classes > 0:
59 | out = self.classifier(out)
60 |
61 | if is_feat:
62 | return [f0, f1, f2, f3, feat], out
63 | else:
64 | return out
65 |
66 |
67 | def convnet4(**kwargs):
68 | """Four layer ConvNet
69 | """
70 | model = ConvNet(**kwargs)
71 | return model
72 |
73 |
74 | if __name__ == '__main__':
75 | model = convnet4(num_classes=64)
76 | data = torch.randn(2, 3, 84, 84)
77 | feat, logit = model(data, is_feat=True)
78 | print(feat[-1].shape)
79 | print(logit.shape)
80 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.distributions import Bernoulli
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class SELayer(nn.Module):
14 | def __init__(self, channel, reduction=16):
15 | super(SELayer, self).__init__()
16 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
17 | self.fc = nn.Sequential(
18 | nn.Linear(channel, channel // reduction),
19 | nn.ReLU(inplace=True),
20 | nn.Linear(channel // reduction, channel),
21 | nn.Sigmoid()
22 | )
23 |
24 | def forward(self, x):
25 | b, c, _, _ = x.size()
26 | y = self.avg_pool(x).view(b, c)
27 | y = self.fc(y).view(b, c, 1, 1)
28 | return x * y
29 |
30 |
31 | class DropBlock(nn.Module):
32 | def __init__(self, block_size):
33 | super(DropBlock, self).__init__()
34 |
35 | self.block_size = block_size
36 | #self.gamma = gamma
37 | #self.bernouli = Bernoulli(gamma)
38 |
39 | def forward(self, x, gamma):
40 | # shape: (bsize, channels, height, width)
41 |
42 | if self.training:
43 | batch_size, channels, height, width = x.shape
44 |
45 | bernoulli = Bernoulli(gamma)
46 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
47 | block_mask = self._compute_block_mask(mask)
48 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
49 | count_ones = block_mask.sum()
50 |
51 | return block_mask * x * (countM / count_ones)
52 | else:
53 | return x
54 |
55 | def _compute_block_mask(self, mask):
56 | left_padding = int((self.block_size-1) / 2)
57 | right_padding = int(self.block_size / 2)
58 |
59 | batch_size, channels, height, width = mask.shape
60 | #print ("mask", mask[0][0])
61 | non_zero_idxs = mask.nonzero()
62 | nr_blocks = non_zero_idxs.shape[0]
63 |
64 | offsets = torch.stack(
65 | [
66 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
67 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding
68 | ]
69 | ).t().cuda()
70 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1)
71 |
72 | if nr_blocks > 0:
73 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
74 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
75 | offsets = offsets.long()
76 |
77 | block_idxs = non_zero_idxs + offsets
78 | #block_idxs += left_padding
79 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
80 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
81 | else:
82 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
83 |
84 | block_mask = 1 - padded_mask#[:height, :width]
85 | return block_mask
86 |
87 |
88 | class BasicBlock(nn.Module):
89 | expansion = 1
90 |
91 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False,
92 | block_size=1, use_se=False):
93 | super(BasicBlock, self).__init__()
94 | self.conv1 = conv3x3(inplanes, planes)
95 | self.bn1 = nn.BatchNorm2d(planes)
96 | self.relu = nn.LeakyReLU(0.1)
97 | self.conv2 = conv3x3(planes, planes)
98 | self.bn2 = nn.BatchNorm2d(planes)
99 | self.conv3 = conv3x3(planes, planes)
100 | self.bn3 = nn.BatchNorm2d(planes)
101 | self.maxpool = nn.MaxPool2d(stride)
102 | self.downsample = downsample
103 | self.stride = stride
104 | self.drop_rate = drop_rate
105 | self.num_batches_tracked = 0
106 | self.drop_block = drop_block
107 | self.block_size = block_size
108 | self.DropBlock = DropBlock(block_size=self.block_size)
109 | self.use_se = use_se
110 | if self.use_se:
111 | self.se = SELayer(planes, 4)
112 |
113 | def forward(self, x):
114 | self.num_batches_tracked += 1
115 |
116 | residual = x
117 |
118 | out = self.conv1(x)
119 | out = self.bn1(out)
120 | out = self.relu(out)
121 |
122 | out = self.conv2(out)
123 | out = self.bn2(out)
124 | out = self.relu(out)
125 |
126 | out = self.conv3(out)
127 | out = self.bn3(out)
128 | if self.use_se:
129 | out = self.se(out)
130 |
131 | if self.downsample is not None:
132 | residual = self.downsample(x)
133 | out += residual
134 | out = self.relu(out)
135 | out = self.maxpool(out)
136 |
137 | if self.drop_rate > 0:
138 | if self.drop_block == True:
139 | feat_size = out.size()[2]
140 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
141 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
142 | out = self.DropBlock(out, gamma=gamma)
143 | else:
144 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
145 |
146 | return out
147 |
148 |
149 | class ResNet(nn.Module):
150 |
151 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0,
152 | dropblock_size=5, num_classes=-1, use_se=False):
153 | super(ResNet, self).__init__()
154 |
155 | self.inplanes = 3
156 | self.use_se = use_se
157 | self.layer1 = self._make_layer(block, n_blocks[0], 64,
158 | stride=2, drop_rate=drop_rate)
159 | self.layer2 = self._make_layer(block, n_blocks[1], 160,
160 | stride=2, drop_rate=drop_rate)
161 | self.layer3 = self._make_layer(block, n_blocks[2], 320,
162 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
163 | self.layer4 = self._make_layer(block, n_blocks[3], 640,
164 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
165 | if avg_pool:
166 | # self.avgpool = nn.AvgPool2d(5, stride=1)
167 | self.avgpool = nn.AdaptiveAvgPool2d(1)
168 | self.keep_prob = keep_prob
169 | self.keep_avg_pool = avg_pool
170 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
171 | self.drop_rate = drop_rate
172 |
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
176 | elif isinstance(m, nn.BatchNorm2d):
177 | nn.init.constant_(m.weight, 1)
178 | nn.init.constant_(m.bias, 0)
179 |
180 | self.num_classes = num_classes
181 | if self.num_classes > 0:
182 | self.classifier = nn.Linear(640, self.num_classes)
183 |
184 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
185 | downsample = None
186 | if stride != 1 or self.inplanes != planes * block.expansion:
187 | downsample = nn.Sequential(
188 | nn.Conv2d(self.inplanes, planes * block.expansion,
189 | kernel_size=1, stride=1, bias=False),
190 | nn.BatchNorm2d(planes * block.expansion),
191 | )
192 |
193 | layers = []
194 | if n_block == 1:
195 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se)
196 | else:
197 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se)
198 | layers.append(layer)
199 | self.inplanes = planes * block.expansion
200 |
201 | for i in range(1, n_block):
202 | if i == n_block - 1:
203 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block,
204 | block_size=block_size, use_se=self.use_se)
205 | else:
206 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se)
207 | layers.append(layer)
208 |
209 | return nn.Sequential(*layers)
210 |
211 | def forward(self, x, is_feat=False):
212 | x = self.layer1(x)
213 | f0 = x
214 | x = self.layer2(x)
215 | f1 = x
216 | x = self.layer3(x)
217 | f2 = x
218 | x = self.layer4(x)
219 | f3 = x
220 | if self.keep_avg_pool:
221 | x = self.avgpool(x)
222 | x = x.view(x.size(0), -1)
223 | feat = x
224 |
225 | if self.num_classes > 0:
226 | x = self.classifier(x)
227 |
228 | if is_feat:
229 | return [f0, f1, f2, f3, feat], x
230 | else:
231 | return x
232 |
233 |
234 | def resnet12(keep_prob=1.0, avg_pool=False, **kwargs):
235 | """Constructs a ResNet-12 model.
236 | """
237 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
238 | return model
239 |
240 |
241 | def resnet18(keep_prob=1.0, avg_pool=False, **kwargs):
242 | """Constructs a ResNet-18 model.
243 | """
244 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
245 | return model
246 |
247 |
248 | def resnet24(keep_prob=1.0, avg_pool=False, **kwargs):
249 | """Constructs a ResNet-24 model.
250 | """
251 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
252 | return model
253 |
254 |
255 | def resnet50(keep_prob=1.0, avg_pool=False, **kwargs):
256 | """Constructs a ResNet-50 model.
257 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers
258 | """
259 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
260 | return model
261 |
262 |
263 | def resnet101(keep_prob=1.0, avg_pool=False, **kwargs):
264 | """Constructs a ResNet-101 model.
265 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers
266 | """
267 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
268 | return model
269 |
270 |
271 | def seresnet12(keep_prob=1.0, avg_pool=False, **kwargs):
272 | """Constructs a ResNet-12 model.
273 | """
274 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
275 | return model
276 |
277 |
278 | def seresnet18(keep_prob=1.0, avg_pool=False, **kwargs):
279 | """Constructs a ResNet-18 model.
280 | """
281 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
282 | return model
283 |
284 |
285 | def seresnet24(keep_prob=1.0, avg_pool=False, **kwargs):
286 | """Constructs a ResNet-24 model.
287 | """
288 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
289 | return model
290 |
291 |
292 | def seresnet50(keep_prob=1.0, avg_pool=False, **kwargs):
293 | """Constructs a ResNet-50 model.
294 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers
295 | """
296 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
297 | return model
298 |
299 |
300 | def seresnet101(keep_prob=1.0, avg_pool=False, **kwargs):
301 | """Constructs a ResNet-101 model.
302 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers
303 | """
304 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
305 | return model
306 |
307 |
308 | if __name__ == '__main__':
309 |
310 | import argparse
311 |
312 | parser = argparse.ArgumentParser('argument for training')
313 | parser.add_argument('--model', type=str, choices=['resnet12', 'resnet18', 'resnet24', 'resnet50', 'resnet101',
314 | 'seresnet12', 'seresnet18', 'seresnet24', 'seresnet50',
315 | 'seresnet101'])
316 | args = parser.parse_args()
317 |
318 | model_dict = {
319 | 'resnet12': resnet12,
320 | 'resnet18': resnet18,
321 | 'resnet24': resnet24,
322 | 'resnet50': resnet50,
323 | 'resnet101': resnet101,
324 | 'seresnet12': seresnet12,
325 | 'seresnet18': seresnet18,
326 | 'seresnet24': seresnet24,
327 | 'seresnet50': seresnet50,
328 | 'seresnet101': seresnet101,
329 | }
330 |
331 | model = model_dict[args.model](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=64)
332 | data = torch.randn(2, 3, 84, 84)
333 | model = model.cuda()
334 | data = data.cuda()
335 | feat, logit = model(data, is_feat=True)
336 | print(feat[-1].shape)
337 | print(logit.shape)
338 |
--------------------------------------------------------------------------------
/models/resnet_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 |
6 | __all__ = ['ResNet', 'resnet50']
7 |
8 | model_urls = {
9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 |
23 | class Normalize(nn.Module):
24 |
25 | def __init__(self, power=2):
26 | super(Normalize, self).__init__()
27 | self.power = power
28 |
29 | def forward(self, x):
30 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
31 | out = x.div(norm)
32 | return out
33 |
34 |
35 | class BasicBlock(nn.Module):
36 | expansion = 1
37 |
38 | def __init__(self, inplanes, planes, stride=1, downsample=None):
39 | super(BasicBlock, self).__init__()
40 | self.conv1 = conv3x3(inplanes, planes, stride)
41 | self.bn1 = nn.BatchNorm2d(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = conv3x3(planes, planes)
44 | self.bn2 = nn.BatchNorm2d(planes)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | residual = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | out += residual
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class Bottleneck(nn.Module):
68 | expansion = 4
69 |
70 | def __init__(self, inplanes, planes, stride=1, downsample=None):
71 | super(Bottleneck, self).__init__()
72 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
73 | self.bn1 = nn.BatchNorm2d(planes)
74 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
75 | padding=1, bias=False)
76 | self.bn2 = nn.BatchNorm2d(planes)
77 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
78 | self.bn3 = nn.BatchNorm2d(planes * 4)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | residual = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv3(out)
95 | out = self.bn3(out)
96 |
97 | if self.downsample is not None:
98 | residual = self.downsample(x)
99 |
100 | out += residual
101 | out = self.relu(out)
102 |
103 | return out
104 |
105 |
106 | class ResNet(nn.Module):
107 |
108 | def __init__(self, block, layers, in_channel=3, width=1, num_classes=64):
109 | self.inplanes = 64
110 | super(ResNet, self).__init__()
111 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3,
112 | bias=False)
113 | self.bn1 = nn.BatchNorm2d(64)
114 | self.relu = nn.ReLU(inplace=True)
115 |
116 | self.base = int(64 * width)
117 |
118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
119 | self.layer1 = self._make_layer(block, self.base, layers[0])
120 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2)
121 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2)
122 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2)
123 | self.avgpool = nn.AvgPool2d(3, stride=1)
124 | self.classifier = nn.Linear(self.base * 8 * block.expansion, num_classes)
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129 | m.weight.data.normal_(0, math.sqrt(2. / n))
130 | elif isinstance(m, nn.BatchNorm2d):
131 | m.weight.data.fill_(1)
132 | m.bias.data.zero_()
133 |
134 | def _make_layer(self, block, planes, blocks, stride=1):
135 | downsample = None
136 | if stride != 1 or self.inplanes != planes * block.expansion:
137 | downsample = nn.Sequential(
138 | nn.Conv2d(self.inplanes, planes * block.expansion,
139 | kernel_size=1, stride=stride, bias=False),
140 | nn.BatchNorm2d(planes * block.expansion),
141 | )
142 |
143 | layers = []
144 | layers.append(block(self.inplanes, planes, stride, downsample))
145 | self.inplanes = planes * block.expansion
146 | for i in range(1, blocks):
147 | layers.append(block(self.inplanes, planes))
148 |
149 | return nn.Sequential(*layers)
150 |
151 | def forward(self, x, is_feat=False):
152 | x = self.conv1(x)
153 | x = self.bn1(x)
154 | x = self.relu(x)
155 | x = self.maxpool(x)
156 | x = self.layer1(x)
157 | x = self.layer2(x)
158 | x = self.layer3(x)
159 | x = self.layer4(x)
160 | x = self.avgpool(x)
161 | x = x.view(x.size(0), -1)
162 |
163 | if is_feat:
164 | return [x], x
165 |
166 | x = self.classifier(x)
167 | return x
168 |
169 |
170 | def resnet50(pretrained=False, **kwargs):
171 | """Constructs a ResNet-50 model.
172 | Args:
173 | pretrained (bool): If True, returns a model pre-trained on ImageNet
174 | """
175 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
176 | if pretrained:
177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
178 | return model
179 |
180 |
181 | if __name__ == '__main__':
182 | model = resnet50(num_classes=200)
183 |
184 | data = torch.randn(2, 3, 84, 84)
185 |
--------------------------------------------------------------------------------
/models/resnet_sd.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 | import torch
4 |
5 | __all__ = ['ResNet_StoDepth_lineardecay', 'resnet18_StoDepth_lineardecay', 'resnet34_StoDepth_lineardecay', 'resnet50_StoDepth_lineardecay', 'resnet101_StoDepth_lineardecay',
6 | 'resnet152_StoDepth_lineardecay']
7 |
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | def conv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class StoDepth_BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, prob, multFlag, inplanes, planes, stride=1, downsample=None):
33 | super(StoDepth_BasicBlock, self).__init__()
34 | self.conv1 = conv3x3(inplanes, planes, stride)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 | self.prob = prob
42 | self.m = torch.distributions.bernoulli.Bernoulli(torch.Tensor([self.prob]))
43 | self.multFlag = multFlag
44 |
45 | def forward(self, x):
46 |
47 | identity = x.clone()
48 |
49 | if self.training:
50 | if torch.equal(self.m.sample(),torch.ones(1)):
51 |
52 | self.conv1.weight.requires_grad = True
53 | self.conv2.weight.requires_grad = True
54 |
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu(out)
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 |
61 | if self.downsample is not None:
62 | identity = self.downsample(x)
63 |
64 | out += identity
65 | else:
66 | # Resnet does not use bias terms
67 | self.conv1.weight.requires_grad = False
68 | self.conv2.weight.requires_grad = False
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out = identity
74 | else:
75 |
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 |
83 | if self.downsample is not None:
84 | identity = self.downsample(x)
85 |
86 | if self.multFlag:
87 | out = self.prob*out + identity
88 | else:
89 | out = out + identity
90 |
91 | out = self.relu(out)
92 |
93 | return out
94 |
95 |
96 | class StoDepth_Bottleneck(nn.Module):
97 | expansion = 4
98 |
99 | def __init__(self, prob, multFlag, inplanes, planes, stride=1, downsample=None):
100 | super(StoDepth_Bottleneck, self).__init__()
101 | self.conv1 = conv1x1(inplanes, planes)
102 | self.bn1 = nn.BatchNorm2d(planes)
103 | self.conv2 = conv3x3(planes, planes, stride)
104 | self.bn2 = nn.BatchNorm2d(planes)
105 | self.conv3 = conv1x1(planes, planes * self.expansion)
106 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.downsample = downsample
109 | self.stride = stride
110 | self.prob = prob
111 | self.m = torch.distributions.bernoulli.Bernoulli(torch.Tensor([self.prob]))
112 | self.multFlag = multFlag
113 |
114 | def forward(self, x):
115 |
116 | identity = x.clone()
117 |
118 | if self.training:
119 | if torch.equal(self.m.sample(),torch.ones(1)):
120 | self.conv1.weight.requires_grad = True
121 | self.conv2.weight.requires_grad = True
122 | self.conv3.weight.requires_grad = True
123 |
124 | out = self.conv1(x)
125 | out = self.bn1(out)
126 | out = self.relu(out)
127 |
128 | out = self.conv2(out)
129 | out = self.bn2(out)
130 | out = self.relu(out)
131 |
132 | out = self.conv3(out)
133 | out = self.bn3(out)
134 |
135 | if self.downsample is not None:
136 | identity = self.downsample(x)
137 |
138 | out += identity
139 | else:
140 | # Resnet does not use bias terms
141 | self.conv1.weight.requires_grad = False
142 | self.conv2.weight.requires_grad = False
143 | self.conv3.weight.requires_grad = False
144 |
145 | if self.downsample is not None:
146 | identity = self.downsample(x)
147 |
148 | out = identity
149 | else:
150 | out = self.conv1(x)
151 | out = self.bn1(out)
152 | out = self.relu(out)
153 |
154 | out = self.conv2(out)
155 | out = self.bn2(out)
156 | out = self.relu(out)
157 |
158 | out = self.conv3(out)
159 | out = self.bn3(out)
160 |
161 | if self.downsample is not None:
162 | identity = self.downsample(x)
163 |
164 | if self.multFlag:
165 | out = self.prob*out + identity
166 | else:
167 | out = out + identity
168 |
169 | out = self.relu(out)
170 |
171 | return out
172 |
173 |
174 | class ResNet_StoDepth_lineardecay(nn.Module):
175 |
176 | def __init__(self, block, prob_0_L, multFlag, layers, num_classes=1000, zero_init_residual=False):
177 | super(ResNet_StoDepth_lineardecay, self).__init__()
178 | self.inplanes = 64
179 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
180 | bias=False)
181 | self.bn1 = nn.BatchNorm2d(64)
182 | self.relu = nn.ReLU(inplace=True)
183 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
184 |
185 | self.multFlag = multFlag
186 | self.prob_now = prob_0_L[0]
187 | self.prob_delta = prob_0_L[0]-prob_0_L[1]
188 | self.prob_step = self.prob_delta/(sum(layers)-1)
189 |
190 | self.layer1 = self._make_layer(block, 64, layers[0])
191 | self.layer2 = self._make_layer(block, 160, layers[1], stride=2)
192 | self.layer3 = self._make_layer(block, 320, layers[2], stride=2)
193 | self.layer4 = self._make_layer(block, 640, layers[3], stride=2)
194 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
195 | self.fc = nn.Linear(640 * block.expansion, 64)
196 |
197 | for m in self.modules():
198 | if isinstance(m, nn.Conv2d):
199 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
200 | elif isinstance(m, nn.BatchNorm2d):
201 | nn.init.constant_(m.weight, 1)
202 | nn.init.constant_(m.bias, 0)
203 |
204 | # Zero-initialize the last BN in each residual branch,
205 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
206 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
207 | if zero_init_residual:
208 | for m in self.modules():
209 | if isinstance(m, StoDepth_lineardecayBottleneck):
210 | nn.init.constant_(m.bn3.weight, 0)
211 | elif isinstance(m, StoDepth_lineardecayBasicBlock):
212 | nn.init.constant_(m.bn2.weight, 0)
213 |
214 | def _make_layer(self, block, planes, blocks, stride=1):
215 | downsample = None
216 | if stride != 1 or self.inplanes != planes * block.expansion:
217 | downsample = nn.Sequential(
218 | conv1x1(self.inplanes, planes * block.expansion, stride),
219 | nn.BatchNorm2d(planes * block.expansion),
220 | )
221 |
222 | layers = []
223 | layers.append(block(self.prob_now, self.multFlag, self.inplanes, planes, stride, downsample))
224 | self.prob_now = self.prob_now - self.prob_step
225 | self.inplanes = planes * block.expansion
226 | for _ in range(1, blocks):
227 | layers.append(block(self.prob_now, self.multFlag, self.inplanes, planes))
228 | self.prob_now = self.prob_now - self.prob_step
229 |
230 | return nn.Sequential(*layers)
231 |
232 | def forward(self, x, is_feat=False):
233 | x = self.conv1(x)
234 | x = self.bn1(x)
235 | x = self.relu(x)
236 | x = self.maxpool(x)
237 |
238 | x = self.layer1(x)
239 | x = self.layer2(x)
240 | x = self.layer3(x)
241 | x = self.layer4(x)
242 |
243 | x = self.avgpool(x)
244 | x = x.view(x.size(0), -1)
245 | x = self.fc(x)
246 |
247 | return x
248 |
249 |
250 |
251 | def resnet12_sd(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs):
252 | """Constructs a ResNet_StoDepth_lineardecay-18 model.
253 | Args:
254 | pretrained (bool): If True, returns a model pre-trained on ImageNet
255 | """
256 | model = ResNet_StoDepth_lineardecay(StoDepth_BasicBlock, prob_0_L, multFlag, [1, 1, 1, 1], **kwargs)
257 | return model
258 |
259 |
260 | def resnet18_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs):
261 | """Constructs a ResNet_StoDepth_lineardecay-18 model.
262 | Args:
263 | pretrained (bool): If True, returns a model pre-trained on ImageNet
264 | """
265 | model = ResNet_StoDepth_lineardecay(StoDepth_BasicBlock, prob_0_L, multFlag, [2, 2, 2, 2], **kwargs)
266 | if pretrained:
267 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
268 | return model
269 |
270 |
271 | def resnet34_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs):
272 | """Constructs a ResNet_StoDepth_lineardecay-34 model.
273 | Args:
274 | pretrained (bool): If True, returns a model pre-trained on ImageNet
275 | """
276 | model = ResNet_StoDepth_lineardecay(StoDepth_BasicBlock, prob_0_L, multFlag, [3, 4, 6, 3], **kwargs)
277 | if pretrained:
278 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
279 | return model
280 |
281 |
282 | def resnet50_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs):
283 | """Constructs a ResNet_StoDepth_lineardecay-50 model.
284 | Args:
285 | pretrained (bool): If True, returns a model pre-trained on ImageNet
286 | """
287 | model = ResNet_StoDepth_lineardecay(StoDepth_Bottleneck, prob_0_L, multFlag, [3, 4, 6, 3], **kwargs)
288 | if pretrained:
289 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
290 | return model
291 |
292 |
293 | def resnet101_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs):
294 | """Constructs a ResNet_StoDepth_lineardecay-101 model.
295 | Args:
296 | pretrained (bool): If True, returns a model pre-trained on ImageNet
297 | """
298 | model = ResNet_StoDepth_lineardecay(StoDepth_Bottleneck, prob_0_L, multFlag, [3, 4, 23, 3], **kwargs)
299 | if pretrained:
300 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
301 | return model
302 |
303 |
304 | def resnet152_StoDepth_lineardecay(pretrained=False, prob_0_L=[1,0.5], multFlag=True, **kwargs):
305 | """Constructs a ResNet_StoDepth_lineardecay-152 model.
306 | Args:
307 | pretrained (bool): If True, returns a model pre-trained on ImageNet
308 | """
309 | model = ResNet_StoDepth_lineardecay(StoDepth_Bottleneck, prob_0_L, multFlag, [3, 8, 36, 3], **kwargs)
310 | if pretrained:
311 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
312 | return model
--------------------------------------------------------------------------------
/models/resnet_selfdist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def conv3x3(in_planes, out_planes, stride=1):
5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3,
6 | stride=stride, padding=1, bias=False)
7 |
8 | def conv1x1(in_planes, planes, stride=1):
9 | return nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
10 |
11 | def branchBottleNeck(channel_in, channel_out, kernel_size):
12 | middle_channel = channel_out//4
13 | return nn.Sequential(
14 | nn.Conv2d(channel_in, middle_channel, kernel_size=1, stride=1),
15 | nn.BatchNorm2d(middle_channel),
16 | nn.ReLU(),
17 |
18 | nn.Conv2d(middle_channel, middle_channel, kernel_size=kernel_size, stride=kernel_size),
19 | nn.BatchNorm2d(middle_channel),
20 | nn.ReLU(),
21 |
22 | nn.Conv2d(middle_channel, channel_out, kernel_size=1, stride=1),
23 | nn.BatchNorm2d(channel_out),
24 | nn.ReLU(),
25 | )
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 | def __init__(self, inplanes, planes, stride=1, downsample=None):
29 | super(BasicBlock, self).__init__()
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | output = self.conv1(x)
42 | output = self.bn1(output)
43 | output = self.relu(output)
44 |
45 | output = self.conv2(output)
46 | output = self.bn2(output)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | output += residual
52 | output = self.relu(output)
53 | return output
54 |
55 | class BottleneckBlock(nn.Module):
56 | expansion = 4
57 | def __init__(self, inplanes, planes, stride=1, downsample=None):
58 | super(BottleneckBlock, self).__init__()
59 | self.conv1 = conv1x1(inplanes, planes)
60 | self.bn1 = nn.BatchNorm2d(planes)
61 | self.relu = nn.ReLU(inplace=True)
62 |
63 | self.conv2 = conv3x3(planes, planes, stride)
64 | self.bn2 = nn.BatchNorm2d(planes)
65 |
66 | self.conv3 = conv1x1(planes, planes*self.expansion)
67 | self.bn3 = nn.BatchNorm2d(planes*self.expansion)
68 |
69 | self.downsample = downsample
70 | self.stride = stride
71 |
72 | def forward(self, x):
73 | residual = x
74 |
75 | output = self.conv1(x)
76 | output = self.bn1(output)
77 | output = self.relu(output)
78 |
79 | output = self.conv2(output)
80 | output = self.bn2(output)
81 | output = self.relu(output)
82 |
83 | output = self.conv3(output)
84 | output = self.bn3(output)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | output += residual
90 | output = self.relu(output)
91 |
92 | return output
93 |
94 | class Multi_ResNet(nn.Module):
95 | """Resnet model
96 |
97 | Args:
98 | block (class): block type, BasicBlock or BottleneckBlock
99 | layers (int list): layer num in each block
100 | num_classes (int): class num
101 | """
102 |
103 | def __init__(self, block, layers, num_classes=1000):
104 | super(Multi_ResNet, self).__init__()
105 | self.inplanes = 64
106 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
107 | self.bn1 = nn.BatchNorm2d(self.inplanes)
108 | self.relu = nn.ReLU(inplace=True)
109 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
110 |
111 | self.layer1 = self._make_layer(block, 64, layers[0])
112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115 |
116 | self.downsample1_1 = nn.Sequential(
117 | conv1x1(64 * block.expansion, 512 * block.expansion, stride=8),
118 | nn.BatchNorm2d(512 * block.expansion),
119 | )
120 | self.bottleneck1_1 = branchBottleNeck(64 * block.expansion, 512 * block.expansion, kernel_size=8)
121 | self.avgpool1 = nn.AdaptiveAvgPool2d((1,1))
122 | self.middle_fc1 = nn.Linear(512 * block.expansion, num_classes)
123 |
124 |
125 | self.downsample2_1 = nn.Sequential(
126 | conv1x1(128 * block.expansion, 512 * block.expansion, stride=4),
127 | nn.BatchNorm2d(512 * block.expansion),
128 | )
129 | self.bottleneck2_1 = branchBottleNeck(128 * block.expansion, 512 * block.expansion, kernel_size=4)
130 | self.avgpool2 = nn.AdaptiveAvgPool2d((1,1))
131 | self.middle_fc2 = nn.Linear(512 * block.expansion, num_classes)
132 |
133 |
134 | self.downsample3_1 = nn.Sequential(
135 | conv1x1(256 * block.expansion, 512 * block.expansion, stride=2),
136 | nn.BatchNorm2d(512 * block.expansion),
137 | )
138 | self.bottleneck3_1 = branchBottleNeck(256 * block.expansion, 512 * block.expansion, kernel_size=2)
139 | self.avgpool3 = nn.AdaptiveAvgPool2d((1,1))
140 | self.middle_fc3 = nn.Linear(512 * block.expansion, num_classes)
141 |
142 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
143 | self.fc = nn.Linear(512 * block.expansion, num_classes)
144 |
145 | for m in self.modules():
146 | if isinstance(m, nn.Conv2d):
147 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
148 | elif isinstance(m, nn.BatchNorm2d):
149 | nn.init.constant_(m.weight, 1)
150 | nn.init.constant_(m.bias, 0)
151 |
152 | def _make_layer(self, block, planes, layers, stride=1):
153 | """A block with 'layers' layers
154 |
155 | Args:
156 | block (class): block type
157 | planes (int): output channels = planes * expansion
158 | layers (int): layer num in the block
159 | stride (int): the first layer stride in the block
160 | """
161 | downsample = None
162 | if stride !=1 or self.inplanes != planes * block.expansion:
163 | downsample = nn.Sequential(
164 | conv1x1(self.inplanes, planes * block.expansion, stride),
165 | nn.BatchNorm2d(planes * block.expansion),
166 | )
167 | layer = []
168 | layer.append(block(self.inplanes, planes, stride=stride, downsample=downsample))
169 | self.inplanes = planes * block.expansion
170 | for i in range(1, layers):
171 | layer.append(block(self.inplanes, planes))
172 |
173 | return nn.Sequential(*layer)
174 |
175 | def forward(self, x, is_feat=False, is_dist=False):
176 | x = self.conv1(x)
177 | x = self.bn1(x)
178 | x = self.relu(x)
179 | # x = self.maxpool(x)
180 |
181 | x = self.layer1(x)
182 | middle_output1 = self.bottleneck1_1(x)
183 | middle_output1 = self.avgpool1(middle_output1)
184 | middle1_fea = middle_output1
185 | middle_output1 = torch.flatten(middle_output1, 1)
186 | middle_output1 = self.middle_fc1(middle_output1)
187 |
188 | x = self.layer2(x)
189 | middle_output2 = self.bottleneck2_1(x)
190 | middle_output2 = self.avgpool2(middle_output2)
191 | middle2_fea = middle_output2
192 | middle_output2 = torch.flatten(middle_output2, 1)
193 | middle_output2 = self.middle_fc2(middle_output2)
194 |
195 | x = self.layer3(x)
196 | middle_output3 = self.bottleneck3_1(x)
197 | middle_output3 = self.avgpool3(middle_output3)
198 | middle3_fea = middle_output3
199 | middle_output3 = torch.flatten(middle_output3, 1)
200 | middle_output3 = self.middle_fc3(middle_output3)
201 |
202 | x = self.layer4(x)
203 | x = self.avgpool(x)
204 | final_fea = x
205 | x = torch.flatten(x, 1)
206 | x = self.fc(x)
207 |
208 |
209 | if is_dist:
210 | return x, middle_output1, middle_output2, middle_output3, final_fea, middle1_fea, middle2_fea, middle3_fea
211 |
212 | if is_feat:
213 | return [final_fea, final_fea, final_fea, final_fea, final_fea], x
214 | else:
215 | return x
216 |
217 |
218 | def multi_resnet50_kd(num_classes=1000):
219 | return Multi_ResNet(BottleneckBlock, [3,4,6,3], num_classes=num_classes)
220 |
221 | def multi_resnet18_kd(num_classes=1000):
222 | return Multi_ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes)
223 |
224 |
225 | def multi_resnet12_kd(num_classes=64):
226 | return Multi_ResNet(BasicBlock, [1,1,1,1], num_classes=num_classes)
227 |
228 |
--------------------------------------------------------------------------------
/models/resnet_ssl.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.distributions import Bernoulli
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | """3x3 convolution with padding"""
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class SELayer(nn.Module):
14 | def __init__(self, channel, reduction=16):
15 | super(SELayer, self).__init__()
16 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
17 | self.fc = nn.Sequential(
18 | nn.Linear(channel, channel // reduction),
19 | nn.ReLU(inplace=True),
20 | nn.Linear(channel // reduction, channel),
21 | nn.Sigmoid()
22 | )
23 |
24 | def forward(self, x):
25 | b, c, _, _ = x.size()
26 | y = self.avg_pool(x).view(b, c)
27 | y = self.fc(y).view(b, c, 1, 1)
28 | return x * y
29 |
30 |
31 | class DropBlock(nn.Module):
32 | def __init__(self, block_size):
33 | super(DropBlock, self).__init__()
34 |
35 | self.block_size = block_size
36 | #self.gamma = gamma
37 | #self.bernouli = Bernoulli(gamma)
38 |
39 | def forward(self, x, gamma):
40 | # shape: (bsize, channels, height, width)
41 |
42 | if self.training:
43 | batch_size, channels, height, width = x.shape
44 |
45 | bernoulli = Bernoulli(gamma)
46 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
47 | block_mask = self._compute_block_mask(mask)
48 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
49 | count_ones = block_mask.sum()
50 |
51 | return block_mask * x * (countM / count_ones)
52 | else:
53 | return x
54 |
55 | def _compute_block_mask(self, mask):
56 | left_padding = int((self.block_size-1) / 2)
57 | right_padding = int(self.block_size / 2)
58 |
59 | batch_size, channels, height, width = mask.shape
60 | #print ("mask", mask[0][0])
61 | non_zero_idxs = mask.nonzero()
62 | nr_blocks = non_zero_idxs.shape[0]
63 |
64 | offsets = torch.stack(
65 | [
66 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
67 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding
68 | ]
69 | ).t().cuda()
70 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1)
71 |
72 | if nr_blocks > 0:
73 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
74 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
75 | offsets = offsets.long()
76 |
77 | block_idxs = non_zero_idxs + offsets
78 | #block_idxs += left_padding
79 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
80 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
81 | else:
82 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
83 |
84 | block_mask = 1 - padded_mask#[:height, :width]
85 | return block_mask
86 |
87 |
88 | class BasicBlock(nn.Module):
89 | expansion = 1
90 |
91 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False,
92 | block_size=1, use_se=False):
93 | super(BasicBlock, self).__init__()
94 | self.conv1 = conv3x3(inplanes, planes)
95 | self.bn1 = nn.BatchNorm2d(planes)
96 | self.relu = nn.LeakyReLU(0.1)
97 | self.conv2 = conv3x3(planes, planes)
98 | self.bn2 = nn.BatchNorm2d(planes)
99 | self.conv3 = conv3x3(planes, planes)
100 | self.bn3 = nn.BatchNorm2d(planes)
101 | self.maxpool = nn.MaxPool2d(stride)
102 | self.downsample = downsample
103 | self.stride = stride
104 | self.drop_rate = drop_rate
105 | self.num_batches_tracked = 0
106 | self.drop_block = drop_block
107 | self.block_size = block_size
108 | self.DropBlock = DropBlock(block_size=self.block_size)
109 | self.use_se = use_se
110 | if self.use_se:
111 | self.se = SELayer(planes, 4)
112 |
113 | def forward(self, x):
114 | self.num_batches_tracked += 1
115 |
116 | residual = x
117 |
118 | out = self.conv1(x)
119 | out = self.bn1(out)
120 | out = self.relu(out)
121 |
122 | out = self.conv2(out)
123 | out = self.bn2(out)
124 | out = self.relu(out)
125 |
126 | out = self.conv3(out)
127 | out = self.bn3(out)
128 | if self.use_se:
129 | out = self.se(out)
130 |
131 | if self.downsample is not None:
132 | residual = self.downsample(x)
133 | out += residual
134 | out = self.relu(out)
135 | out = self.maxpool(out)
136 |
137 | if self.drop_rate > 0:
138 | if self.drop_block == True:
139 | feat_size = out.size()[2]
140 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
141 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
142 | out = self.DropBlock(out, gamma=gamma)
143 | else:
144 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
145 |
146 | return out
147 |
148 |
149 | class ResNet(nn.Module):
150 |
151 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0,
152 | dropblock_size=5, num_classes=-1, use_se=False):
153 | super(ResNet, self).__init__()
154 |
155 | self.inplanes = 3
156 | self.use_se = use_se
157 | self.layer1 = self._make_layer(block, n_blocks[0], 64,
158 | stride=2, drop_rate=drop_rate)
159 | self.layer2 = self._make_layer(block, n_blocks[1], 160,
160 | stride=2, drop_rate=drop_rate)
161 | self.layer3 = self._make_layer(block, n_blocks[2], 320,
162 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
163 | self.layer4 = self._make_layer(block, n_blocks[3], 640,
164 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
165 | if avg_pool:
166 | # self.avgpool = nn.AvgPool2d(5, stride=1)
167 | self.avgpool = nn.AdaptiveAvgPool2d(1)
168 | self.keep_prob = keep_prob
169 | self.keep_avg_pool = avg_pool
170 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
171 | self.drop_rate = drop_rate
172 |
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
176 | elif isinstance(m, nn.BatchNorm2d):
177 | nn.init.constant_(m.weight, 1)
178 | nn.init.constant_(m.bias, 0)
179 |
180 | self.num_classes = num_classes
181 | if self.num_classes > 0:
182 | self.classifier = nn.Linear(640, self.num_classes)
183 | self.rot_classifier = nn.Linear(self.num_classes, 4)
184 | # self.rot_classifier1 = nn.Linear(self.num_classes, 32)
185 | # self.rot_classifier2 = nn.Linear(32, 16)
186 | # self.rot_classifier3 = nn.Linear(16, 4)
187 |
188 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
189 | downsample = None
190 | if stride != 1 or self.inplanes != planes * block.expansion:
191 | downsample = nn.Sequential(
192 | nn.Conv2d(self.inplanes, planes * block.expansion,
193 | kernel_size=1, stride=1, bias=False),
194 | nn.BatchNorm2d(planes * block.expansion),
195 | )
196 |
197 | layers = []
198 | if n_block == 1:
199 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se)
200 | else:
201 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se)
202 | layers.append(layer)
203 | self.inplanes = planes * block.expansion
204 |
205 | for i in range(1, n_block):
206 | if i == n_block - 1:
207 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block,
208 | block_size=block_size, use_se=self.use_se)
209 | else:
210 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se)
211 | layers.append(layer)
212 |
213 | return nn.Sequential(*layers)
214 |
215 | def forward(self, x, is_feat=False, rot=False):
216 | x = self.layer1(x)
217 | f0 = x
218 | x = self.layer2(x)
219 | f1 = x
220 | x = self.layer3(x)
221 | f2 = x
222 | x = self.layer4(x)
223 | f3 = x
224 | if self.keep_avg_pool:
225 | x = self.avgpool(x)
226 | x = x.view(x.size(0), -1)
227 | feat = x
228 |
229 | xx = self.classifier(x)
230 |
231 | if(rot):
232 | # xy1 = self.rot_classifier1(xx)
233 | # xy2 = self.rot_classifier2(xy1)
234 | xy = self.rot_classifier(xx)
235 | return [f0, f1, f2, f3, feat], (xx, xy)
236 |
237 | if is_feat:
238 | return [f0, f1, f2, f3, feat], xx
239 | else:
240 | return xx
241 |
242 |
243 | def resnet12_ssl(keep_prob=1.0, avg_pool=False, **kwargs):
244 | """Constructs a ResNet-12 model.
245 | """
246 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
247 | return model
248 |
249 |
250 | def resnet18(keep_prob=1.0, avg_pool=False, **kwargs):
251 | """Constructs a ResNet-18 model.
252 | """
253 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
254 | return model
255 |
256 |
257 | def resnet24(keep_prob=1.0, avg_pool=False, **kwargs):
258 | """Constructs a ResNet-24 model.
259 | """
260 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
261 | return model
262 |
263 |
264 | def resnet50(keep_prob=1.0, avg_pool=False, **kwargs):
265 | """Constructs a ResNet-50 model.
266 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers
267 | """
268 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
269 | return model
270 |
271 |
272 | def resnet101(keep_prob=1.0, avg_pool=False, **kwargs):
273 | """Constructs a ResNet-101 model.
274 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers
275 | """
276 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
277 | return model
278 |
279 |
280 | def seresnet12(keep_prob=1.0, avg_pool=False, **kwargs):
281 | """Constructs a ResNet-12 model.
282 | """
283 | model = ResNet(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
284 | return model
285 |
286 |
287 | def seresnet18(keep_prob=1.0, avg_pool=False, **kwargs):
288 | """Constructs a ResNet-18 model.
289 | """
290 | model = ResNet(BasicBlock, [1, 1, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
291 | return model
292 |
293 |
294 | def seresnet24(keep_prob=1.0, avg_pool=False, **kwargs):
295 | """Constructs a ResNet-24 model.
296 | """
297 | model = ResNet(BasicBlock, [2, 2, 2, 2], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
298 | return model
299 |
300 |
301 | def seresnet50(keep_prob=1.0, avg_pool=False, **kwargs):
302 | """Constructs a ResNet-50 model.
303 | indeed, only (3 + 4 + 6 + 3) * 3 + 1 = 49 layers
304 | """
305 | model = ResNet(BasicBlock, [3, 4, 6, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
306 | return model
307 |
308 |
309 | def seresnet101(keep_prob=1.0, avg_pool=False, **kwargs):
310 | """Constructs a ResNet-101 model.
311 | indeed, only (3 + 4 + 23 + 3) * 3 + 1 = 100 layers
312 | """
313 | model = ResNet(BasicBlock, [3, 4, 23, 3], keep_prob=keep_prob, avg_pool=avg_pool, use_se=True, **kwargs)
314 | return model
315 |
316 |
317 | if __name__ == '__main__':
318 |
319 | import argparse
320 |
321 | parser = argparse.ArgumentParser('argument for training')
322 | parser.add_argument('--model', type=str, choices=['resnet12', 'resnet18', 'resnet24', 'resnet50', 'resnet101',
323 | 'seresnet12', 'seresnet18', 'seresnet24', 'seresnet50',
324 | 'seresnet101'])
325 | args = parser.parse_args()
326 |
327 | model_dict = {
328 | 'resnet12': resnet12,
329 | 'resnet18': resnet18,
330 | 'resnet24': resnet24,
331 | 'resnet50': resnet50,
332 | 'resnet101': resnet101,
333 | 'seresnet12': seresnet12,
334 | 'seresnet18': seresnet18,
335 | 'seresnet24': seresnet24,
336 | 'seresnet50': seresnet50,
337 | 'seresnet101': seresnet101,
338 | }
339 |
340 | model = model_dict[args.model](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=64)
341 | data = torch.randn(2, 3, 84, 84)
342 | model = model.cuda()
343 | data = data.cuda()
344 | feat, logit = model(data, is_feat=True)
345 | print(feat[-1].shape)
346 | print(logit.shape)
347 |
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 |
4 | from . import model_dict
5 |
6 |
7 | def create_model(name, n_cls, dataset='miniImageNet', dropout=0.1):
8 | """create model by name"""
9 | if dataset == 'miniImageNet' or dataset == 'tieredImageNet':
10 | if name.endswith('v2') or name.endswith('v3'):
11 | model = model_dict[name](num_classes=n_cls)
12 | elif(name.endswith("kd")):
13 | model = model_dict[name](num_classes=n_cls)
14 | elif(name.endswith("ssl")):
15 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=n_cls)
16 | elif name.startswith('resnet50'):
17 | print('use imagenet-style resnet50')
18 | model = model_dict[name](num_classes=n_cls)
19 | elif name.startswith('resnet') or name.startswith('seresnet'):
20 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=5, num_classes=n_cls)
21 | elif name.startswith('wrn'):
22 | model = model_dict[name](num_classes=n_cls)
23 | elif name.startswith('convnet'):
24 | model = model_dict[name](num_classes=n_cls)
25 | else:
26 | raise NotImplementedError('model {} not supported in dataset {}:'.format(name, dataset))
27 | elif dataset == 'CIFAR-FS' or dataset == 'FC100' or dataset=="toy":
28 |
29 | print("***********", name)
30 | if(name.endswith("kd")):
31 | model = model_dict[name](num_classes=n_cls)
32 | elif(name.endswith("sd")):
33 | model = model_dict[name]()
34 | elif(name.endswith("ssl")):
35 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=2, num_classes=n_cls)
36 | elif name.startswith('resnet') or name.startswith('seresnet'):
37 | model = model_dict[name](avg_pool=True, drop_rate=0.1, dropblock_size=2, num_classes=n_cls)
38 | elif name.startswith('convnet'):
39 | model = model_dict[name](num_classes=n_cls)
40 | else:
41 | raise NotImplementedError('model {} not supported in dataset {}:'.format(name, dataset))
42 | else:
43 | raise NotImplementedError('dataset not supported: {}'.format(dataset))
44 |
45 | return model
46 |
47 |
48 | def get_teacher_name(model_path):
49 | """parse to get teacher model name"""
50 | segments = model_path.split('/')[-2].split('_')
51 | if ':' in segments[0]:
52 | return segments[0].split(':')[-1]
53 | else:
54 | if segments[0] != 'wrn':
55 | return segments[0]
56 | else:
57 | return segments[0] + '_' + segments[1] + '_' + segments[2]
58 |
--------------------------------------------------------------------------------
/models/wresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | import sys
8 | import numpy as np
9 |
10 |
11 | def conv3x3(in_planes, out_planes, stride=1):
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
13 |
14 |
15 | def conv_init(m):
16 | classname = m.__class__.__name__
17 | if classname.find('Conv') != -1:
18 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
19 | init.constant(m.bias, 0)
20 | elif classname.find('BatchNorm') != -1:
21 | init.constant(m.weight, 1)
22 | init.constant(m.bias, 0)
23 |
24 |
25 | class wide_basic(nn.Module):
26 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
27 | super(wide_basic, self).__init__()
28 | self.bn1 = nn.BatchNorm2d(in_planes)
29 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
30 | self.dropout = nn.Dropout(p=dropout_rate)
31 | self.bn2 = nn.BatchNorm2d(planes)
32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
33 |
34 | self.shortcut = nn.Sequential()
35 | if stride != 1 or in_planes != planes:
36 | self.shortcut = nn.Sequential(
37 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
38 | )
39 |
40 | def forward(self, x):
41 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
42 | out = self.conv2(F.relu(self.bn2(out)))
43 | out += self.shortcut(x)
44 |
45 | return out
46 |
47 |
48 | class Wide_ResNet(nn.Module):
49 | def __init__(self, depth, widen_factor, dropout_rate, num_classes=-1):
50 | super(Wide_ResNet, self).__init__()
51 | self.in_planes = 16
52 |
53 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
54 | n = (depth-4) // 6
55 | k = widen_factor
56 |
57 | print('| Wide-Resnet %dx%d' %(depth, k))
58 | nStages = [16, 16*k, 32*k, 64*k]
59 |
60 | self.conv1 = conv3x3(3,nStages[0])
61 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
62 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
63 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
64 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
65 |
66 | self.num_classes = num_classes
67 | if self.num_classes > 0:
68 | self.classifier = nn.Linear(64*k, self.num_classes)
69 |
70 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
71 | strides = [stride] + [1]*(num_blocks-1)
72 | layers = []
73 |
74 | for stride in strides:
75 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
76 | self.in_planes = planes
77 |
78 | return nn.Sequential(*layers)
79 |
80 | def forward(self, x, is_feat=False):
81 | out = self.conv1(x)
82 | out = self.layer1(out)
83 | out = self.layer2(out)
84 | out = self.layer3(out)
85 | out = F.relu(self.bn1(out))
86 | out = F.adaptive_avg_pool2d(out, 1)
87 | out = out.view(out.size(0), -1)
88 | feat = out
89 | if self.num_classes > 0:
90 | out = self.classifier(out)
91 |
92 | if is_feat:
93 | return [feat], out
94 | else:
95 | return out
96 |
97 |
98 | def wrn_28_10(dropout_rate=0.3, num_classes=-1):
99 | return Wide_ResNet(28, 10, dropout_rate, num_classes)
100 |
101 |
102 | if __name__ == '__main__':
103 | net=Wide_ResNet(28, 10, 0.3)
104 | y = net(Variable(torch.randn(1,3,32,32)))
105 |
106 | print(y.size())
107 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.2.1
2 | mkl==2019.0
3 | numpy==1.18.4
4 | Pillow==7.1.2
5 | scikit_learn==0.23.1
6 | scipy==1.4.1
7 | torch==1.5.0
8 | torchvision==0.6.0
9 | tqdm==4.46.0
10 | wandb==0.8.36
11 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | ######################################################################################################################################################
2 | ######################################################################################################################################################
3 | ####################################################### CIFAR-FS ##########################################################
4 | ######################################################################################################################################################
5 | ######################################################################################################################################################
6 |
7 | # # # # self supervision
8 | # python3 train_supervised_ssl.py \
9 | # --tags cifarfs,may30 \
10 | # --model resnet12_ssl \
11 | # --model_path save/backup \
12 | # --dataset CIFAR-FS \
13 | # --data_root ../../Datasets/CIFAR_FS/ \
14 | # --n_aug_support_samples 5 \
15 | # --n_ways 5 \
16 | # --n_shots 1 \
17 | # --epochs 65 \
18 | # --lr_decay_epochs 60 \
19 | # --gamma 2.0 &
20 |
21 |
22 |
23 | # for i in {0..0}; do
24 | # python3 train_distillation.py \
25 | # --tags cifarfs,gen1,may30 \
26 | # --model_s resnet12_ssl \
27 | # --model_t resnet12_ssl \
28 | # --path_t save/backup/resnet12_ssl_CIFAR-FS_lr_0.05_decay_0.0005_trans_D_trial_1/model_firm-sun-394.pth \
29 | # --model_path save/backup \
30 | # --dataset CIFAR-FS \
31 | # --data_root ../../Datasets/CIFAR_FS/ \
32 | # --n_aug_support_samples 5 \
33 | # --n_ways 5 \
34 | # --n_shots 1 \
35 | # --epochs 65 \
36 | # --lr_decay_epochs 60 \
37 | # --gamma 0.1 &
38 | # sleep 1m
39 | # done
40 |
41 |
42 |
43 |
44 |
45 |
46 | # # # evaluation
47 | # CUDA_VISIBLE_DEVICES=0 python3 eval_fewshot.py \
48 | # --model resnet12_ssl \
49 | # --model_path save/backup2/resnet12_ssl_toy_lr_0.05_decay_0.0005_trans_A_trial_1/model_upbeat-dew-17.pth \
50 | # --dataset toy \
51 | # --data_root ../../Datasets/CIFAR_FS/ \
52 | # --n_aug_support_samples 5 \
53 | # --n_ways 5 \
54 | # --n_shots 1
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | ######################################################################################################################################################
66 | ######################################################################################################################################################
67 | ####################################################### FC100 ##########################################################
68 | ######################################################################################################################################################
69 | ######################################################################################################################################################
70 |
71 |
72 | # # # # GEN0
73 | # CUDA_VISIBLE_DEVICES=1 python3 train_supervised_ssl.py \
74 | # --model resnet12_ssl \
75 | # --model_path save/backup \
76 | # --tags fc100 \
77 | # --dataset FC100 \
78 | # --data_root ../Datasets/neurips2020/ \
79 | # --n_aug_support_samples 5 \
80 | # --n_ways 5 \
81 | # --n_shots 1 \
82 | # --epochs 65 \
83 | # --lr_decay_epochs 60 \
84 | # --gamma 2 &
85 |
86 |
87 | # # # # GEN1
88 | # for i in {0..0}; do
89 | # python3 train_distillation5.py \
90 | # --model_s resnet12_ssl \
91 | # --model_t resnet12_ssl \
92 | # --path_t save/backup/resnet12_ssl_FC100_lr_0.05_decay_0.0005_trans_D_trial_1/model_effortless-wood-315.pth \
93 | # --tags fc100 \
94 | # --model_path save/neurips2020 \
95 | # --dataset FC100 \
96 | # --data_root ../Datasets/FC100/ \
97 | # --n_aug_support_samples 5 \
98 | # --n_ways 5 \
99 | # --n_shots 1 \
100 | # --batch_size 64 \
101 | # --epochs 8 \
102 | # --lr_decay_epochs 3 \
103 | # --gamma 0.2 \
104 | # # sleep 4m
105 | # done
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 | ######################################################################################################################################################
124 | ######################################################################################################################################################
125 | ####################################################### miniImagenet ##########################################################
126 | ######################################################################################################################################################
127 | ######################################################################################################################################################
128 |
129 |
130 |
131 |
132 | # # # # GEN0
133 | # python3 train_supervised_ssl.py \
134 | # --model resnet12_ssl \
135 | # --model_path save/neurips2020 \
136 | # --tags miniimagenet,gen0 \
137 | # --dataset miniImageNet \
138 | # --data_root ../Datasets/MiniImagenet/ \
139 | # --n_aug_support_samples 5 \
140 | # --n_ways 5 \
141 | # --n_shots 1 \
142 | # --epochs 65 \
143 | # --lr_decay_epochs 60 \
144 | # --gamma 2.0 &
145 |
146 |
147 | # params=( 0.025 0.05 0.075 0.1 0.15 0.2 0.25 0.3 0.4 0.5 )
148 |
149 | # # # # GEN1
150 | # for i in {0..2}; do
151 | # python3 train_distillation.py \
152 | # --model_s resnet12_ssl \
153 | # --model_t resnet12_ssl \
154 | # --path_t save/neurips2020/resnet12_ssl_miniImageNet_lr_0.05_decay_0.0005_trans_A_trial_1/model_swift-lake-4.pth \
155 | # --tags miniimagenet,gen1,beta \
156 | # --model_path save/neurips2020 \
157 | # --dataset miniImageNet \
158 | # --data_root ../Datasets/MiniImagenet/ \
159 | # --n_aug_support_samples 5 \
160 | # --n_ways 5 \
161 | # --n_shots 1 \
162 | # --batch_size 64 \
163 | # --epochs 8 \
164 | # --lr_decay_epochs 5 \
165 | # --gamma ${params[i]} &
166 | # sleep 30m
167 | # done
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 | ######################################################################################################################################################
185 | ######################################################################################################################################################
186 | ####################################################### tieredImageNet ##########################################################
187 | ######################################################################################################################################################
188 | ######################################################################################################################################################
189 |
190 |
191 |
192 |
193 | # # # # GEN0
194 | # python3 train_supervised_ssl.py \
195 | # --model resnet12_ssl \
196 | # --model_path save/backup \
197 | # --tags tieredimageNet \
198 | # --dataset tieredImageNet \
199 | # --data_root ../Datasets/TieredImagenet/ \
200 | # --n_aug_support_samples 5 \
201 | # --n_ways 5 \
202 | # --n_shots 1 \
203 | # --epochs 60 \
204 | # --lr_decay_epochs 30,40,50 \
205 | # --gamma 2
206 |
207 |
208 |
209 |
210 | # # # # GEN1
211 | # for i in {0..6}; do
212 | # python3 train_distillation5.py \
213 | # --model_s resnet12_ssl \
214 | # --model_t resnet12_ssl \
215 | # --path_t save/backup/resnet12_ssl_FC100_lr_0.05_decay_0.0005_trans_D_trial_1/model_effortless-wood-315.pth \
216 | # --tags fc100 \
217 | # --model_path save/backup \
218 | # --dataset FC100 \
219 | # --data_root ../../Datasets/FC100/ \
220 | # --n_aug_support_samples 5 \
221 | # --n_ways 5 \
222 | # --n_shots 1 \
223 | # --batch_size 64 \
224 | # --epochs 8 \
225 | # --lr_decay_epochs 3 \
226 | # --gamma ${params[i]} \
227 | # # sleep 4m
228 | # done
229 |
230 |
--------------------------------------------------------------------------------
/train_distillation.py:
--------------------------------------------------------------------------------
1 | """
2 | the general training framework
3 | """
4 |
5 | from __future__ import print_function
6 |
7 | import os
8 | import argparse
9 | import socket
10 | import time
11 | import sys
12 | from tqdm import tqdm
13 | import mkl
14 |
15 | import torch
16 | import torch.optim as optim
17 | import torch.nn as nn
18 | import torch.backends.cudnn as cudnn
19 | from torch.utils.data import DataLoader
20 | import torch.nn.functional as F
21 |
22 | from models import model_pool
23 | from models.util import create_model, get_teacher_name
24 |
25 | from distill.util import Embed
26 | from distill.criterion import DistillKL, NCELoss, Attention, HintLoss
27 |
28 | from dataset.mini_imagenet import ImageNet, MetaImageNet
29 | from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet
30 | from dataset.cifar import CIFAR100, MetaCIFAR100
31 | from dataset.transform_cfg import transforms_options, transforms_list
32 |
33 | from util import adjust_learning_rate, accuracy, AverageMeter
34 | from eval.meta_eval import meta_test, meta_test_tune
35 | from eval.cls_eval import validate
36 |
37 | from models.resnet import resnet12
38 | import numpy as np
39 | from util import Logger
40 | import wandb
41 | from dataloader import get_dataloaders
42 | import copy
43 |
44 | def get_freer_gpu():
45 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
46 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
47 | return np.argmax(memory_available)
48 |
49 | os.environ["CUDA_VISIBLE_DEVICES"]=str(get_freer_gpu())
50 | # os.environ['OPENBLAS_NUM_THREADS'] = '4'
51 | mkl.set_num_threads(2)
52 |
53 |
54 | class Wrapper(nn.Module):
55 |
56 | def __init__(self, model, args):
57 | super(Wrapper, self).__init__()
58 |
59 | self.model = model
60 | self.feat = torch.nn.Sequential(*list(self.model.children())[:-2])
61 |
62 | self.last = torch.nn.Linear(list(self.model.children())[-2].in_features, 64)
63 |
64 | def forward(self, images):
65 | feat = self.feat(images)
66 | feat = feat.view(images.size(0), -1)
67 | out = self.last(feat)
68 |
69 | return feat, out
70 |
71 |
72 |
73 | def parse_option():
74 |
75 | parser = argparse.ArgumentParser('argument for training')
76 |
77 | parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency')
78 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
79 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
80 | parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
81 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
82 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
83 | parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
84 |
85 | # optimization
86 | parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate')
87 | parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list')
88 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
89 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
90 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
91 |
92 | # dataset and model
93 | parser.add_argument('--model_s', type=str, default='resnet12', choices=model_pool)
94 | parser.add_argument('--model_t', type=str, default='resnet12', choices=model_pool)
95 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
96 | 'CIFAR-FS', 'FC100'])
97 | parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
98 | parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning')
99 | parser.add_argument('--tags', type=str, default="gen1, ssl", help='add tags for the experiment')
100 | parser.add_argument('--transform', type=str, default='A', choices=transforms_list)
101 |
102 | # path to teacher model
103 | parser.add_argument('--path_t', type=str, default="", help='teacher model snapshot')
104 |
105 | # distillation
106 | parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'contrast', 'hint', 'attention'])
107 | parser.add_argument('--trial', type=str, default='1', help='trial id')
108 |
109 | parser.add_argument('-r', '--gamma', type=float, default=1, help='weight for classification')
110 | parser.add_argument('-a', '--alpha', type=float, default=0, help='weight balance for KD')
111 | parser.add_argument('-b', '--beta', type=float, default=0, help='weight balance for other losses')
112 |
113 | # KL distillation
114 | parser.add_argument('--kd_T', type=float, default=2, help='temperature for KD distillation')
115 | # NCE distillation
116 | parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
117 | parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
118 | parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
119 | parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')
120 |
121 | # cosine annealing
122 | parser.add_argument('--cosine', action='store_true', help='using cosine annealing')
123 |
124 | # specify folder
125 | parser.add_argument('--model_path', type=str, default='save/', help='path to save model')
126 | parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard')
127 | parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')
128 |
129 | # setting for meta-learning
130 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
131 | help='Number of test runs')
132 | parser.add_argument('--n_ways', type=int, default=5, metavar='N',
133 | help='Number of classes for doing each classification run')
134 | parser.add_argument('--n_shots', type=int, default=1, metavar='N',
135 | help='Number of shots in test')
136 | parser.add_argument('--n_queries', type=int, default=15, metavar='N',
137 | help='Number of query in test')
138 | parser.add_argument('--n_aug_support_samples', default=5, type=int,
139 | help='The number of augmented samples for each meta test sample')
140 | parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
141 | help='Size of test batch)')
142 |
143 | opt = parser.parse_args()
144 |
145 | if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
146 | opt.transform = 'D'
147 |
148 | if 'trainval' in opt.path_t:
149 | opt.use_trainval = True
150 | else:
151 | opt.use_trainval = False
152 |
153 | if opt.use_trainval:
154 | opt.trial = opt.trial + '_trainval'
155 |
156 | # set the path according to the environment
157 | if not opt.model_path:
158 | opt.model_path = './models_distilled'
159 | if not opt.tb_path:
160 | opt.tb_path = './tensorboard'
161 | if not opt.data_root:
162 | opt.data_root = './data/{}'.format(opt.dataset)
163 | else:
164 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
165 | opt.data_aug = True
166 |
167 | tags = opt.tags.split(',')
168 | opt.tags = list([])
169 | for it in tags:
170 | opt.tags.append(it)
171 |
172 | iterations = opt.lr_decay_epochs.split(',')
173 | opt.lr_decay_epochs = list([])
174 | for it in iterations:
175 | opt.lr_decay_epochs.append(int(it))
176 |
177 | opt.model_name = 'S:{}_T:{}_{}_{}_r:{}_a:{}_b:{}_trans_{}'.format(opt.model_s, opt.model_t, opt.dataset,
178 | opt.distill, opt.gamma, opt.alpha, opt.beta,
179 | opt.transform)
180 |
181 | if opt.cosine:
182 | opt.model_name = '{}_cosine'.format(opt.model_name)
183 |
184 | opt.model_name = '{}_{}'.format(opt.model_name, opt.trial)
185 |
186 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
187 | if not os.path.isdir(opt.tb_folder):
188 | os.makedirs(opt.tb_folder)
189 |
190 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
191 | if not os.path.isdir(opt.save_folder):
192 | os.makedirs(opt.save_folder)
193 |
194 | #extras
195 | opt.fresh_start = True
196 |
197 |
198 | return opt
199 |
200 |
201 |
202 |
203 |
204 | def load_teacher(model_path, model_name, n_cls, dataset='miniImageNet'):
205 | """load the teacher model"""
206 | print('==> loading teacher model')
207 | print(model_name)
208 | model = create_model(model_name, n_cls, dataset)
209 | model.load_state_dict(torch.load(model_path)['model'])
210 | print('==> done')
211 | return model
212 |
213 |
214 | def main():
215 | best_acc = 0
216 |
217 | opt = parse_option()
218 | wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags)
219 | wandb.config.update(opt)
220 | wandb.save('*.py')
221 | wandb.run.save()
222 |
223 |
224 | # dataloader
225 | train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt)
226 |
227 | # model
228 | model_t = []
229 | if("," in opt.path_t):
230 | for path in opt.path_t.split(","):
231 | model_t.append(load_teacher(path, opt.model_t, n_cls, opt.dataset))
232 | else:
233 | model_t.append(load_teacher(opt.path_t, opt.model_t, n_cls, opt.dataset))
234 |
235 | # model_s = create_model(opt.model_s, n_cls, opt.dataset, dropout=0.4)
236 | # model_s = Wrapper(model_, opt)
237 | model_s = copy.deepcopy(model_t[0])
238 |
239 | criterion_cls = nn.CrossEntropyLoss()
240 | criterion_div = DistillKL(opt.kd_T)
241 | criterion_kd = DistillKL(opt.kd_T)
242 |
243 | optimizer = optim.SGD(model_s.parameters(),
244 | lr=opt.learning_rate,
245 | momentum=opt.momentum,
246 | weight_decay=opt.weight_decay)
247 |
248 |
249 |
250 |
251 | if torch.cuda.is_available():
252 | for m in model_t:
253 | m.cuda()
254 | model_s.cuda()
255 | criterion_cls = criterion_cls.cuda()
256 | criterion_div = criterion_div.cuda()
257 | criterion_kd = criterion_kd.cuda()
258 | cudnn.benchmark = True
259 |
260 |
261 | meta_test_acc = 0
262 | meta_test_std = 0
263 | # routine: supervised model distillation
264 | for epoch in range(1, opt.epochs + 1):
265 |
266 | if opt.cosine:
267 | scheduler.step()
268 | else:
269 | adjust_learning_rate(epoch, opt, optimizer)
270 | print("==> training...")
271 |
272 | time1 = time.time()
273 | train_acc, train_loss = train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt)
274 | time2 = time.time()
275 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
276 |
277 | val_acc = 0
278 | val_loss = 0
279 | meta_val_acc = 0
280 | meta_val_std = 0
281 | # val_acc, val_acc_top5, val_loss = validate(val_loader, model_s, criterion_cls, opt)
282 |
283 |
284 | # #evaluate
285 | # start = time.time()
286 | # meta_val_acc, meta_val_std = meta_test(model_s, meta_valloader)
287 | # test_time = time.time() - start
288 | # print('Meta Val Acc: {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(meta_val_acc, meta_val_std, test_time))
289 |
290 | #evaluate
291 |
292 | start = time.time()
293 | meta_test_acc, meta_test_std = meta_test(model_s, meta_testloader, use_logit=False)
294 | test_time = time.time() - start
295 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.format(meta_test_acc, meta_test_std, test_time))
296 |
297 |
298 | # regular saving
299 | if epoch % opt.save_freq == 0 or epoch==opt.epochs:
300 | print('==> Saving...')
301 | state = {
302 | 'epoch': epoch,
303 | 'model': model_s.state_dict(),
304 | }
305 | save_file = os.path.join(opt.save_folder, 'model_'+str(wandb.run.name)+'.pth')
306 | torch.save(state, save_file)
307 |
308 | #wandb saving
309 | torch.save(state, os.path.join(wandb.run.dir, "model.pth"))
310 |
311 | wandb.log({'epoch': epoch,
312 | 'Train Acc': train_acc,
313 | 'Train Loss':train_loss,
314 | 'Val Acc': val_acc,
315 | 'Val Loss':val_loss,
316 | 'Meta Test Acc': meta_test_acc,
317 | 'Meta Test std': meta_test_std,
318 | 'Meta Val Acc': meta_val_acc,
319 | 'Meta Val std': meta_val_std
320 | })
321 |
322 | #final report
323 | generate_final_report(model_s, opt, wandb)
324 |
325 | #remove output.txt log file
326 | output_log_file = os.path.join(wandb.run.dir, "output.log")
327 | if os.path.isfile(output_log_file):
328 | os.remove(output_log_file)
329 | else: ## Show an error ##
330 | print("Error: %s file not found" % output_log_file)
331 |
332 |
333 |
334 |
335 | def train(epoch, train_loader, model_s, model_t , criterion_cls, criterion_div, criterion_kd, optimizer, opt):
336 | """One epoch training"""
337 | model_s.train()
338 | for m in model_t:
339 | m.eval()
340 |
341 | batch_time = AverageMeter()
342 | data_time = AverageMeter()
343 | losses = AverageMeter()
344 | top1 = AverageMeter()
345 | top5 = AverageMeter()
346 |
347 | end = time.time()
348 |
349 | with tqdm(train_loader, total=len(train_loader)) as pbar:
350 | for idx, data in enumerate(pbar):
351 |
352 | inputs, targets, _ = data
353 | data_time.update(time.time() - end)
354 |
355 | inputs = inputs.float()
356 | if torch.cuda.is_available():
357 | inputs = inputs.cuda()
358 | targets = targets.cuda()
359 |
360 | batch_size = inputs.size()[0]
361 | x = inputs
362 |
363 | x_90 = x.transpose(2,3).flip(2)
364 | # x_180 = x.flip(2).flip(3)
365 | # x_270 = x.flip(2).transpose(2,3)
366 | # inputs_aug = torch.cat((x_90, x_180, x_270),0)
367 |
368 |
369 | # sampled_inputs = inputs_aug[torch.randperm(3*batch_size)[:batch_size]]
370 | inputs_all = torch.cat((x, x_180, x_90, x_270),0)
371 |
372 | # ===================forward=====================
373 |
374 | with torch.no_grad():
375 | (_,_,_,_, feat_t), (logit_t, rot_t) = model_t[0](inputs_all[:batch_size], rot=True)
376 |
377 | (_,_,_,_, feat_s_all), (logit_s_all, rot_s_all) = model_s(inputs_all[:4*batch_size], rot=True)
378 |
379 | loss_div = criterion_div(logit_s_all[:batch_size], logit_t[:batch_size])
380 |
381 | d_90 = logit_s_all[batch_size:2*batch_size] - logit_s_all[:batch_size]
382 | loss_a = torch.mean(torch.sqrt(torch.sum((d_90)**2, dim=1)))
383 | # d_180 = logit_s_all[2*batch_size:3*batch_size] - logit_s_all[:batch_size]
384 | # loss_a += torch.mean(torch.sqrt(torch.sum((d_180)**2, dim=1)))
385 | # d_270 = logit_s_all[3*batch_size:4*batch_size] - logit_s_all[:batch_size]
386 | # loss_a += torch.mean(torch.sqrt(torch.sum((d_270)**2, dim=1)))
387 |
388 |
389 | if(torch.isnan(loss_a).any()):
390 | break
391 | else:
392 | loss = loss_div + opt.gamma*loss_a / 3
393 |
394 |
395 | acc1, acc5 = accuracy(logit_s_all[:batch_size], targets, topk=(1, 5))
396 | losses.update(loss.item(), inputs.size(0))
397 | top1.update(acc1[0], inputs.size(0))
398 | top5.update(acc5[0], inputs.size(0))
399 |
400 | # ===================backward=====================
401 | optimizer.zero_grad()
402 | loss.backward()
403 | optimizer.step()
404 |
405 | # ===================meters=====================
406 | batch_time.update(time.time() - end)
407 | end = time.time()
408 |
409 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()),
410 | "Acc@5":'{0:.2f}'.format(top5.avg.cpu().numpy(),2),
411 | "Loss" :'{0:.2f}'.format(losses.avg,2),
412 | })
413 |
414 |
415 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
416 | .format(top1=top1, top5=top5))
417 |
418 | return top1.avg, losses.avg
419 |
420 |
421 | def generate_final_report(model, opt, wandb):
422 |
423 |
424 | opt.n_shots = 1
425 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
426 |
427 | #validate
428 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True)
429 |
430 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)
431 |
432 | #evaluate
433 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True)
434 |
435 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
436 |
437 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
438 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
439 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
440 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))
441 |
442 |
443 | wandb.log({'Final Meta Test Acc @1': meta_test_acc,
444 | 'Final Meta Test std @1': meta_test_std,
445 | 'Final Meta Test Acc (feat) @1': meta_test_acc_feat,
446 | 'Final Meta Test std (feat) @1': meta_test_std_feat,
447 | 'Final Meta Val Acc @1': meta_val_acc,
448 | 'Final Meta Val std @1': meta_val_std,
449 | 'Final Meta Val Acc (feat) @1': meta_val_acc_feat,
450 | 'Final Meta Val std (feat) @1': meta_val_std_feat
451 | })
452 |
453 |
454 | opt.n_shots = 5
455 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
456 |
457 | #validate
458 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True)
459 |
460 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)
461 |
462 | #evaluate
463 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True)
464 |
465 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
466 |
467 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
468 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
469 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
470 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))
471 |
472 | wandb.log({'Final Meta Test Acc @5': meta_test_acc,
473 | 'Final Meta Test std @5': meta_test_std,
474 | 'Final Meta Test Acc (feat) @5': meta_test_acc_feat,
475 | 'Final Meta Test std (feat) @5': meta_test_std_feat,
476 | 'Final Meta Val Acc @5': meta_val_acc,
477 | 'Final Meta Val std @5': meta_val_std,
478 | 'Final Meta Val Acc (feat) @5': meta_val_acc_feat,
479 | 'Final Meta Val std (feat) @5': meta_val_std_feat
480 | })
481 |
482 |
483 | if __name__ == '__main__':
484 | main()
485 |
--------------------------------------------------------------------------------
/train_selfsupervison.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 | import socket
6 | import time
7 | import sys
8 | from tqdm import tqdm
9 | import mkl
10 |
11 | # import tensorboard_logger as tb_logger
12 | import torch
13 | import torch.optim as optim
14 | import torch.nn as nn
15 | import torch.backends.cudnn as cudnn
16 | from torch.utils.data import DataLoader
17 | import torch.nn.functional as F
18 | from torch.autograd import Variable
19 |
20 | from models import model_pool
21 | from models.util import create_model
22 |
23 | from dataset.mini_imagenet import ImageNet, MetaImageNet
24 | from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet
25 | from dataset.cifar import CIFAR100, MetaCIFAR100
26 | from dataset.transform_cfg import transforms_options, transforms_test_options, transforms_list
27 |
28 | from util import adjust_learning_rate, accuracy, AverageMeter
29 | from eval.meta_eval import meta_test, meta_test_tune
30 | from eval.cls_eval import validate
31 |
32 | from models.resnet import resnet12
33 | import numpy as np
34 | from util import Logger
35 | import wandb
36 | from dataloader import get_dataloaders
37 |
38 | def get_freer_gpu():
39 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
40 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
41 | return np.argmax(memory_available)
42 |
43 | os.environ["CUDA_VISIBLE_DEVICES"]=str(get_freer_gpu())
44 | mkl.set_num_threads(2)
45 |
46 |
47 | def parse_option():
48 |
49 | parser = argparse.ArgumentParser('argument for training')
50 |
51 | parser.add_argument('--eval_freq', type=int, default=10, help='meta-eval frequency')
52 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
53 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
54 | parser.add_argument('--save_freq', type=int, default=10, help='save frequency')
55 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
56 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
57 | parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
58 |
59 | # optimization
60 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
61 | parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr, can be a list')
62 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
63 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
64 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
65 | parser.add_argument('--adam', action='store_true', help='use adam optimizer')
66 | parser.add_argument('--simclr', type=bool, default=False, help='use simple contrastive learning representation')
67 | parser.add_argument('--ssl', type=bool, default=True, help='use self supervised learning')
68 | parser.add_argument('--tags', type=str, default="gen0, ssl", help='add tags for the experiment')
69 |
70 |
71 | # dataset
72 | parser.add_argument('--model', type=str, default='resnet12', choices=model_pool)
73 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'tieredImageNet',
74 | 'CIFAR-FS', 'FC100'])
75 | parser.add_argument('--transform', type=str, default='A', choices=transforms_list)
76 | parser.add_argument('--use_trainval', type=bool, help='use trainval set')
77 |
78 | # cosine annealing
79 | parser.add_argument('--cosine', action='store_true', help='using cosine annealing')
80 |
81 | # specify folder
82 | parser.add_argument('--model_path', type=str, default='save/', help='path to save model')
83 | parser.add_argument('--tb_path', type=str, default='tb/', help='path to tensorboard')
84 | parser.add_argument('--data_root', type=str, default='/raid/data/IncrementLearn/imagenet/Datasets/MiniImagenet/', help='path to data root')
85 |
86 | # meta setting
87 | parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
88 | help='Number of test runs')
89 | parser.add_argument('--n_ways', type=int, default=5, metavar='N',
90 | help='Number of classes for doing each classification run')
91 | parser.add_argument('--n_shots', type=int, default=1, metavar='N',
92 | help='Number of shots in test')
93 | parser.add_argument('--n_queries', type=int, default=15, metavar='N',
94 | help='Number of query in test')
95 | parser.add_argument('--n_aug_support_samples', default=5, type=int,
96 | help='The number of augmented samples for each meta test sample')
97 | parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
98 | help='Size of test batch)')
99 |
100 | parser.add_argument('-t', '--trial', type=str, default='1', help='the experiment id')
101 |
102 |
103 |
104 | #hyper parameters
105 | parser.add_argument('--gamma', type=float, default=2, help='loss cofficient for ssl loss')
106 |
107 | opt = parser.parse_args()
108 |
109 | if opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
110 | opt.transform = 'D'
111 |
112 | if opt.use_trainval:
113 | opt.trial = opt.trial + '_trainval'
114 |
115 | # set the path according to the environment
116 | if not opt.model_path:
117 | opt.model_path = './models_pretrained'
118 | if not opt.tb_path:
119 | opt.tb_path = './tensorboard'
120 | if not opt.data_root:
121 | opt.data_root = './data/{}'.format(opt.dataset)
122 | else:
123 | opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
124 | opt.data_aug = True
125 |
126 | iterations = opt.lr_decay_epochs.split(',')
127 | opt.lr_decay_epochs = list([])
128 | for it in iterations:
129 | opt.lr_decay_epochs.append(int(it))
130 |
131 | tags = opt.tags.split(',')
132 | opt.tags = list([])
133 | for it in tags:
134 | opt.tags.append(it)
135 |
136 | opt.model_name = '{}_{}_lr_{}_decay_{}_trans_{}'.format(opt.model, opt.dataset, opt.learning_rate,
137 | opt.weight_decay, opt.transform)
138 |
139 | if opt.cosine:
140 | opt.model_name = '{}_cosine'.format(opt.model_name)
141 |
142 | if opt.adam:
143 | opt.model_name = '{}_useAdam'.format(opt.model_name)
144 |
145 | opt.model_name = '{}_trial_{}'.format(opt.model_name, opt.trial)
146 |
147 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
148 | if not os.path.isdir(opt.tb_folder):
149 | os.makedirs(opt.tb_folder)
150 |
151 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
152 | if not os.path.isdir(opt.save_folder):
153 | os.makedirs(opt.save_folder)
154 |
155 | opt.n_gpu = torch.cuda.device_count()
156 |
157 |
158 | #extras
159 | opt.fresh_start = True
160 |
161 |
162 | return opt
163 |
164 |
165 | def main():
166 |
167 | opt = parse_option()
168 | wandb.init(project=opt.model_path.split("/")[-1], tags=opt.tags)
169 | wandb.config.update(opt)
170 | wandb.save('*.py')
171 | wandb.run.save()
172 |
173 |
174 | train_loader, val_loader, meta_testloader, meta_valloader, n_cls = get_dataloaders(opt)
175 |
176 | # model
177 | model = create_model(opt.model, n_cls, opt.dataset)
178 | wandb.watch(model)
179 |
180 | # optimizer
181 | if opt.adam:
182 | print("Adam")
183 | optimizer = torch.optim.Adam(model.parameters(),
184 | lr=opt.learning_rate,
185 | weight_decay=0.0005)
186 | else:
187 | print("SGD")
188 | optimizer = optim.SGD(model.parameters(),
189 | lr=opt.learning_rate,
190 | momentum=opt.momentum,
191 | weight_decay=opt.weight_decay)
192 |
193 |
194 |
195 | criterion = nn.CrossEntropyLoss()
196 |
197 | if torch.cuda.is_available():
198 | if opt.n_gpu > 1:
199 | model = nn.DataParallel(model)
200 | model = model.cuda()
201 | criterion = criterion.cuda()
202 | cudnn.benchmark = True
203 |
204 | # set cosine annealing scheduler
205 | if opt.cosine:
206 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
207 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.epochs, eta_min, -1)
208 |
209 | # routine: supervised pre-training
210 | for epoch in range(1, opt.epochs + 1):
211 | if opt.cosine:
212 | scheduler.step()
213 | else:
214 | adjust_learning_rate(epoch, opt, optimizer)
215 | print("==> training...")
216 |
217 |
218 | time1 = time.time()
219 | train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt)
220 | time2 = time.time()
221 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
222 |
223 |
224 | val_acc, val_acc_top5, val_loss = 0,0,0 #validate(val_loader, model, criterion, opt)
225 |
226 |
227 | #validate
228 | start = time.time()
229 | meta_val_acc, meta_val_std = 0,0#meta_test(model, meta_valloader)
230 | test_time = time.time() - start
231 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}, Time: {:.1f}'.format(meta_val_acc, meta_val_std, test_time))
232 |
233 | #evaluate
234 | start = time.time()
235 | meta_test_acc, meta_test_std = 0,0#meta_test(model, meta_testloader)
236 | test_time = time.time() - start
237 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}, Time: {:.1f}'.format(meta_test_acc, meta_test_std, test_time))
238 |
239 |
240 | # regular saving
241 | if epoch % opt.save_freq == 0 or epoch==opt.epochs:
242 | print('==> Saving...')
243 | state = {
244 | 'epoch': epoch,
245 | 'optimizer': optimizer.state_dict(),
246 | 'model': model.state_dict(),
247 | }
248 | save_file = os.path.join(opt.save_folder, 'model_'+str(wandb.run.name)+'.pth')
249 | torch.save(state, save_file)
250 |
251 | #wandb saving
252 | torch.save(state, os.path.join(wandb.run.dir, "model.pth"))
253 |
254 | ## onnx saving
255 | #dummy_input = torch.autograd.Variable(torch.randn(1, 3, 32, 32)).cuda()
256 | #torch.onnx.export(model, dummy_input, os.path.join(wandb.run.dir, "model.onnx"))
257 |
258 | wandb.log({'epoch': epoch,
259 | 'Train Acc': train_acc,
260 | 'Train Loss':train_loss,
261 | 'Val Acc': val_acc,
262 | 'Val Loss':val_loss,
263 | 'Meta Test Acc': meta_test_acc,
264 | 'Meta Test std': meta_test_std,
265 | 'Meta Val Acc': meta_val_acc,
266 | 'Meta Val std': meta_val_std
267 | })
268 |
269 | #final report
270 | generate_final_report(model, opt, wandb)
271 |
272 | #remove output.txt log file
273 | output_log_file = os.path.join(wandb.run.dir, "output.log")
274 | if os.path.isfile(output_log_file):
275 | os.remove(output_log_file)
276 | else: ## Show an error ##
277 | print("Error: %s file not found" % output_log_file)
278 |
279 |
280 |
281 | def train(epoch, train_loader, model, criterion, optimizer, opt):
282 | """One epoch training"""
283 | model.train()
284 |
285 | batch_time = AverageMeter()
286 | data_time = AverageMeter()
287 | losses = AverageMeter()
288 | top1 = AverageMeter()
289 | top5 = AverageMeter()
290 |
291 | end = time.time()
292 | with tqdm(train_loader, total=len(train_loader)) as pbar:
293 | for idx, (input, target, _) in enumerate(pbar):
294 | data_time.update(time.time() - end)
295 |
296 | input = input.float()
297 | if torch.cuda.is_available():
298 | input = input.cuda()
299 | target = target.cuda()
300 |
301 |
302 | batch_size = input.size()[0]
303 | x = input
304 | x_90 = x.transpose(2,3).flip(2)
305 | x_180 = x.flip(2).flip(3)
306 | x_270 = x.flip(2).transpose(2,3)
307 | generated_data = torch.cat((x, x_90, x_180, x_270),0)
308 | train_targets = target.repeat(4)
309 |
310 | rot_labels = torch.zeros(4*batch_size).cuda().long()
311 | for i in range(4*batch_size):
312 | if i < batch_size:
313 | rot_labels[i] = 0
314 | elif i < 2*batch_size:
315 | rot_labels[i] = 1
316 | elif i < 3*batch_size:
317 | rot_labels[i] = 2
318 | else:
319 | rot_labels[i] = 3
320 |
321 | # ===================forward=====================
322 |
323 | (_,_,_,_, feat), (train_logit, rot_logits) = model(generated_data, rot=True)
324 |
325 | rot_labels = F.one_hot(rot_labels.to(torch.int64), 4).float()
326 | loss_ss = torch.sum(F.binary_cross_entropy_with_logits(input = rot_logits, target = rot_labels))
327 | loss_ce = criterion(train_logit, train_targets)
328 |
329 | loss = opt.gamma * loss_ss + loss_ce
330 |
331 | acc1, acc5 = accuracy(train_logit, train_targets, topk=(1, 5))
332 | losses.update(loss.item(), input.size(0))
333 | top1.update(acc1[0], input.size(0))
334 | top5.update(acc5[0], input.size(0))
335 |
336 | # ===================backward=====================
337 | optimizer.zero_grad()
338 | loss.backward()
339 | optimizer.step()
340 |
341 | # ===================meters=====================
342 | batch_time.update(time.time() - end)
343 | end = time.time()
344 |
345 |
346 | pbar.set_postfix({"Acc@1":'{0:.2f}'.format(top1.avg.cpu().numpy()),
347 | "Acc@5":'{0:.2f}'.format(top5.avg.cpu().numpy(),2),
348 | "Loss" :'{0:.2f}'.format(losses.avg,2),
349 | })
350 |
351 | print('Train_Acc@1 {top1.avg:.3f} Train_Acc@5 {top5.avg:.3f}'
352 | .format(top1=top1, top5=top5))
353 |
354 | return top1.avg, losses.avg
355 |
356 |
357 |
358 | def generate_final_report(model, opt, wandb):
359 |
360 |
361 | opt.n_shots = 1
362 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
363 |
364 | #validate
365 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True)
366 |
367 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)
368 |
369 | #evaluate
370 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True)
371 |
372 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
373 |
374 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
375 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
376 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
377 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))
378 |
379 |
380 | wandb.log({'Final Meta Test Acc @1': meta_test_acc,
381 | 'Final Meta Test std @1': meta_test_std,
382 | 'Final Meta Test Acc (feat) @1': meta_test_acc_feat,
383 | 'Final Meta Test std (feat) @1': meta_test_std_feat,
384 | 'Final Meta Val Acc @1': meta_val_acc,
385 | 'Final Meta Val std @1': meta_val_std,
386 | 'Final Meta Val Acc (feat) @1': meta_val_acc_feat,
387 | 'Final Meta Val std (feat) @1': meta_val_std_feat
388 | })
389 |
390 |
391 | opt.n_shots = 5
392 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
393 |
394 | #validate
395 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader, use_logit=True)
396 |
397 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)
398 |
399 | #evaluate
400 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader, use_logit=True)
401 |
402 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
403 |
404 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
405 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
406 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
407 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))
408 |
409 | wandb.log({'Final Meta Test Acc @5': meta_test_acc,
410 | 'Final Meta Test std @5': meta_test_std,
411 | 'Final Meta Test Acc (feat) @5': meta_test_acc_feat,
412 | 'Final Meta Test std (feat) @5': meta_test_std_feat,
413 | 'Final Meta Val Acc @5': meta_val_acc,
414 | 'Final Meta Val std @5': meta_val_std,
415 | 'Final Meta Val Acc (feat) @5': meta_val_acc_feat,
416 | 'Final Meta Val std (feat) @5': meta_val_std_feat
417 | })
418 |
419 |
420 | if __name__ == '__main__':
421 | main()
422 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import os
7 | import sys
8 | from dataloader import get_dataloaders
9 |
10 |
11 | class LabelSmoothing(nn.Module):
12 | """
13 | NLL loss with label smoothing.
14 | """
15 | def __init__(self, smoothing=0.0):
16 | """
17 | Constructor for the LabelSmoothing module.
18 | :param smoothing: label smoothing factor
19 | """
20 | super(LabelSmoothing, self).__init__()
21 | self.confidence = 1.0 - smoothing
22 | self.smoothing = smoothing
23 |
24 | def forward(self, x, target):
25 | logprobs = torch.nn.functional.log_softmax(x, dim=-1)
26 |
27 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
28 | nll_loss = nll_loss.squeeze(1)
29 | smooth_loss = -logprobs.mean(dim=-1)
30 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
31 | return loss.mean()
32 |
33 |
34 | class BCEWithLogitsLoss(nn.Module):
35 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None, num_classes=64):
36 | super(BCEWithLogitsLoss, self).__init__()
37 | self.num_classes = num_classes
38 | self.criterion = nn.BCEWithLogitsLoss(weight=weight,
39 | size_average=size_average,
40 | reduce=reduce,
41 | reduction=reduction,
42 | pos_weight=pos_weight)
43 | def forward(self, input, target):
44 | target_onehot = F.one_hot(target, num_classes=self.num_classes)
45 | return self.criterion(input, target_onehot)
46 |
47 |
48 | class AverageMeter(object):
49 | """Computes and stores the average and current value"""
50 | def __init__(self):
51 | self.reset()
52 |
53 | def reset(self):
54 | self.val = 0
55 | self.avg = 0
56 | self.sum = 0
57 | self.count = 0
58 |
59 | def update(self, val, n=1):
60 | self.val = val
61 | self.sum += val * n
62 | self.count += n
63 | self.avg = self.sum / self.count
64 |
65 |
66 | def adjust_learning_rate(epoch, opt, optimizer):
67 | """Sets the learning rate to the initial LR decayed by decay rate every steep step"""
68 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
69 | if steps > 0:
70 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps)
71 | for param_group in optimizer.param_groups:
72 | param_group['lr'] = new_lr
73 |
74 |
75 | def accuracy(output, target, topk=(1,)):
76 | """Computes the accuracy over the k top predictions for the specified values of k"""
77 | with torch.no_grad():
78 | maxk = max(topk)
79 | batch_size = target.size(0)
80 |
81 | _, pred = output.topk(maxk, 1, True, True)
82 | pred = pred.t()
83 | correct = pred.eq(target.view(1, -1).expand_as(pred))
84 |
85 | res = []
86 | for k in topk:
87 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
88 | res.append(correct_k.mul_(100.0 / batch_size))
89 | return res
90 |
91 |
92 |
93 |
94 | class Logger(object):
95 | '''Save training process to log file with simple plot function.'''
96 | def __init__(self, fpath, title=None, resume=False):
97 | self.file = None
98 | self.resume = resume
99 | self.title = '' if title == None else title
100 | if fpath is not None:
101 | if resume:
102 | self.file = open(fpath, 'r')
103 | name = self.file.readline()
104 | self.names = name.rstrip().split('\t')
105 | self.numbers = {}
106 | for _, name in enumerate(self.names):
107 | self.numbers[name] = []
108 |
109 | for numbers in self.file:
110 | numbers = numbers.rstrip().split('\t')
111 | for i in range(0, len(numbers)):
112 | self.numbers[self.names[i]].append(numbers[i])
113 | self.file.close()
114 | self.file = open(fpath, 'a')
115 | else:
116 | self.file = open(fpath, 'w')
117 |
118 | def set_names(self, names):
119 | if self.resume:
120 | pass
121 | # initialize numbers as empty list
122 | self.numbers = {}
123 | self.names = names
124 | for _, name in enumerate(self.names):
125 | self.file.write(name)
126 | self.file.write('\t')
127 | self.numbers[name] = []
128 | self.file.write('\n')
129 | self.file.flush()
130 |
131 |
132 | def append(self, numbers):
133 | assert len(self.names) == len(numbers), 'Numbers do not match names'
134 | for index, num in enumerate(numbers):
135 | self.file.write("{0:.6f}".format(num))
136 | self.file.write('\t')
137 | self.numbers[self.names[index]].append(num)
138 | self.file.write('\n')
139 | self.file.flush()
140 |
141 | def plot(self, names=None):
142 | names = self.names if names == None else names
143 | numbers = self.numbers
144 | for _, name in enumerate(names):
145 | x = np.arange(len(numbers[name]))
146 | plt.plot(x, np.asarray(numbers[name]))
147 | plt.legend([self.title + '(' + name + ')' for name in names])
148 | plt.grid(True)
149 |
150 |
151 | def close(self):
152 | if self.file is not None:
153 | self.file.close()
154 |
155 |
156 |
157 |
158 |
159 |
160 | def generate_final_report(model, opt, wandb):
161 | from eval.meta_eval import meta_test
162 |
163 | opt.n_shots = 1
164 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
165 |
166 | #validate
167 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader)
168 |
169 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)
170 |
171 | #evaluate
172 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader)
173 |
174 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
175 |
176 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
177 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
178 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
179 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))
180 |
181 |
182 | wandb.log({'Final Meta Test Acc @1': meta_test_acc,
183 | 'Final Meta Test std @1': meta_test_std,
184 | 'Final Meta Test Acc (feat) @1': meta_test_acc_feat,
185 | 'Final Meta Test std (feat) @1': meta_test_std_feat,
186 | 'Final Meta Val Acc @1': meta_val_acc,
187 | 'Final Meta Val std @1': meta_val_std,
188 | 'Final Meta Val Acc (feat) @1': meta_val_acc_feat,
189 | 'Final Meta Val std (feat) @1': meta_val_std_feat
190 | })
191 |
192 |
193 | opt.n_shots = 5
194 | train_loader, val_loader, meta_testloader, meta_valloader, _ = get_dataloaders(opt)
195 |
196 | #validate
197 | meta_val_acc, meta_val_std = meta_test(model, meta_valloader)
198 |
199 | meta_val_acc_feat, meta_val_std_feat = meta_test(model, meta_valloader, use_logit=False)
200 |
201 | #evaluate
202 | meta_test_acc, meta_test_std = meta_test(model, meta_testloader)
203 |
204 | meta_test_acc_feat, meta_test_std_feat = meta_test(model, meta_testloader, use_logit=False)
205 |
206 | print('Meta Val Acc : {:.4f}, Meta Val std: {:.4f}'.format(meta_val_acc, meta_val_std))
207 | print('Meta Val Acc (feat): {:.4f}, Meta Val std (feat): {:.4f}'.format(meta_val_acc_feat, meta_val_std_feat))
208 | print('Meta Test Acc: {:.4f}, Meta Test std: {:.4f}'.format(meta_test_acc, meta_test_std))
209 | print('Meta Test Acc (feat): {:.4f}, Meta Test std (feat): {:.4f}'.format(meta_test_acc_feat, meta_test_std_feat))
210 |
211 | wandb.log({'Final Meta Test Acc @5': meta_test_acc,
212 | 'Final Meta Test std @5': meta_test_std,
213 | 'Final Meta Test Acc (feat) @5': meta_test_acc_feat,
214 | 'Final Meta Test std (feat) @5': meta_test_std_feat,
215 | 'Final Meta Val Acc @5': meta_val_acc,
216 | 'Final Meta Val std @5': meta_val_std,
217 | 'Final Meta Val Acc (feat) @5': meta_val_acc_feat,
218 | 'Final Meta Val std (feat) @5': meta_val_std_feat
219 | })
--------------------------------------------------------------------------------
/utils/figs/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/main.png
--------------------------------------------------------------------------------
/utils/figs/results1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/results1.png
--------------------------------------------------------------------------------
/utils/figs/results2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/results2.png
--------------------------------------------------------------------------------
/utils/figs/training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brjathu/SKD/0a28dc54b1ed648a82270cbdc22cd15887cdd3e2/utils/figs/training.png
--------------------------------------------------------------------------------