├── losses ├── __init__.py └── losses.py ├── networks ├── sub_network │ ├── __init__.py │ └── resnet_layer.py ├── memory_bank.py ├── wrn_big.py ├── vgg_big.py ├── resnet_big.py └── efficient_big.py ├── utils ├── __init__.py ├── imagenet100.txt ├── imagenet.py ├── tinyimagenet.py └── util.py ├── requirements.txt ├── scripts ├── 1stage_train.sh ├── supcon_represent.sh ├── selfcon_represent.sh └── selfcon_represent_imagenet.sh ├── LICENSE ├── .gitignore ├── README.md ├── main_represent.py ├── main_linear.py └── main_ce.py /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * -------------------------------------------------------------------------------- /networks/sub_network/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_layer import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .imagenet import * 3 | from .tinyimagenet import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.5 2 | torchvision==0.9.1 3 | torch==1.8.1 4 | apex==0.9.10dev 5 | tensorboard_logger==0.1.0 6 | git+https://github.com/ildoonet/pytorch-randaugment 7 | opencv-python -------------------------------------------------------------------------------- /scripts/1stage_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | seed="0" 4 | data="cifar100" 5 | bsz="1024" 6 | method="ce" 7 | model="resnet18" 8 | lr="0.8" 9 | 10 | python main_ce.py \ 11 | --seed $seed \ 12 | --dataset $data \ 13 | --batch_size $bsz \ 14 | --method $method 15 | --model $model \ 16 | --learning_rate $lr \ 17 | --epochs 500 \ 18 | --cosine 19 | -------------------------------------------------------------------------------- /networks/sub_network/resnet_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['resnet_sub_layer', 'wrn_sub_layer'] 5 | 6 | 7 | def resnet_sub_layer(block, in_planes, planes, num_blocks, stride): 8 | strides = [stride] + [1] * (num_blocks - 1) 9 | layers = [] 10 | for i in range(num_blocks): 11 | stride = strides[i] 12 | layers.append(block(in_planes, planes, stride)) 13 | in_planes = planes * block.expansion 14 | return nn.Sequential(*layers) 15 | 16 | def wrn_sub_layer(block, in_planes, planes, num_blocks, dropout_rate, stride): 17 | strides = [stride] + [1]*(int(num_blocks)-1) 18 | layers = [] 19 | 20 | for stride in strides: 21 | layers.append(block(in_planes, planes, dropout_rate, stride)) 22 | in_planes = planes 23 | 24 | return nn.Sequential(*layers) 25 | -------------------------------------------------------------------------------- /scripts/supcon_represent.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | seed="0" 4 | data="cifar100" 5 | method="SupCon" 6 | model="resnet18" 7 | bsz="1024" 8 | lr="0.5" 9 | label="True" 10 | multiview="True" 11 | 12 | python main_represent.py \ 13 | --seed $seed \ 14 | --method $method \ 15 | --dataset $data \ 16 | --model $model \ 17 | --batch_size $bsz \ 18 | --learning_rate $lr \ 19 | --temp 0.1 \ 20 | --epochs 1000 \ 21 | --multiview \ 22 | --cosine \ 23 | --precision 24 | 25 | python main_linear.py --batch_size 512 \ 26 | --dataset $data \ 27 | --model $model \ 28 | --learning_rate 3 \ 29 | --weight_decay 0 \ 30 | --epochs 100 \ 31 | --lr_decay_epochs '60,80' \ 32 | --lr_decay_rate 0.1 \ 33 | --ckpt ./save/representation/${method}/${data}_models/${method}_${data}_${model}_lr_${lr}_multiview_${multiview}_label_${label}_decay_0.0001_bsz_${bsz}_temp_0.1_seed_${seed}_cosine_warm/last.pth 34 | -------------------------------------------------------------------------------- /networks/memory_bank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MemoryBank(nn.Module): 6 | def __init__(self, dim, K, n_cls): 7 | super(MemoryBank, self).__init__() 8 | 9 | self.K = K 10 | 11 | self.register_buffer("queue", torch.randn(dim, K)) 12 | self.register_buffer("q_label", torch.randint(n_cls, (1, K))) 13 | self.queue = nn.functional.normalize(self.queue, dim=0) 14 | 15 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 16 | 17 | @torch.no_grad() 18 | def _dequeue_and_enqueue(self, keys, labels): 19 | batch_size = keys.shape[0] 20 | 21 | ptr = int(self.queue_ptr) 22 | assert self.K % batch_size == 0 # for simplicity 23 | 24 | self.queue[:, ptr:ptr + batch_size] = keys.T 25 | self.q_label[:, ptr:ptr + batch_size] = labels.unsqueeze(1).T 26 | ptr = (ptr + batch_size) % self.K # move pointer 27 | 28 | self.queue_ptr[0] = ptr 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 SangminBae 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/selfcon_represent.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | seed="0" 4 | method="SelfCon" 5 | data="cifar100" 6 | model="resnet18" 7 | arch="resnet" 8 | size="fc" 9 | pos="[False,True,False]" 10 | bsz="1024" 11 | lr="0.5" 12 | label="True" 13 | multiview="False" 14 | 15 | python main_represent.py --exp_name "${arch}_${size}_${pos}" \ 16 | --seed $seed \ 17 | --method $method \ 18 | --dataset $data \ 19 | --model $model \ 20 | --selfcon_pos $pos \ 21 | --selfcon_arch $arch \ 22 | --selfcon_size $size \ 23 | --batch_size $bsz \ 24 | --learning_rate $lr \ 25 | --temp 0.1 \ 26 | --epochs 1000 \ 27 | --cosine \ 28 | --precision 29 | 30 | python main_linear.py --batch_size 512 \ 31 | --dataset $data \ 32 | --model $model \ 33 | --learning_rate 3 \ 34 | --weight_decay 0 \ 35 | --selfcon_pos $pos \ 36 | --selfcon_arch $arch \ 37 | --selfcon_size $size \ 38 | --epochs 100 \ 39 | --lr_decay_epochs '60,80' \ 40 | --lr_decay_rate 0.1 \ 41 | --subnet \ 42 | --ckpt ./save/representation/${method}/${data}_models/${method}_${data}_${model}_lr_${lr}_multiview_${multiview}_label_${label}_decay_0.0001_bsz_${bsz}_temp_0.1_seed_${seed}_cosine_warm_${arch}_${size}_${pos}/last.pth 43 | 44 | -------------------------------------------------------------------------------- /scripts/selfcon_represent_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | seed="0" 4 | method="SelfCon" 5 | data="imagenet" 6 | model="resnet34" 7 | arch="resnet" 8 | size="same" 9 | pos="[False,True,False]" 10 | bsz="2048" 11 | lr="0.25" 12 | label="True" 13 | multiview="False" 14 | 15 | python main_represent.py --exp_name "${arch}_${size}_${pos}" \ 16 | --seed $seed \ 17 | --method $method \ 18 | --dataset $data \ 19 | --data_folder './data/ILSVRC2015/ILSVRC2015/Data/CLS-LOC/' \ 20 | --model $model \ 21 | --selfcon_pos $pos \ 22 | --selfcon_arch $arch \ 23 | --selfcon_size $size \ 24 | --batch_size $bsz \ 25 | --learning_rate $lr \ 26 | --temp 0.1 \ 27 | --epochs 800 \ 28 | --cosine \ 29 | --precision 30 | 31 | python main_linear.py --batch_size 512 \ 32 | --dataset $data \ 33 | --data_folder './data/ILSVRC2015/ILSVRC2015/Data/CLS-LOC/' \ 34 | --model $model \ 35 | --learning_rate 5 \ 36 | --weight_decay 0 \ 37 | --selfcon_pos $pos \ 38 | --selfcon_arch $arch \ 39 | --selfcon_size $size \ 40 | --epochs 40 \ 41 | --lr_decay_epochs '20,30' \ 42 | --lr_decay_rate 0.1 \ 43 | --subnet \ 44 | --ckpt ./save/representation/${method}/${data}_models/${method}_${data}_${model}_lr_${lr}_multiview_${multiview}_label_${label}_decay_0.0001_bsz_${bsz}_temp_0.1_seed_${seed}_cosine_warm_${arch}_${size}_${pos}/last.pth 45 | 46 | -------------------------------------------------------------------------------- /utils/imagenet100.txt: -------------------------------------------------------------------------------- 1 | n02869837 2 | n01749939 3 | n02488291 4 | n02107142 5 | n13037406 6 | n02091831 7 | n04517823 8 | n04589890 9 | n03062245 10 | n01773797 11 | n01735189 12 | n07831146 13 | n07753275 14 | n03085013 15 | n04485082 16 | n02105505 17 | n01983481 18 | n02788148 19 | n03530642 20 | n04435653 21 | n02086910 22 | n02859443 23 | n13040303 24 | n03594734 25 | n02085620 26 | n02099849 27 | n01558993 28 | n04493381 29 | n02109047 30 | n04111531 31 | n02877765 32 | n04429376 33 | n02009229 34 | n01978455 35 | n02106550 36 | n01820546 37 | n01692333 38 | n07714571 39 | n02974003 40 | n02114855 41 | n03785016 42 | n03764736 43 | n03775546 44 | n02087046 45 | n07836838 46 | n04099969 47 | n04592741 48 | n03891251 49 | n02701002 50 | n03379051 51 | n02259212 52 | n07715103 53 | n03947888 54 | n04026417 55 | n02326432 56 | n03637318 57 | n01980166 58 | n02113799 59 | n02086240 60 | n03903868 61 | n02483362 62 | n04127249 63 | n02089973 64 | n03017168 65 | n02093428 66 | n02804414 67 | n02396427 68 | n04418357 69 | n02172182 70 | n01729322 71 | n02113978 72 | n03787032 73 | n02089867 74 | n02119022 75 | n03777754 76 | n04238763 77 | n02231487 78 | n03032252 79 | n02138441 80 | n02104029 81 | n03837869 82 | n03494278 83 | n04136333 84 | n03794056 85 | n03492542 86 | n02018207 87 | n04067472 88 | n03930630 89 | n03584829 90 | n02123045 91 | n04229816 92 | n02100583 93 | n03642806 94 | n04336792 95 | n03259280 96 | n02116738 97 | n02108089 98 | n03424325 99 | n01855672 100 | n02090622 -------------------------------------------------------------------------------- /utils/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wouter Van Gansbeke, Simon Vandenhende 3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 4 | """ 5 | import os 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | from PIL import Image 10 | from torchvision import transforms as tf 11 | from glob import glob 12 | 13 | 14 | class ImageNetSubset(data.Dataset): 15 | def __init__(self, subset_file, root='', split='train', 16 | transform=None): 17 | super(ImageNetSubset, self).__init__() 18 | 19 | self.root = root 20 | self.transform = transform 21 | self.split = split 22 | 23 | # Read the subset of classes to include (sorted) 24 | with open(subset_file, 'r') as f: 25 | result = f.read().splitlines() 26 | subdirs = [] 27 | for line in result: 28 | subdirs.append(line) 29 | 30 | # Gather the files (sorted) 31 | imgs = [] 32 | for i, subdir in enumerate(subdirs): 33 | subdir_path = os.path.join(self.root, subdir) 34 | files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG'))) 35 | for f in files: 36 | imgs.append((f, i)) 37 | self.imgs = imgs 38 | 39 | # Resize 40 | self.resize = tf.Resize(256) 41 | 42 | def get_image(self, index): 43 | path, target = self.imgs[index] 44 | with open(path, 'rb') as f: 45 | img = Image.open(f).convert('RGB') 46 | img = self.resize(img) 47 | return img 48 | 49 | def __len__(self): 50 | return len(self.imgs) 51 | 52 | def __getitem__(self, index): 53 | path, target = self.imgs[index] 54 | with open(path, 'rb') as f: 55 | img = Image.open(f).convert('RGB') 56 | im_size = img.size 57 | img = self.resize(img) 58 | 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | 62 | return img, target 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # folder 132 | results/ 133 | save/ 134 | run/ 135 | 136 | # model checkpoint 137 | *.pth 138 | output/ 139 | log/ -------------------------------------------------------------------------------- /utils/tinyimagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | 6 | EXTENSION = 'JPEG' 7 | NUM_IMAGES_PER_CLASS = 500 8 | CLASS_LIST_FILE = 'wnids.txt' 9 | VAL_ANNOTATION_FILE = 'val_annotations.txt' 10 | 11 | __all__ = ['TinyImageNet'] 12 | 13 | 14 | class TinyImageNet(Dataset): 15 | """Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`. 16 | Parameters 17 | ---------- 18 | root: string 19 | Root directory including `train`, `test` and `val` subdirectories. 20 | split: string 21 | Indicating which split to return as a data set. 22 | Valid option: [`train`, `test`, `val`] 23 | transform: torchvision.transforms 24 | A (series) of valid transformation(s). 25 | in_memory: bool 26 | Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead. 27 | """ 28 | def __init__(self, root, train=True, transform=None, target_transform=None, in_memory=False, download=False): 29 | self.root = os.path.expanduser(root) 30 | self.train = train 31 | self.split = 'train' if train else 'val' 32 | self.transform = transform 33 | self.target_transform = target_transform 34 | self.in_memory = in_memory 35 | self.split_dir = os.path.join(root, self.split) 36 | self.image_paths = sorted(glob.iglob(os.path.join(self.split_dir, '**', '*.%s' % EXTENSION), recursive=True)) 37 | self.labels = {} # fname - label number mapping 38 | self.images = [] # used for in-memory processing 39 | 40 | # build class label - number mapping 41 | with open(os.path.join(self.root, CLASS_LIST_FILE), 'r') as fp: 42 | self.label_texts = sorted([text.strip() for text in fp.readlines()]) 43 | self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)} 44 | 45 | if self.split == 'train': 46 | for label_text, i in self.label_text_to_number.items(): 47 | for cnt in range(NUM_IMAGES_PER_CLASS): 48 | self.labels['%s_%d.%s' % (label_text, cnt, EXTENSION)] = i 49 | elif self.split == 'val': 50 | with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), 'r') as fp: 51 | for line in fp.readlines(): 52 | terms = line.split('\t') 53 | file_name, label_text = terms[0], terms[1] 54 | self.labels[file_name] = self.label_text_to_number[label_text] 55 | 56 | # read all images into torch tensor in memory to minimize disk IO overhead 57 | if self.in_memory: 58 | self.images = [self.read_image(path) for path in self.image_paths] 59 | 60 | def __len__(self): 61 | return len(self.image_paths) 62 | 63 | def __getitem__(self, index): 64 | file_path = self.image_paths[index] 65 | 66 | if self.in_memory: 67 | img = self.images[index] 68 | else: 69 | img = self.read_image(file_path) 70 | 71 | if self.split == 'test': 72 | return img 73 | else: 74 | # file_name = file_path.split('/')[-1] 75 | return img, self.labels[os.path.basename(file_path)] 76 | 77 | def __repr__(self): 78 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 79 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 80 | tmp = self.split 81 | fmt_str += ' Split: {}\n'.format(tmp) 82 | fmt_str += ' Root Location: {}\n'.format(self.root) 83 | tmp = ' Transforms (if any): ' 84 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 85 | tmp = ' Target Transforms (if any): ' 86 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 87 | return fmt_str 88 | 89 | def read_image(self, path): 90 | img = Image.open(path).convert('RGB') 91 | return self.transform(img) if self.transform else img -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import json 5 | import pickle 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | 11 | __all__ = ['AverageMeter', 'TwoCropTransform', 'adjust_learning_rate', 'warmup_learning_rate', 'accuracy', 'class_accuracy', 'set_optimizer', 'save_model', 'update_json', 'update_json_list'] 12 | 13 | 14 | class TwoCropTransform: 15 | """Create two crops of the same image""" 16 | def __init__(self, transform): 17 | self.transform = transform 18 | 19 | def __call__(self, x): 20 | return [self.transform(x), self.transform(x)] 21 | 22 | 23 | class AverageMeter(object): 24 | """Computes and stores the average and current value""" 25 | def __init__(self): 26 | self.reset() 27 | 28 | def reset(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | self.avg = self.sum / self.count 39 | 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | """Computes the accuracy over the k top predictions for the specified values of k""" 43 | with torch.no_grad(): 44 | maxk = max(topk) 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(maxk, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | res = [] 52 | for k in topk: 53 | #correct_k = correct[:k].reshape(-1, k).float().sum(1).sum(0, keepdim=True) 54 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 55 | res.append(correct_k.mul_(100.0 / batch_size)) 56 | return res 57 | 58 | 59 | def class_accuracy(output, target, cls, topk=(1,)): 60 | """Computes the accuracy over the k top predictions for the specified values of k""" 61 | with torch.no_grad(): 62 | maxk = max(topk) 63 | 64 | output = output[target == cls] 65 | target = target[target == cls] 66 | 67 | batch_size = target.size(0) 68 | 69 | _, pred = output.topk(maxk, 1, True, True) 70 | pred = pred.t() 71 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 72 | 73 | res = [] 74 | for k in topk: 75 | #correct_k = correct[:k].reshape(-1, k).float().sum(1).sum(0, keepdim=True) 76 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 77 | res.append(correct_k.mul_(100.0 / batch_size)) 78 | return res, batch_size 79 | 80 | 81 | def adjust_learning_rate(args, optimizer, epoch): 82 | lr = args.learning_rate 83 | if args.cosine: 84 | eta_min = lr * (args.lr_decay_rate ** 3) 85 | lr = eta_min + (lr - eta_min) * ( 86 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 87 | else: 88 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 89 | if steps > 0: 90 | lr = lr * (args.lr_decay_rate ** steps) 91 | for param_group in optimizer.param_groups: 92 | param_group['lr'] = lr 93 | 94 | 95 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 96 | if args.warm and epoch <= args.warm_epochs: 97 | p = (batch_id + (epoch - 1) * total_batches) / \ 98 | (args.warm_epochs * total_batches) 99 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 100 | 101 | for param_group in optimizer.param_groups: 102 | param_group['lr'] = lr 103 | 104 | 105 | def set_optimizer(opt, model, optimizer='sgd'): 106 | if optimizer == 'sgd': 107 | optimizer = optim.SGD(model.parameters(), 108 | lr=opt.learning_rate, 109 | momentum=opt.momentum, 110 | weight_decay=opt.weight_decay) 111 | elif optimizer == 'adam': 112 | optimizer = optim.Adam(model.parameters(), 113 | lr=opt.learning_rate) 114 | return optimizer 115 | 116 | 117 | def save_model(model, optimizer, opt, epoch, save_file): 118 | print('==> Saving...') 119 | state = { 120 | 'opt': opt, 121 | 'model': model.state_dict(), 122 | 'optimizer': optimizer.state_dict(), 123 | 'epoch': epoch, 124 | } 125 | torch.save(state, save_file) 126 | del state 127 | 128 | 129 | def update_json(exp_name, acc={}, path='./save/results.json'): 130 | for k, v in acc.items(): 131 | acc[k] = [round(a, 2) for a in v] 132 | if not os.path.exists(path): 133 | with open(path, 'w') as f: 134 | json.dump({}, f) 135 | 136 | with open(path, 'r', encoding="UTF-8") as f: 137 | result_dict = json.load(f) 138 | result_dict[exp_name] = acc 139 | 140 | with open(path, 'w') as f: 141 | json.dump(result_dict, f) 142 | 143 | print('best accuracy: {}'.format(acc)) 144 | print('results updated to %s' % path) 145 | 146 | 147 | def update_json_list(exp_name, acc=[0., 0.], path='./save/results.json'): 148 | acc = [round(a, 2) for a in acc] 149 | if not os.path.exists(path): 150 | with open(path, 'w') as f: 151 | json.dump({}, f) 152 | 153 | with open(path, 'r', encoding="UTF-8") as f: 154 | result_dict = json.load(f) 155 | result_dict[exp_name] = acc 156 | 157 | with open(path, 'w') as f: 158 | json.dump(result_dict, f) 159 | 160 | print('best accuracy: {}'.format(acc)) 161 | print('results updated to %s' % path) -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | refer to 3 | 1) Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf 4 | 2) SimCLR: https://arxiv.org/pdf/2002.05709.pdf 5 | """ 6 | from __future__ import print_function 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | eps = 1e-7 15 | 16 | 17 | class ConLoss(nn.Module): 18 | """Self-Contrastive Learning: https://arxiv.org/abs/2106.15499.""" 19 | def __init__(self, temperature=0.07, contrast_mode='all', base_temperature=0.07): 20 | super(ConLoss, self).__init__() 21 | self.temperature = temperature 22 | self.contrast_mode = contrast_mode 23 | self.base_temperature = base_temperature 24 | 25 | def forward(self, features, labels=None, mask=None, supcon_s=False, selfcon_s_FG=False, selfcon_m_FG=False): 26 | """ 27 | Args: 28 | features: hidden vector of shape [bsz, n_views, ...]. 29 | labels: ground truth of shape [bsz]. 30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 31 | has the same class as sample i. Can be asymmetric. 32 | supcon_s: boolean for using single-viewed batch. 33 | selfcon_s_FG: exclude contrastive loss when the anchor is from F (backbone) and the pairs are from G (sub-network). 34 | selfcon_m_FG: exclude contrastive loss when the anchor is from F (backbone) and the pairs are from G (sub-network). 35 | Returns: 36 | A loss scalar. 37 | """ 38 | device = features.device 39 | 40 | if len(features.shape) < 3: 41 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 42 | 'at least 3 dimensions are required') 43 | if len(features.shape) > 3: 44 | features = features.view(features.shape[0], features.shape[1], -1) 45 | 46 | batch_size = features.shape[0] if not selfcon_m_FG else int(features.shape[0]/2) 47 | 48 | if labels is not None and mask is not None: 49 | raise ValueError('Cannot define both `labels` and `mask`') 50 | elif labels is None and mask is None: 51 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 52 | elif labels is not None: 53 | labels = labels.contiguous().view(-1, 1) 54 | if labels.shape[0] != batch_size: 55 | raise ValueError('Num of labels does not match num of features') 56 | mask = torch.eq(labels, labels.T).float().to(device) 57 | else: 58 | mask = mask.float().to(device) 59 | 60 | if not selfcon_s_FG and not selfcon_m_FG: 61 | contrast_count = features.shape[1] 62 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 63 | if self.contrast_mode == 'one': 64 | anchor_feature = features[:, 0] 65 | anchor_count = 1 66 | elif self.contrast_mode == 'all': 67 | anchor_feature = contrast_feature 68 | anchor_count = contrast_count 69 | else: 70 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 71 | elif selfcon_s_FG: 72 | contrast_count = features.shape[1] 73 | anchor_count = features.shape[1]-1 74 | 75 | anchor_feature, contrast_feature = torch.cat(torch.unbind(features, dim=1)[:-1], dim=0), torch.unbind(features, dim=1)[-1] 76 | contrast_feature = torch.cat([anchor_feature, contrast_feature], dim=0) 77 | elif selfcon_m_FG: 78 | contrast_count = int(features.shape[1] * 2) 79 | anchor_count = (features.shape[1]-1)*2 80 | 81 | anchor_feature, contrast_feature = torch.cat(torch.unbind(features, dim=1)[:-1], dim=0), torch.unbind(features, dim=1)[-1] 82 | contrast_feature = torch.cat([anchor_feature, contrast_feature], dim=0) 83 | 84 | # compute logits 85 | anchor_dot_contrast = torch.div( 86 | torch.matmul(anchor_feature, contrast_feature.T), 87 | self.temperature) 88 | 89 | # for numerical stability 90 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 91 | logits = anchor_dot_contrast - logits_max.detach() 92 | 93 | # tile mask 94 | mask = mask.repeat(anchor_count, contrast_count) 95 | 96 | # mask-out self-contrast cases 97 | logits_mask = torch.scatter( 98 | torch.ones_like(mask), 99 | 1, 100 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 101 | 0 102 | ) 103 | 104 | mask = mask * logits_mask 105 | if supcon_s: 106 | idx = mask.sum(1) != 0 107 | mask = mask[idx, :] 108 | logits_mask = logits_mask[idx, :] 109 | logits = logits[idx, :] 110 | batch_size = idx.sum() 111 | 112 | # compute log_prob 113 | exp_logits = torch.exp(logits) * logits_mask 114 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 115 | 116 | # compute mean of log-likelihood over positive 117 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 118 | 119 | # loss 120 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 121 | loss = loss.view(anchor_count, batch_size).mean() 122 | 123 | return loss 124 | 125 | 126 | class KLLoss(nn.Module): 127 | """Distilling the Knowledge in a Neural Network""" 128 | def __init__(self, T=3.0): 129 | super(KLLoss, self).__init__() 130 | self.T = T 131 | 132 | def forward(self, logit_s, logit_t): 133 | p_s = F.log_softmax(logit_s/self.T, dim=1) 134 | p_t = F.softmax(logit_t.clone().detach()/self.T, dim=1) 135 | loss = -pow(self.T, 2)*(p_s * p_t).sum(dim=1).mean() 136 | 137 | return loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Contrastive Learning: Single-viewed Supervised Contrastive Framework using Sub-network 2 | 3 |

4 | 5 |

6 | 7 | This repository contains the official PyTorch implementation of the following paper: 8 | 9 | > **Self-Contrastive Learning: Single-viewed Supervised Contrastive Framework using Sub-network** by 10 | > Sangmin Bae*, Sungnyun Kim*, Jongwoo Ko, Gihun Lee, Seungjong Noh, Se-Young Yun, [AAAI 2023](https://aaai.org/Conferences/AAAI-23/). 11 | > 12 | > **Paper**: https://arxiv.org/abs/2106.15499 13 | > **Video**: https://www.youtube.com/watch?v=VNv3LXzqX_4 14 | > 15 | > **Abstract:** *Contrastive loss has significantly improved performance in supervised classification tasks by using a multi-viewed framework that leverages augmentation and label information. The augmentation enables contrast with another view of a single image but enlarges training time and memory usage. To exploit the strength of multi-views while avoiding the high computation cost, we introduce a multi-exit architecture that outputs multiple features of a single image in a single-viewed framework. To this end, we propose Self-Contrastive (SelfCon) learning, which self-contrasts within multiple outputs from the different levels of a single network. The multi-exit architecture efficiently replaces multi-augmented images and leverages various information from different layers of a network. We demonstrate that SelfCon learning improves the classification performance of the encoder network, and empirically analyze its advantages in terms of the single-view and the sub-network. Furthermore, we provide theoretical evidence of the performance increase based on the mutual information bound. For ImageNet classification on ResNet-50, SelfCon improves accuracy by +0.6% with 59% memory and 48% time of Supervised Contrastive learning, and a simple ensemble of multi-exit outputs boosts performance up to +1.5%.* 16 | 17 | ## Table of Contents 18 | 19 | * [Installation](#installation) 20 | * [Usage](#usage) 21 | * [Parameters for Pretraining](#parameters-for-pretraining) 22 | * [Experimental Results](#experimental-results) 23 | * [License](#license) 24 | * [Contact](#contact) 25 | 26 | ## Installation 27 | We experimented with eight RTX 3090 GPUs and CUDA version of 11.3. 28 | Please check below requirements and install packages from `requirements.txt`. 29 | 30 | ```bash 31 | $ pip install --upgrade pip 32 | $ pip install -r requirements.txt 33 | ``` 34 | 35 | ## Usage 36 | To pretrain the SelfCon model, the following command is an example of running `main_represent.py`. 37 | 38 | ```bash 39 | # Pretraining on [Dataset: CIFAR-100, Architecture: ResNet-18] 40 | python main_represent.py --exp_name "resnet_fc_[False,True,False]" \ 41 | --seed 2022 \ 42 | --method SelfCon \ 43 | --dataset cifar100 \ 44 | --model resnet18 \ 45 | --selfcon_pos "[False,True,False]" \ 46 | --selfcon_arch "resnet" \ 47 | --selfcon_size "fc" \ 48 | --batch_size 1024 \ 49 | --learning_rate 0.5 \ 50 | --temp 0.1 \ 51 | --epochs 1000 \ 52 | --cosine \ 53 | --precision 54 | ``` 55 | 56 | For linear evaluation, run `main_linear.py` with an appropriate `${SAVE_CKPT}`. 57 | For the above example, `${SAVE_CKPT}` is `./save/representation/SelfCon/cifar100_models/SelfCon_cifar100_resnet18_lr_0.5_multiview_False_label_True_decay_0.0001_bsz_1024_temp_0.1_seed_2022_cosine_warm_resnet_fc_[False,True,False]/last.pth`. 58 | 59 | ```bash 60 | # Finetuning on [Dataset: CIFAR-100, Architecture: ResNet-18] 61 | python main_linear.py --batch_size 512 \ 62 | --dataset cifar100 \ 63 | --model resnet18 \ 64 | --learning_rate 3 \ 65 | --weight_decay 0 \ 66 | --selfcon_pos "[False,True,False]" \ 67 | --selfcon_arch "resnet" \ 68 | --selfcon_size "fc" \ 69 | --epochs 100 \ 70 | --lr_decay_epochs '60,80' \ 71 | --lr_decay_rate 0.1 \ 72 | --subnet \ 73 | --ckpt ${SAVE_CKPT} 74 | ``` 75 | 76 | Also, refer to `./scripts/` for SupCon pretraining and 1-stage training examples. 77 | For ImageNet experiments, change `--dataset` to `imagenet`, specify `--data_folder`, and set hyperparameters as denoted in the paper. 78 | 79 | ### Parameters for Pretraining 80 | | Parameter | Description | 81 | | ----------------------------- | ---------------------------------------- | 82 | | `model` | The model architecture. Default: `resnet50`. | 83 | | `dataset` | Dataset to use. Options: `cifar10`, `cifar100`, `tinyimagenet`, `imagenet100`, `imagenet`. | 84 | | `method` | Pretraining method. Options: `Con`, `SupCon`, `SelfCon`. | 85 | | `lr` | Learning rate for the pretraining. Default: `0.5` for the batch size of 1024. | 86 | | `temp` | Temperature of contrastive loss function. Default: `0.07`. | 87 | | `precision` | Whether to use mixed precision. Default: `False`. | 88 | | `cosine` | Whether to use cosine annealing scheduling. Default: `False`. | 89 | | `selfcon_pos` | Position where to attach the sub-network. Default: `[False,True,False]` for ResNet architectures. | 90 | | `selfcon_arch` | Sub-network architecture. Options: `resnet`, `vgg`, `efficientnet`, `wrn`. Default: `resnet`. | 91 | | `selfcon_size` | Block numbers of a sub-network. Options: `fc`, `small`, `same`. Default: `same`. | 92 | | `multiview` | Whether to use multi-viwed batch. Default: `False`. | 93 | | `label` | Whether to use label information in a contrastive loss. Default: `False`. | 94 | 95 | 96 | ### Experimental Results 97 | See our paper for more details and extensive analyses. 98 | Here are some of our main results. 99 | 100 |

101 | 102 |

103 |

104 | 105 |

106 | 107 | ## Citing This Work 108 | 109 | If you find this repo useful for your research, please consider citing our paper: 110 | ``` 111 | @article{bae2021self, 112 | title={Self-Contrastive Learning: Single-viewed Supervised Contrastive Framework using Sub-network}, 113 | author={Bae, Sangmin and Kim, Sungnyun and Ko, Jongwoo and Lee, Gihun and Noh, Seungjong and Yun, Se-Young}, 114 | journal={arXiv preprint arXiv:2106.15499}, 115 | year={2021} 116 | } 117 | ``` 118 | 119 | ## License 120 | Distributed under the MIT License. 121 | 122 | ## Contact 123 | * Sangmin Bae: bsmn0223@kaist.ac.kr 124 | * Sungnyun Kim: ksn4397@kaist.ac.kr 125 | -------------------------------------------------------------------------------- /networks/wrn_big.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | from .sub_network import * 8 | 9 | import sys 10 | import numpy as np 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 14 | 15 | def conv7x7(in_planes, out_planes, stride=2): 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, padding=3, bias=True) 17 | 18 | def conv_init(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('Conv') != -1: 21 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 22 | init.constant_(m.bias, 0) 23 | elif classname.find('BatchNorm') != -1: 24 | init.constant_(m.weight, 1) 25 | init.constant_(m.bias, 0) 26 | 27 | class wide_basic(nn.Module): 28 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 29 | super(wide_basic, self).__init__() 30 | self.bn1 = nn.BatchNorm2d(in_planes) 31 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 32 | self.dropout = nn.Dropout(p=dropout_rate) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 35 | 36 | self.shortcut = nn.Sequential() 37 | if stride != 1 or in_planes != planes: 38 | self.shortcut = nn.Sequential( 39 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 40 | ) 41 | 42 | def forward(self, x): 43 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 44 | out = self.conv2(F.relu(self.bn2(out))) 45 | out += self.shortcut(x) 46 | 47 | return out 48 | 49 | class Wide_ResNet(nn.Module): 50 | def __init__(self, depth, widen_factor, dropout_rate, selfcon_pos=[False,False], selfcon_arch='wrn', selfcon_size='same', dataset=''): 51 | super(Wide_ResNet, self).__init__() 52 | self.in_planes = 16 53 | 54 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 55 | n = (depth-4)/6 56 | k = widen_factor 57 | self.dropout_rate = dropout_rate 58 | self.num_blocks = n 59 | 60 | print('| Wide-Resnet %dx%d' %(depth, k)) 61 | nStages = [16, 16*k, 32*k, 64*k] 62 | self.nStages = nStages 63 | 64 | if dataset in ['imagenet', 'imagenet100']: 65 | self.conv1 = conv7x7(7,nStages[0]) 66 | else: 67 | self.conv1 = conv3x3(3,nStages[0]) 68 | if 'imagenet' in dataset: 69 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 70 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 71 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 72 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 73 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 74 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 75 | 76 | self.selfcon_pos = selfcon_pos 77 | self.selfcon_arch = selfcon_arch 78 | self.selfcon_size = selfcon_size 79 | self.selfcon_layer = nn.ModuleList([self._make_sub_layer(idx, pos) for idx, pos in enumerate(selfcon_pos)]) 80 | self.dataset = dataset 81 | 82 | for m in self.modules(): 83 | conv_init(m) 84 | 85 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 86 | strides = [stride] + [1]*(int(num_blocks)-1) 87 | layers = [] 88 | 89 | for stride in strides: 90 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 91 | self.in_planes = planes 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def _make_sub_layer(self, idx, pos): 96 | channels = [128, 256, 512] 97 | strides = [1, 2, 2] 98 | num_blocks = [self.num_blocks]*3 99 | if self.selfcon_size == 'same': 100 | num_blocks = num_blocks 101 | elif self.selfcon_size == 'small': 102 | num_blocks = [int((n+1)/2) for n in num_blocks] 103 | elif self.selfcon_size == 'large': 104 | num_blocks = [int(n*2) for n in num_blocks] 105 | elif self.selfcon_size == 'fc': 106 | pass 107 | else: 108 | raise NotImplemented 109 | 110 | if not pos: 111 | return None 112 | else: 113 | if self.selfcon_size == 'fc': 114 | return nn.Linear(channels[idx], channels[-1]) 115 | else: 116 | if self.selfcon_arch == 'resnet': 117 | raise NotImplemented 118 | elif self.selfcon_arch == 'vgg': 119 | raise NotImplemented 120 | elif self.selfcon_arch == 'efficientnet': 121 | raise NotImplemented 122 | elif self.selfcon_arch == 'wrn': 123 | layers = [] 124 | for i in range(idx+1, 3): 125 | in_planes = channels[i-1] 126 | layers.append(wrn_sub_layer(wide_basic, in_planes, channels[i], num_blocks[i], self.dropout_rate, strides[i])) 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x): 131 | sub_out = [] 132 | 133 | x = self.conv1(x) 134 | # maxpool -> last map before avgpool is 4x4 135 | if 'imagenet' in self.dataset: 136 | x = self.maxpool(x) 137 | 138 | x = self.layer1(x) 139 | if self.selfcon_layer[0]: 140 | if self.selfcon_size != 'fc': 141 | out = self.selfcon_layer[0](x) 142 | out = torch.flatten(self.avgpool(out), 1) 143 | else: 144 | out = torch.flatten(self.avgpool(x), 1) 145 | out = self.selfcon_layer[0](out) 146 | sub_out.append(out) 147 | 148 | x = self.layer2(x) 149 | if self.selfcon_layer[1]: 150 | if self.selfcon_size != 'fc': 151 | out = self.selfcon_layer[1](x) 152 | out = torch.flatten(self.avgpool(out), 1) 153 | else: 154 | out = torch.flatten(self.avgpool(x), 1) 155 | out = self.selfcon_layer[1](out) 156 | sub_out.append(out) 157 | 158 | x = self.layer3(x) 159 | x = F.relu(self.bn1(x)) 160 | x = self.avgpool(x) 161 | 162 | x = x.view(x.size(0), -1) 163 | x = torch.flatten(x ,1) 164 | # out = self.linear(out) 165 | 166 | return sub_out, x 167 | 168 | 169 | def wrn_16_8(**kwargs): 170 | return Wide_ResNet(16, 8, 0.3, **kwargs) 171 | 172 | model_dict = { 173 | 'wrn_16_8': [wrn_16_8, 512], 174 | } 175 | 176 | class ConWRN(nn.Module): 177 | """backbone + projection head""" 178 | def __init__(self, name='wrn_16_8', head='mlp', feat_dim=128, selfcon_pos=[False,False], selfcon_arch='wrn', selfcon_size='same', dataset=''): 179 | super(ConWRN, self).__init__() 180 | model_fun, dim_in = model_dict[name] 181 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset) 182 | if head == 'linear': 183 | self.head = nn.Linear(dim_in, feat_dim) 184 | 185 | self.sub_heads = [] 186 | for pos in selfcon_pos: 187 | if pos: 188 | self.sub_heads.append(nn.Linear(dim_in, feat_dim)) 189 | elif head == 'mlp': 190 | self.head = nn.Sequential( 191 | nn.Linear(dim_in, dim_in), 192 | nn.ReLU(inplace=True), 193 | nn.Linear(dim_in, feat_dim) 194 | ) 195 | 196 | heads = [] 197 | for pos in selfcon_pos: 198 | if pos: 199 | heads.append(nn.Sequential( 200 | nn.Linear(dim_in, dim_in), 201 | nn.ReLU(inplace=True), 202 | nn.Linear(dim_in, feat_dim) 203 | )) 204 | self.sub_heads = nn.ModuleList(heads) 205 | else: 206 | raise NotImplementedError( 207 | 'head not supported: {}'.format(head)) 208 | 209 | def forward(self, x): 210 | sub_feat, feat = self.encoder(x) 211 | 212 | sh_feat = [] 213 | for sf, sub_head in zip(sub_feat, self.sub_heads): 214 | sh_feat.append(F.normalize(sub_head(sf), dim=1)) 215 | 216 | feat = F.normalize(self.head(feat), dim=1) 217 | return sh_feat, feat 218 | 219 | 220 | class CEWRN(nn.Module): 221 | """encoder + classifier""" 222 | def __init__(self, name='wrn_16_8', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False,False], selfcon_arch='wrn', selfcon_size='same', dataset=''): 223 | super(CEWRN, self).__init__() 224 | self.method = method 225 | 226 | model_fun, dim_in = model_dict[name] 227 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset) 228 | 229 | logit_fcs, feat_fcs = [], [] 230 | for pos in selfcon_pos: 231 | if pos: 232 | logit_fcs.append(nn.Linear(dim_in, num_classes)) 233 | feat_fcs.append(nn.Linear(dim_in, dim_out)) 234 | 235 | self.logit_fc = nn.ModuleList(logit_fcs) 236 | self.l_fc = nn.Linear(dim_in, num_classes) 237 | 238 | if method not in ['ce', 'subnet_ce', 'kd']: 239 | self.feat_fc = nn.ModuleList(feat_fcs) 240 | self.f_fc = nn.Linear(dim_in, dim_out) 241 | for param in self.f_fc.parameters(): 242 | param.requires_grad = False 243 | 244 | def forward(self, x): 245 | sub_feat, feat = self.encoder(x) 246 | 247 | feats, logits = [], [] 248 | 249 | for idx, sh_feat in enumerate(sub_feat): 250 | logits.append(self.logit_fc[idx](sh_feat)) 251 | if self.method not in ['ce', 'subnet_ce', 'kd']: 252 | out = self.feat_fc[idx](sh_feat) 253 | feats.append(F.normalize(out, dim=1)) 254 | 255 | if self.method not in ['ce', 'subnet_ce', 'kd']: 256 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)] 257 | else: 258 | return [logits, self.l_fc(feat)] 259 | 260 | 261 | class LinearClassifier_WRN(nn.Module): 262 | """Linear classifier""" 263 | def __init__(self, name='wrn_16_8', num_classes=10): 264 | super(LinearClassifier_WRN, self).__init__() 265 | _, feat_dim = model_dict[name] 266 | self.fc = nn.Linear(feat_dim, num_classes) 267 | 268 | def forward(self, features): 269 | return self.fc(features) 270 | 271 | 272 | if __name__ == '__main__': 273 | net=Wide_ResNet(16, 8, 0.3, 10) 274 | y = net(Variable(torch.randn(1,3,32,32))) 275 | 276 | print(y.size()) 277 | -------------------------------------------------------------------------------- /networks/vgg_big.py: -------------------------------------------------------------------------------- 1 | ''' 2 | VGG in PyTorch 3 | Adapted from: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from typing import Union, List, Dict, Any, cast 10 | 11 | 12 | class VGG(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | features: nn.Module, 17 | cfg: str = 'D', 18 | arch: str = 'vgg16_bn', 19 | init_weights: bool = True, 20 | selfcon_pos: List[bool] = [False,False,False,False], 21 | selfcon_arch: str = 'vgg', 22 | selfcon_size: str = 'small', 23 | dataset: str = '' 24 | ) -> None: 25 | super(VGG, self).__init__() 26 | features_lst, modules_lst = [], [] 27 | for module in features.modules(): 28 | if isinstance(module, nn.Sequential): 29 | continue 30 | modules_lst.append(module) 31 | if isinstance(module, nn.MaxPool2d): 32 | features_lst.append(modules_lst) 33 | modules_lst = [] 34 | self.block1 = nn.Sequential(*features_lst[0]) 35 | self.block2 = nn.Sequential(*features_lst[1]) 36 | self.block3 = nn.Sequential(*features_lst[2]) 37 | self.block4 = nn.Sequential(*features_lst[3]) 38 | self.block5 = nn.Sequential(*features_lst[4]) 39 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 40 | 41 | self.arch = arch 42 | self.selfcon_pos = selfcon_pos 43 | self.selfcon_arch = selfcon_arch 44 | self.selfcon_size = selfcon_size 45 | self.dataset = dataset 46 | self.selfcon_layer = nn.ModuleList([self._make_sub_layer(idx, pos, cfg) for idx, pos in enumerate(selfcon_pos)]) 47 | 48 | if init_weights: 49 | self._initialize_weights() 50 | 51 | def forward(self, x): 52 | sub_out = [] 53 | 54 | x = self.block1(x) 55 | if self.selfcon_layer[0]: 56 | if self.selfcon_size != 'fc': 57 | out = self.selfcon_layer[0](x) 58 | out = torch.flatten(self.avgpool(out), 1) 59 | else: 60 | out = torch.flatten(self.avgpool(x), 1) 61 | out = self.selfcon_layer[0](out) 62 | sub_out.append(out) 63 | 64 | x = self.block2(x) 65 | if self.selfcon_layer[1]: 66 | if self.selfcon_size != 'fc': 67 | out = self.selfcon_layer[1](x) 68 | out = torch.flatten(self.avgpool(out), 1) 69 | else: 70 | out = torch.flatten(self.avgpool(x), 1) 71 | out = self.selfcon_layer[1](out) 72 | sub_out.append(out) 73 | 74 | x = self.block3(x) 75 | if self.selfcon_layer[2]: 76 | if self.selfcon_size != 'fc': 77 | out = self.selfcon_layer[2](x) 78 | out = torch.flatten(self.avgpool(out), 1) 79 | else: 80 | out = torch.flatten(self.avgpool(x), 1) 81 | out = self.selfcon_layer[2](out) 82 | sub_out.append(out) 83 | 84 | x = self.block4(x) 85 | if self.selfcon_layer[3]: 86 | if self.selfcon_size != 'fc': 87 | out = self.selfcon_layer[3](x) 88 | out = torch.flatten(self.avgpool(out), 1) 89 | else: 90 | out = torch.flatten(self.avgpool(x), 1) 91 | out = self.selfcon_layer[3](out) 92 | sub_out.append(out) 93 | 94 | x = self.block5(x) 95 | x = self.avgpool(x) 96 | x = torch.flatten(x, 1) 97 | return sub_out, x 98 | 99 | def _initialize_weights(self) -> None: 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 103 | if m.bias is not None: 104 | nn.init.constant_(m.bias, 0) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | elif isinstance(m, nn.Linear): 109 | nn.init.normal_(m.weight, 0, 0.01) 110 | nn.init.constant_(m.bias, 0) 111 | 112 | def _make_sub_layer(self, idx, pos, cfg): 113 | channels = [64, 128, 256, 512, 512] 114 | 115 | if not pos: 116 | return None 117 | else: 118 | if self.selfcon_arch == 'resnet': 119 | raise NotImplemented 120 | elif self.selfcon_arch == 'vgg': 121 | if self.selfcon_size == 'fc': 122 | layers = [nn.Linear(channels[idx], channels[-1])] 123 | else: 124 | layers = [] 125 | if self.selfcon_size == 'same': 126 | num_blocks = 3 if cfg == 'D' else 2 127 | elif self.selfcon_size == 'small': 128 | num_blocks = 1 129 | elif self.selfcon_size == 'large': 130 | raise NotImplemented 131 | 132 | for i in range(idx+1, 5): 133 | in_planes = channels[i-1] 134 | v = channels[i] 135 | for b in range(num_blocks): 136 | if self.arch.endswith('_bn'): 137 | layers += [nn.Conv2d(in_planes, v, kernel_size=3, padding=1), nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 138 | else: 139 | layers += [nn.Conv2d(in_planes, v, kernel_size=3, padding=1), nn.ReLU(inplace=True)] 140 | in_planes = v 141 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 142 | else: 143 | raise NotImplemented 144 | 145 | return nn.Sequential(*layers) 146 | 147 | 148 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 149 | layers: List[nn.Module] = [] 150 | in_channels = 3 151 | for v in cfg: 152 | if v == 'M': 153 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 154 | else: 155 | v = cast(int, v) 156 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 157 | if batch_norm: 158 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 159 | else: 160 | layers += [conv2d, nn.ReLU(inplace=True)] 161 | in_channels = v 162 | return nn.Sequential(*layers) 163 | 164 | 165 | cfgs: Dict[str, List[Union[str, int]]] = { 166 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 167 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 168 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 169 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 170 | } 171 | 172 | 173 | def _vgg(arch: str, cfg: str, batch_norm: bool, progress: bool, **kwargs: Any) -> VGG: 174 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), cfg=cfg, arch=arch, **kwargs) 175 | return model 176 | 177 | def vgg13(progress: bool = True, **kwargs: Any) -> VGG: 178 | return _vgg('vgg13', 'B', False, progress, **kwargs) 179 | 180 | def vgg13_bn(progress: bool = True, **kwargs: Any) -> VGG: 181 | return _vgg('vgg13_bn', 'B', True, progress, **kwargs) 182 | 183 | def vgg16(progress: bool = True, **kwargs: Any) -> VGG: 184 | return _vgg('vgg16', 'D', False, progress, **kwargs) 185 | 186 | def vgg16_bn(progress: bool = True, **kwargs: Any) -> VGG: 187 | return _vgg('vgg16_bn', 'D', True, progress, **kwargs) 188 | 189 | model_dict = {'vgg13': vgg13, 190 | 'vgg13_bn': vgg13_bn, 191 | 'vgg16': vgg16, 192 | 'vgg16_bn': vgg16_bn 193 | } 194 | 195 | class ConVGG(nn.Module): 196 | def __init__(self, name='vgg13_bn', head='mlp', feat_dim=128, selfcon_pos=[False,False,False,False], selfcon_arch='vgg', selfcon_size='same', dataset=''): 197 | super(ConVGG, self).__init__() 198 | model_fun = model_dict[name] 199 | dim_in = 512 200 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset) 201 | if head == 'linear': 202 | self.head = nn.Linear(dim_in, feat_dim) 203 | 204 | self.sub_heads = [] 205 | for pos in selfcon_pos: 206 | if pos: 207 | self.sub_heads.append(nn.Linear(dim_in, feat_dim)) 208 | elif head == 'mlp': 209 | self.head = nn.Sequential( 210 | nn.Linear(dim_in, dim_in), 211 | nn.ReLU(inplace=True), 212 | nn.Linear(dim_in, feat_dim) 213 | ) 214 | 215 | heads = [] 216 | for pos in selfcon_pos: 217 | if pos: 218 | heads.append(nn.Sequential( 219 | nn.Linear(dim_in, dim_in), 220 | nn.ReLU(inplace=True), 221 | nn.Linear(dim_in, feat_dim) 222 | )) 223 | self.sub_heads = nn.ModuleList(heads) 224 | else: 225 | raise NotImplementedError( 226 | 'head not supported: {}'.format(head)) 227 | 228 | def forward(self, x): 229 | sub_feat, feat = self.encoder(x) 230 | 231 | sh_feat = [] 232 | for sf, sub_head in zip(sub_feat, self.sub_heads): 233 | sh_feat.append(F.normalize(sub_head(sf), dim=1)) 234 | 235 | feat = F.normalize(self.head(feat), dim=1) 236 | return sh_feat, feat 237 | 238 | 239 | class CEVGG(nn.Module): 240 | """encoder + classifier""" 241 | def __init__(self, name='vgg13_bn', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False,False,False,False], selfcon_arch='vgg', selfcon_size='same', dataset=''): 242 | super(CEVGG, self).__init__() 243 | self.method = method 244 | 245 | model_fun = model_dict[name] 246 | dim_in = 512 247 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset) 248 | 249 | logit_fcs, feat_fcs = [], [] 250 | for pos in selfcon_pos: 251 | if pos: 252 | logit_fcs.append(nn.Sequential(nn.Linear(dim_in, dim_in), 253 | nn.ReLU(inplace=True), 254 | nn.Dropout(), 255 | nn.Linear(dim_in, num_classes) 256 | )) 257 | feat_fcs.append(nn.Linear(dim_in, dim_out)) 258 | 259 | self.logit_fc = nn.ModuleList(logit_fcs) 260 | self.l_fc = nn.Sequential(nn.Linear(dim_in, dim_in), 261 | nn.ReLU(inplace=True), 262 | nn.Dropout(), 263 | nn.Linear(dim_in, num_classes) 264 | ) 265 | 266 | if method not in ['ce', 'subnet_ce', 'kd']: 267 | self.feat_fc = nn.ModuleList(feat_fcs) 268 | self.f_fc = nn.Linear(dim_in, dim_out) 269 | for param in self.f_fc.parameters(): 270 | param.requires_grad = False 271 | 272 | def forward(self, x): 273 | sub_feat, feat = self.encoder(x) 274 | 275 | feats, logits = [], [] 276 | 277 | for idx, sh_feat in enumerate(sub_feat): 278 | logits.append(self.logit_fc[idx](sh_feat)) 279 | if self.method not in ['ce', 'subnet_ce', 'kd']: 280 | out = self.feat_fc[idx](sh_feat) 281 | feats.append(F.normalize(out, dim=1)) 282 | 283 | if self.method not in ['ce', 'subnet_ce', 'kd']: 284 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)] 285 | else: 286 | return [logits, self.l_fc(feat)] 287 | 288 | 289 | class LinearClassifier_VGG(nn.Module): 290 | """Linear classifier""" 291 | def __init__(self, name='vgg13_bn', num_classes=10): 292 | super(LinearClassifier_VGG, self).__init__() 293 | feat_dim = 512 294 | self.fc1 = nn.Linear(feat_dim, feat_dim) 295 | self.relu = nn.ReLU(inplace=True) 296 | self.dropout = nn.Dropout() 297 | self.fc2 = nn.Linear(feat_dim, num_classes) 298 | 299 | def forward(self, features): 300 | features = self.dropout(self.relu(self.fc1(features))) 301 | return self.fc2(features) 302 | -------------------------------------------------------------------------------- /networks/resnet_big.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | ImageNet-Style ResNet 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | Adapted from: https://github.com/bearpaw/pytorch-classification 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .sub_network import * 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1, is_last=False): 18 | super(BasicBlock, self).__init__() 19 | self.is_last = is_last 20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion * planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion * planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out += self.shortcut(x) 36 | preact = out 37 | out = F.relu(out) 38 | if self.is_last: 39 | return out, preact 40 | else: 41 | return out 42 | 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, in_planes, planes, stride=1, is_last=False): 48 | super(Bottleneck, self).__init__() 49 | self.is_last = is_last 50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 55 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 56 | 57 | self.shortcut = nn.Sequential() 58 | if stride != 1 or in_planes != self.expansion * planes: 59 | self.shortcut = nn.Sequential( 60 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion * planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | preact = out 70 | out = F.relu(out) 71 | if self.is_last: 72 | return out, preact 73 | else: 74 | return out 75 | 76 | 77 | class ResNet(nn.Module): 78 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''): 79 | super(ResNet, self).__init__() 80 | self.in_planes = 64 81 | self.block = block 82 | self.num_blocks = num_blocks 83 | self.in_channel = in_channel 84 | self.dataset = dataset 85 | 86 | self.large = False if dataset in ['cifar10', 'cifar100', 'tinyimagenet'] else True 87 | if not self.large: 88 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False) 89 | else: 90 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 91 | 92 | self.bn1 = nn.BatchNorm2d(64) 93 | if self.large: 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 96 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 98 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 100 | 101 | self.selfcon_pos = selfcon_pos 102 | self.selfcon_arch = selfcon_arch 103 | self.selfcon_size = selfcon_size 104 | self.selfcon_layer = nn.ModuleList([self._make_sub_layer(idx, pos) for idx, pos in enumerate(selfcon_pos)]) 105 | 106 | for k, m in self.named_modules(): 107 | if isinstance(m, nn.Conv2d): 108 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 109 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 110 | nn.init.constant_(m.weight, 1) 111 | nn.init.constant_(m.bias, 0) 112 | 113 | # Zero-initialize the last BN in each residual branch, 114 | # so that the residual branch starts with zeros, and each residual block behaves 115 | # like an identity. This improves the model by 0.2~0.3% according to: 116 | # https://arxiv.org/abs/1706.02677 117 | if zero_init_residual: 118 | for m in self.modules(): 119 | if isinstance(m, Bottleneck): 120 | nn.init.constant_(m.bn3.weight, 0) 121 | elif isinstance(m, BasicBlock): 122 | nn.init.constant_(m.bn2.weight, 0) 123 | 124 | def _make_layer(self, block, planes, num_blocks, stride): 125 | strides = [stride] + [1] * (num_blocks - 1) 126 | layers = [] 127 | for i in range(num_blocks): 128 | stride = strides[i] 129 | layers.append(block(self.in_planes, planes, stride)) 130 | self.in_planes = planes * block.expansion 131 | return nn.Sequential(*layers) 132 | 133 | def _make_sub_layer(self, idx, pos): 134 | channels = [64, 128, 256, 512] 135 | strides = [1, 2, 2, 2] 136 | if self.selfcon_size == 'same': 137 | num_blocks = self.num_blocks 138 | elif self.selfcon_size == 'small': 139 | num_blocks = [int(n/2) for n in self.num_blocks] 140 | elif self.selfcon_size == 'large': 141 | num_blocks = [int(n*2) for n in self.num_blocks] 142 | elif self.selfcon_size == 'fc': 143 | pass 144 | else: 145 | raise NotImplemented 146 | 147 | if not pos: 148 | return None 149 | else: 150 | if self.selfcon_size == 'fc': 151 | return nn.Linear(channels[idx] * self.block.expansion, channels[-1] * self.block.expansion) 152 | else: 153 | if self.selfcon_arch == 'resnet': 154 | # selfcon layer do not share any parameters 155 | layers = [] 156 | for i in range(idx+1, 4): 157 | in_planes = channels[i-1] * self.block.expansion 158 | layers.append(resnet_sub_layer(self.block, in_planes, channels[i], num_blocks[i], strides[i])) 159 | elif self.selfcon_arch == 'vgg': 160 | raise NotImplemented 161 | elif self.selfcon_arch == 'efficientnet': 162 | raise NotImplemented 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | sub_out = [] 168 | 169 | x = F.relu(self.bn1(self.conv1(x))) 170 | if self.large: 171 | x = self.maxpool(x) 172 | 173 | x = self.layer1(x) 174 | if self.selfcon_layer[0]: 175 | if self.selfcon_size != 'fc': 176 | out = self.selfcon_layer[0](x) 177 | out = torch.flatten(self.avgpool(out), 1) 178 | else: 179 | out = torch.flatten(self.avgpool(x), 1) 180 | out = self.selfcon_layer[0](out) 181 | sub_out.append(out) 182 | 183 | x = self.layer2(x) 184 | if self.selfcon_layer[1]: 185 | if self.selfcon_size != 'fc': 186 | out = self.selfcon_layer[1](x) 187 | out = torch.flatten(self.avgpool(out), 1) 188 | else: 189 | out = torch.flatten(self.avgpool(x), 1) 190 | out = self.selfcon_layer[1](out) 191 | sub_out.append(out) 192 | 193 | x = self.layer3(x) 194 | if self.selfcon_layer[2]: 195 | if self.selfcon_size != 'fc': 196 | out = self.selfcon_layer[2](x) 197 | out = torch.flatten(self.avgpool(out), 1) 198 | else: 199 | out = torch.flatten(self.avgpool(x), 1) 200 | out = self.selfcon_layer[2](out) 201 | sub_out.append(out) 202 | 203 | out = self.layer4(x) 204 | out = self.avgpool(out) 205 | out = torch.flatten(out, 1) 206 | 207 | return sub_out, out 208 | 209 | 210 | def resnet18(**kwargs): 211 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 212 | 213 | 214 | def resnet34(**kwargs): 215 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 216 | 217 | 218 | def resnet50(**kwargs): 219 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 220 | 221 | 222 | def resnet101(**kwargs): 223 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 224 | 225 | 226 | model_dict = { 227 | 'resnet18': [resnet18, 512], 228 | 'resnet34': [resnet34, 512], 229 | 'resnet50': [resnet50, 2048], 230 | 'resnet101': [resnet101, 2048], 231 | } 232 | 233 | 234 | class LinearBatchNorm(nn.Module): 235 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" 236 | def __init__(self, dim, affine=True): 237 | super(LinearBatchNorm, self).__init__() 238 | self.dim = dim 239 | self.bn = nn.BatchNorm2d(dim, affine=affine) 240 | 241 | def forward(self, x): 242 | x = x.view(-1, self.dim, 1, 1) 243 | x = self.bn(x) 244 | x = x.view(-1, self.dim) 245 | return x 246 | 247 | 248 | class ConResNet(nn.Module): 249 | """backbone + projection head""" 250 | def __init__(self, name='resnet50', head='mlp', feat_dim=128, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''): 251 | super(ConResNet, self).__init__() 252 | model_fun, dim_in = model_dict[name] 253 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset) 254 | if head == 'linear': 255 | self.head = nn.Linear(dim_in, feat_dim) 256 | 257 | self.sub_heads = [] 258 | for pos in selfcon_pos: 259 | if pos: 260 | self.sub_heads.append(nn.Linear(dim_in, feat_dim)) 261 | elif head == 'mlp': 262 | self.head = nn.Sequential( 263 | nn.Linear(dim_in, dim_in), 264 | nn.ReLU(inplace=True), 265 | nn.Linear(dim_in, feat_dim) 266 | ) 267 | 268 | heads = [] 269 | for pos in selfcon_pos: 270 | if pos: 271 | heads.append(nn.Sequential( 272 | nn.Linear(dim_in, dim_in), 273 | nn.ReLU(inplace=True), 274 | nn.Linear(dim_in, feat_dim) 275 | )) 276 | self.sub_heads = nn.ModuleList(heads) 277 | else: 278 | raise NotImplementedError( 279 | 'head not supported: {}'.format(head)) 280 | 281 | def forward(self, x): 282 | sub_feat, feat = self.encoder(x) 283 | 284 | sh_feat = [] 285 | for sf, sub_head in zip(sub_feat, self.sub_heads): 286 | sh_feat.append(F.normalize(sub_head(sf), dim=1)) 287 | 288 | feat = F.normalize(self.head(feat), dim=1) 289 | return sh_feat, feat 290 | 291 | 292 | class CEResNet(nn.Module): 293 | """encoder + classifier""" 294 | def __init__(self, name='resnet50', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''): 295 | super(CEResNet, self).__init__() 296 | self.method = method 297 | 298 | model_fun, dim_in = model_dict[name] 299 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset) 300 | 301 | logit_fcs, feat_fcs = [], [] 302 | for pos in selfcon_pos: 303 | if pos: 304 | logit_fcs.append(nn.Linear(dim_in, num_classes)) 305 | feat_fcs.append(nn.Linear(dim_in, dim_out)) 306 | 307 | self.logit_fc = nn.ModuleList(logit_fcs) 308 | self.l_fc = nn.Linear(dim_in, num_classes) 309 | 310 | if method not in ['ce', 'subnet_ce', 'kd']: 311 | self.feat_fc = nn.ModuleList(feat_fcs) 312 | self.f_fc = nn.Linear(dim_in, dim_out) 313 | for param in self.f_fc.parameters(): 314 | param.requires_grad = False 315 | 316 | def forward(self, x): 317 | sub_feat, feat = self.encoder(x) 318 | 319 | feats, logits = [], [] 320 | 321 | for idx, sh_feat in enumerate(sub_feat): 322 | logits.append(self.logit_fc[idx](sh_feat)) 323 | if self.method not in ['ce', 'subnet_ce', 'kd']: 324 | out = self.feat_fc[idx](sh_feat) 325 | feats.append(F.normalize(out, dim=1)) 326 | 327 | if self.method not in ['ce', 'subnet_ce', 'kd']: 328 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)] 329 | else: 330 | return [logits, self.l_fc(feat)] 331 | 332 | 333 | class LinearClassifier(nn.Module): 334 | """Linear classifier""" 335 | def __init__(self, name='resnet50', num_classes=10): 336 | super(LinearClassifier, self).__init__() 337 | _, feat_dim = model_dict[name] 338 | self.fc = nn.Linear(feat_dim, num_classes) 339 | 340 | def forward(self, features): 341 | return self.fc(features) 342 | -------------------------------------------------------------------------------- /main_represent.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | import copy 9 | import random 10 | import builtins 11 | import numpy as np 12 | 13 | import torch 14 | import torch.backends.cudnn as cudnn 15 | from torchvision import transforms, datasets 16 | from RandAugment import RandAugment 17 | 18 | from losses import ConLoss 19 | from utils.util import * 20 | from utils.tinyimagenet import TinyImageNet 21 | from utils.imagenet import ImageNetSubset 22 | from networks.resnet_big import ConResNet 23 | from networks.vgg_big import ConVGG 24 | from networks.wrn_big import ConWRN 25 | from networks.efficient_big import ConEfficientNet 26 | 27 | 28 | def parse_option(): 29 | parser = argparse.ArgumentParser('argument for training') 30 | 31 | parser.add_argument('--exp_name', type=str, default='') 32 | parser.add_argument('--seed', type=int, default=0) 33 | parser.add_argument('--print_freq', type=int, default=10) 34 | parser.add_argument('--save_freq', type=int, default=0) 35 | parser.add_argument('--save_dir', type=str, default='./save/representation') 36 | parser.add_argument('--resume', help='path of model checkpoint to resume', type=str, 37 | default='') 38 | 39 | # dataset 40 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet', 'imagenet', 'imagenet100']) 41 | parser.add_argument('--data_folder', type=str, default='datasets/') 42 | parser.add_argument('--batch_size', type=int, default=256) 43 | parser.add_argument('--num_workers', type=int, default=16) 44 | 45 | # model 46 | parser.add_argument('--model', type=str, default='resnet50') 47 | parser.add_argument('--selfcon_pos', type=str, default='[False,False,False]', 48 | help='where to augment the paths') 49 | parser.add_argument('--selfcon_arch', type=str, default='resnet', 50 | choices=['resnet', 'vgg', 'efficientnet', 'wrn'], help='which architecture to form a sub-network') 51 | parser.add_argument('--selfcon_size', type=str, default='same', 52 | choices=['fc', 'same', 'small'], help='argument for num_blocks of a sub-network') 53 | parser.add_argument('--feat_dim', type=int, default=128, 54 | help='feature dimension for mlp') 55 | 56 | # optimization 57 | parser.add_argument('--epochs', type=int, default=1000) 58 | parser.add_argument('--learning_rate', type=float, default=0.05) 59 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900') 60 | parser.add_argument('--lr_decay_rate', type=float, default=0.1) 61 | parser.add_argument('--weight_decay', type=float, default=1e-4) 62 | parser.add_argument('--momentum', type=float, default=0.9) 63 | parser.add_argument('--precision', action='store_true', 64 | help='whether to use 16 bit precision or not') 65 | parser.add_argument('--cosine', action='store_true', 66 | help='using cosine annealing') 67 | parser.add_argument('--warm', action='store_true', 68 | help='warm-up for large batch training') 69 | parser.add_argument('--temp', type=float, default=0.07, 70 | help='temperature for loss function') 71 | 72 | # important arguments 73 | parser.add_argument('--method', type=str, 74 | choices=['Con', 'SupCon', 'SelfCon'], help='choose method') 75 | parser.add_argument('--multiview', action='store_true', 76 | help='use multiview batch or not') 77 | parser.add_argument('--label', action='store_false', 78 | help='whether to use label information or not') 79 | parser.add_argument('--alpha', type=float, default=0.0, 80 | help='weight for selfcon with multiview loss function') 81 | 82 | # other arguments 83 | parser.add_argument('--randaug', action='store_true', 84 | help='whether to add randaugment or not') 85 | parser.add_argument('--weakaug', action='store_true', 86 | help='whether to use weak augmentation or not') 87 | 88 | opt = parser.parse_args() 89 | 90 | if opt.model.startswith('vgg'): 91 | if opt.selfcon_pos == '[False,False,False]': 92 | opt.selfcon_pos = '[False,False,False,False]' 93 | opt.selfcon_arch = 'vgg' 94 | elif opt.model.startswith('wrn'): 95 | if opt.selfcon_pos == '[False,False,False]': 96 | opt.selfcon_pos = '[False,False]' 97 | opt.selfcon_arch = 'wrn' 98 | 99 | # set the path according to the environment 100 | opt.model_path = '%s/%s/%s_models' % (opt.save_dir, opt.method, opt.dataset) 101 | 102 | if opt.dataset == 'cifar10': 103 | opt.n_cls = 10 104 | elif opt.dataset == 'cifar100': 105 | opt.n_cls = 100 106 | elif opt.dataset == 'tinyimagenet': 107 | opt.n_cls = 200 108 | elif opt.dataset == 'imagenet': 109 | opt.n_cls = 1000 110 | elif opt.dataset == 'imagenet100': 111 | opt.n_cls = 100 112 | else: 113 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 114 | 115 | iterations = opt.lr_decay_epochs.split(',') 116 | opt.lr_decay_epochs = list([]) 117 | for it in iterations: 118 | opt.lr_decay_epochs.append(int(it)) 119 | opt.model_name = '{}_{}_{}_lr_{}_multiview_{}_label_{}_decay_{}_bsz_{}_temp_{}_seed_{}'.\ 120 | format(opt.method, opt.dataset, opt.model, opt.learning_rate, 121 | opt.multiview, opt.label, opt.weight_decay, opt.batch_size, 122 | opt.temp, opt.seed) 123 | 124 | # warm-up for large-batch training, 125 | if opt.batch_size >= 1024: 126 | opt.warm = True 127 | if opt.warm: 128 | opt.warmup_from = 0.01 129 | opt.warm_epochs = 10 130 | if opt.cosine: 131 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 132 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 133 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 134 | else: 135 | opt.warmup_to = opt.learning_rate 136 | 137 | if opt.cosine: 138 | opt.model_name = '{}_cosine'.format(opt.model_name) 139 | if opt.warm: 140 | opt.model_name = '{}_warm'.format(opt.model_name) 141 | if opt.exp_name: 142 | opt.model_name = '{}_{}'.format(opt.model_name, opt.exp_name) 143 | 144 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 145 | if not os.path.isdir(opt.save_folder): 146 | os.makedirs(opt.save_folder) 147 | 148 | return opt 149 | 150 | 151 | def set_loader(opt): 152 | # construct data loader 153 | if opt.dataset == 'cifar10': 154 | mean = (0.4914, 0.4822, 0.4465) 155 | std = (0.2023, 0.1994, 0.2010) 156 | size = 32 157 | elif opt.dataset == 'cifar100': 158 | mean = (0.5071, 0.4867, 0.4408) 159 | std = (0.2675, 0.2565, 0.2761) 160 | size = 32 161 | elif opt.dataset == 'tinyimagenet': 162 | mean = (0.485, 0.456, 0.406) 163 | std = (0.229, 0.224, 0.225) 164 | size = 64 165 | elif opt.dataset == 'imagenet' or opt.dataset == 'imagenet100': 166 | mean = (0.485, 0.456, 0.406) 167 | std = (0.229, 0.224, 0.225) 168 | size = 224 169 | else: 170 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 171 | 172 | normalize = transforms.Normalize(mean=mean, std=std) 173 | transform = [transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), transforms.RandomHorizontalFlip()] 174 | if not opt.weakaug: 175 | transform += [transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 176 | transforms.RandomGrayscale(p=0.2)] 177 | 178 | transform += [transforms.ToTensor(), normalize] 179 | train_transform = transforms.Compose(transform) 180 | 181 | if opt.randaug: 182 | train_transform.transforms.insert(0, RandAugment(2, 9)) 183 | if opt.multiview: 184 | train_transform = TwoCropTransform(train_transform) 185 | 186 | if opt.dataset == 'cifar10': 187 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 188 | transform=train_transform, 189 | download=True) 190 | elif opt.dataset == 'cifar100': 191 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 192 | transform=train_transform, 193 | download=True) 194 | elif opt.dataset == 'tinyimagenet': 195 | train_dataset = TinyImageNet(root=opt.data_folder, 196 | transform=train_transform, 197 | download=True) 198 | elif opt.dataset == 'imagenet': 199 | traindir = os.path.join(opt.data_folder, 'train') 200 | train_dataset = datasets.ImageFolder(root=traindir, 201 | transform=train_transform) 202 | elif opt.dataset == 'imagenet100': 203 | traindir = os.path.join(opt.data_folder, 'train') 204 | train_dataset = ImageNetSubset('./utils/imagenet100.txt', 205 | root=traindir, 206 | transform=train_transform) 207 | else: 208 | raise ValueError(opt.dataset) 209 | 210 | train_loader = torch.utils.data.DataLoader( 211 | train_dataset, batch_size=opt.batch_size, shuffle=True, 212 | num_workers=opt.num_workers, pin_memory=True, sampler=None) 213 | 214 | return train_loader 215 | 216 | 217 | def set_model(opt): 218 | model_kwargs = {'name': opt.model, 219 | 'dataset': opt.dataset, 220 | 'selfcon_pos': eval(opt.selfcon_pos), 221 | 'selfcon_arch': opt.selfcon_arch, 222 | 'selfcon_size': opt.selfcon_size 223 | } 224 | if opt.model.startswith('resnet'): 225 | model = ConResNet(**model_kwargs) 226 | elif opt.model.startswith('vgg'): 227 | model = ConVGG(**model_kwargs) 228 | elif opt.model.startswith('wrn'): 229 | model = ConWRN(**model_kwargs) 230 | elif opt.model.startswith('eff'): 231 | model = ConEfficientNet(**model_kwargs) 232 | 233 | criterion = ConLoss(temperature=opt.temp) 234 | 235 | if torch.cuda.is_available(): 236 | if torch.cuda.device_count() > 1: 237 | model.encoder = torch.nn.DataParallel(model.encoder) 238 | model = model.cuda() 239 | criterion = criterion.cuda() 240 | cudnn.benchmark = True 241 | 242 | return model, criterion, opt 243 | 244 | 245 | def _train(images, labels, model, criterion, epoch, bsz, opt): 246 | # compute loss 247 | features = model(images) 248 | if opt.method == 'Con': 249 | f1, f2 = torch.split(features[1], [bsz, bsz], dim=0) 250 | elif opt.method == 'SupCon': 251 | if opt.multiview: 252 | f1, f2 = torch.split(features[1], [bsz, bsz], dim=0) 253 | else: # opt.method == 'SelfCon' 254 | f1, f2 = features 255 | 256 | if opt.method == 'SupCon': 257 | # SupCon 258 | if opt.multiview: 259 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 260 | loss = criterion(features, labels) 261 | # SupCon-S 262 | else: 263 | features = features[1].unsqueeze(1) 264 | loss = criterion(features, labels, supcon_s=True) 265 | 266 | elif opt.method == 'Con': 267 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 268 | loss = criterion(features) 269 | elif opt.method == 'SelfCon': 270 | loss = torch.tensor([0.0]).cuda() 271 | # SelfCon 272 | if not opt.multiview: 273 | if not opt.alpha: 274 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1) 275 | # SelfCon-SU 276 | if not opt.label: 277 | loss += criterion(features) 278 | # SelfCon 279 | else: 280 | loss += criterion(features, labels) 281 | else: 282 | features = f2.unsqueeze(1) 283 | if opt.label: 284 | loss += criterion(features, labels, supcon_s=True) 285 | 286 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1) 287 | # SelfCon-SU* 288 | if not opt.label: 289 | loss += opt.alpha * criterion(features, selfcon_s_FG=True) 290 | # SelfCon-S* 291 | else: 292 | loss += opt.alpha * criterion(features, labels, selfcon_s_FG=True) 293 | # SelfCon-M 294 | else: 295 | if not opt.alpha: 296 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1) 297 | labels_repeat = torch.cat([labels, labels], dim=0) 298 | # SelfCon-MU 299 | if not opt.label: 300 | loss += criterion(features) 301 | # SelfCon-M 302 | else: 303 | loss += criterion(features, labels_repeat) 304 | else: 305 | f2_1, f2_2 = torch.split(f2, [bsz, bsz], dim=0) 306 | features = torch.cat([f2_1.unsqueeze(1), f2_2.unsqueeze(1)], dim=1) 307 | # contrastive loss between F (backbone) 308 | if not opt.label: 309 | loss += criterion(features) 310 | else: 311 | loss += criterion(features, labels) 312 | 313 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1) 314 | # SelfCon-MU* 315 | if not opt.label: 316 | loss += opt.alpha * criterion(features, selfcon_m_FG=True) 317 | # SelfCon-M* 318 | else: 319 | loss += opt.alpha * criterion(features, labels, selfcon_m_FG=True) 320 | else: 321 | raise ValueError('contrastive method not supported: {}'. 322 | format(opt.method)) 323 | 324 | return loss 325 | 326 | 327 | def train(train_loader, model, criterion, optimizer, epoch, opt): 328 | """one epoch training""" 329 | model.train() 330 | if opt.precision: 331 | scaler = torch.cuda.amp.GradScaler() 332 | 333 | batch_time = AverageMeter() 334 | data_time = AverageMeter() 335 | losses = AverageMeter() 336 | 337 | end = time.time() 338 | for idx, (images, labels) in enumerate(train_loader): 339 | data_time.update(time.time() - end) 340 | 341 | bsz = labels.shape[0] 342 | 343 | if opt.multiview: 344 | images = torch.cat([images[0], images[1]], dim=0) 345 | if torch.cuda.is_available(): 346 | images = images.cuda(non_blocking=True) 347 | labels = labels.cuda(non_blocking=True) 348 | 349 | # warm-up learning rate 350 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 351 | 352 | if opt.precision: 353 | with torch.cuda.amp.autocast(): 354 | loss = _train(images, labels, model, criterion, epoch, bsz, opt) 355 | else: 356 | loss = _train(images, labels, model, criterion, epoch, bsz, opt) 357 | 358 | # update metric 359 | losses.update(loss.item(), bsz) 360 | 361 | # SGD 362 | optimizer.zero_grad() 363 | if not opt.precision: 364 | loss.backward() 365 | optimizer.step() 366 | else: 367 | scaler.scale(loss).backward() 368 | scaler.step(optimizer) 369 | scaler.update() 370 | 371 | # measure elapsed time 372 | batch_time.update(time.time() - end) 373 | end = time.time() 374 | 375 | # print info 376 | if (idx + 1) % opt.print_freq == 0: 377 | print('Train: [{0}][{1}/{2}]\t' 378 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 379 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 380 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( 381 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 382 | data_time=data_time, loss=losses)) 383 | sys.stdout.flush() 384 | 385 | return losses.avg 386 | 387 | 388 | def main(): 389 | opt = parse_option() 390 | 391 | np.random.seed(opt.seed) 392 | random.seed(opt.seed) 393 | torch.manual_seed(opt.seed) 394 | torch.cuda.manual_seed(opt.seed) 395 | # cudnn.deterministic = True 396 | 397 | # build model and criterion 398 | model, criterion, opt = set_model(opt) 399 | 400 | # build data loader 401 | train_loader = set_loader(opt) 402 | 403 | # build optimizer 404 | optimizer = set_optimizer(opt, model) 405 | 406 | if opt.resume: 407 | if os.path.isfile(opt.resume): 408 | print("=> loading checkpoint '{}'".format(opt.resume)) 409 | checkpoint = torch.load(opt.resume) 410 | opt.start_epoch = checkpoint['epoch'] 411 | model.load_state_dict(checkpoint['model']) 412 | optimizer.load_state_dict(checkpoint['optimizer']) 413 | opt.start_epoch += 1 414 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch'])) 415 | else: 416 | print("=> no checkpoint found at '{}'".format(opt.resume)) 417 | else: 418 | opt.start_epoch = 1 419 | 420 | # training routine 421 | for epoch in range(opt.start_epoch, opt.epochs + 1): 422 | adjust_learning_rate(opt, optimizer, epoch) 423 | 424 | # train for one epoch 425 | time1 = time.time() 426 | loss = train(train_loader, model, criterion, optimizer, epoch, opt) 427 | time2 = time.time() 428 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 429 | 430 | if opt.save_freq: 431 | if epoch % opt.save_freq == 0: 432 | save_file = os.path.join( 433 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 434 | save_model(model, optimizer, opt, epoch, save_file) 435 | 436 | # save the last model 437 | save_file = os.path.join( 438 | opt.save_folder, 'last.pth') 439 | save_model(model, optimizer, opt, epoch, save_file) 440 | 441 | 442 | if __name__ == '__main__': 443 | main() 444 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import warnings 7 | import time 8 | import math 9 | import random 10 | import builtins 11 | import numpy as np 12 | 13 | import torch 14 | import torch.backends.cudnn as cudnn 15 | import torch.multiprocessing as mp 16 | import torch.distributed as dist 17 | from torchvision import transforms, datasets 18 | 19 | from utils.util import * 20 | from utils.tinyimagenet import TinyImageNet 21 | from utils.imagenet import ImageNetSubset 22 | from networks.resnet_big import ConResNet, LinearClassifier 23 | from networks.vgg_big import ConVGG, LinearClassifier_VGG 24 | from networks.wrn_big import ConWRN, LinearClassifier_WRN 25 | from networks.efficient_big import ConEfficientNet, LinearClassifier_EFF 26 | 27 | 28 | def parse_option(): 29 | parser = argparse.ArgumentParser('argument for training') 30 | 31 | parser.add_argument('--exp_name', type=str, default='') 32 | parser.add_argument('--seed', type=int, default=0) 33 | parser.add_argument('--print_freq', type=int, default=10) 34 | parser.add_argument('--save_dir', type=str, default='./save/representation') 35 | parser.add_argument('--ckpt', type=str, default='', 36 | help='path for pre-trained model') 37 | parser.add_argument('--subnet', action='store_true', 38 | help='measure the accuracy of sub-network or not') 39 | 40 | # dataset 41 | parser.add_argument('--dataset', type=str, default='imagenet', choices=['cifar10', 'cifar100', 'tinyimagenet', 'imagenet', 'imagenet100']) 42 | parser.add_argument('--data_folder', type=str, default='datasets/') 43 | parser.add_argument('--batch_size', type=int, default=256) 44 | parser.add_argument('--num_workers', type=int, default=16) 45 | 46 | # model 47 | parser.add_argument('--model', type=str, default='resnet50') 48 | parser.add_argument('--selfcon_pos', type=str, default='[False,False,False]', 49 | help='where to augment the paths') 50 | parser.add_argument('--selfcon_arch', type=str, default='resnet', 51 | choices=['resnet', 'vgg', 'efficientnet', 'wrn'], help='which architecture to form a sub-network') 52 | parser.add_argument('--selfcon_size', type=str, default='same', 53 | choices=['fc', 'same', 'small'], help='argument for num_blocks of a sub-network') 54 | 55 | # optimization 56 | parser.add_argument('--epochs', type=int, default=100) 57 | parser.add_argument('--learning_rate', type=float, default=0.1) 58 | parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90') 59 | parser.add_argument('--lr_decay_rate', type=float, default=0.2) 60 | parser.add_argument('--weight_decay', type=float, default=0) 61 | parser.add_argument('--momentum', type=float, default=0.9) 62 | parser.add_argument('--cosine', action='store_true', 63 | help='using cosine annealing') 64 | parser.add_argument('--warm', action='store_true', 65 | help='warm-up for large batch training') 66 | 67 | opt = parser.parse_args() 68 | 69 | iterations = opt.lr_decay_epochs.split(',') 70 | opt.lr_decay_epochs = list([]) 71 | for it in iterations: 72 | opt.lr_decay_epochs.append(int(it)) 73 | 74 | opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\ 75 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay, 76 | opt.batch_size) 77 | 78 | if opt.cosine: 79 | opt.model_name = '{}_cosine'.format(opt.model_name) 80 | 81 | # warm-up for large-batch training, 82 | if opt.warm: 83 | opt.model_name = '{}_warm'.format(opt.model_name) 84 | opt.warmup_from = 0.01 85 | opt.warm_epochs = 10 86 | if opt.cosine: 87 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 88 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 89 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 90 | else: 91 | opt.warmup_to = opt.learning_rate 92 | 93 | if opt.dataset == 'cifar10': 94 | opt.n_cls = 10 95 | elif opt.dataset == 'cifar100': 96 | opt.n_cls = 100 97 | elif opt.dataset == 'tinyimagenet': 98 | opt.n_cls = 200 99 | elif opt.dataset == 'imagenet': 100 | opt.n_cls = 1000 101 | elif opt.dataset == 'imagenet100': 102 | opt.n_cls = 100 103 | else: 104 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 105 | 106 | return opt 107 | 108 | 109 | def set_loader(opt): 110 | # construct data loader 111 | if opt.dataset == 'cifar10': 112 | mean = (0.4914, 0.4822, 0.4465) 113 | std = (0.2023, 0.1994, 0.2010) 114 | size = 32 115 | elif opt.dataset == 'cifar100': 116 | mean = (0.5071, 0.4867, 0.4408) 117 | std = (0.2675, 0.2565, 0.2761) 118 | size = 32 119 | elif opt.dataset == 'tinyimagenet': 120 | mean = (0.485, 0.456, 0.406) 121 | std = (0.229, 0.224, 0.225) 122 | size = 64 123 | elif opt.dataset == 'imagenet': 124 | mean = (0.485, 0.456, 0.406) 125 | std = (0.229, 0.224, 0.225) 126 | size = 224 127 | elif opt.dataset == 'imagenet100': 128 | mean = (0.485, 0.456, 0.406) 129 | std = (0.229, 0.224, 0.225) 130 | size = 224 131 | else: 132 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 133 | normalize = transforms.Normalize(mean=mean, std=std) 134 | 135 | train_transform = transforms.Compose([ 136 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 137 | transforms.RandomHorizontalFlip(), 138 | transforms.ToTensor(), 139 | normalize, 140 | ]) 141 | 142 | # if opt.randaug: 143 | # train_transform.transforms.insert(0, RandAugment(2, 9)) 144 | 145 | if opt.dataset not in ['imagenet', 'imagenet100']: 146 | val_transform = transforms.Compose([ 147 | transforms.ToTensor(), 148 | normalize, 149 | ]) 150 | else: 151 | val_transform = transforms.Compose([transforms.Resize(256), 152 | transforms.CenterCrop(224), 153 | transforms.ToTensor(), 154 | normalize]) 155 | 156 | if opt.dataset == 'cifar10': 157 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 158 | transform=train_transform, 159 | download=True) 160 | val_dataset = datasets.CIFAR10(root=opt.data_folder, 161 | train=False, 162 | transform=val_transform) 163 | elif opt.dataset == 'cifar100': 164 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 165 | transform=train_transform, 166 | download=True) 167 | val_dataset = datasets.CIFAR100(root=opt.data_folder, 168 | train=False, 169 | transform=val_transform) 170 | elif opt.dataset == 'tinyimagenet': 171 | train_dataset = TinyImageNet(root=opt.data_folder, 172 | transform=train_transform, 173 | download=True) 174 | val_dataset = TinyImageNet(root=opt.data_folder, 175 | train=False, 176 | transform=val_transform) 177 | elif opt.dataset == 'imagenet': 178 | traindir = os.path.join(opt.data_folder, 'train') 179 | train_dataset = datasets.ImageFolder(root=traindir, 180 | transform=train_transform) 181 | 182 | valdir = os.path.join(opt.data_folder, 'val') 183 | val_dataset = datasets.ImageFolder(root=valdir, 184 | transform=val_transform) 185 | elif opt.dataset == 'imagenet100': 186 | traindir = os.path.join(opt.data_folder, 'train') 187 | train_dataset = ImageNetSubset('./utils/imagenet100.txt', 188 | root=traindir, 189 | transform=train_transform) 190 | 191 | valdir = os.path.join(opt.data_folder, 'val') 192 | val_dataset = ImageNetSubset('./utils/imagenet100.txt', 193 | root=valdir, 194 | transform=val_transform) 195 | else: 196 | raise ValueError(opt.dataset) 197 | 198 | train_sampler = None 199 | train_loader = torch.utils.data.DataLoader( 200 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), 201 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) 202 | val_loader = torch.utils.data.DataLoader( 203 | val_dataset, batch_size=512, shuffle=False, 204 | num_workers=8, pin_memory=True) 205 | 206 | return train_loader, val_loader, train_sampler 207 | 208 | def set_model(opt): 209 | model_kwargs = {'name': opt.model, 210 | 'dataset': opt.dataset, 211 | 'selfcon_pos': eval(opt.selfcon_pos), 212 | 'selfcon_arch': opt.selfcon_arch, 213 | 'selfcon_size': opt.selfcon_size 214 | } 215 | 216 | if opt.model.startswith('resnet'): 217 | model = ConResNet(**model_kwargs) 218 | classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls) 219 | if opt.subnet: 220 | sub_classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls) 221 | 222 | elif opt.model.startswith('vgg'): 223 | model = ConVGG(**model_kwargs) 224 | classifier = LinearClassifier_VGG(name=opt.model, num_classes=opt.n_cls) 225 | if opt.subnet: 226 | sub_classifier = LinearClassifier_VGG(name=opt.model, num_classes=opt.n_cls) 227 | 228 | elif opt.model.startswith('wrn'): 229 | model = ConWRN(**model_kwargs) 230 | classifier = LinearClassifier_WRN(name=opt.model, num_classes=opt.n_cls) 231 | if opt.subnet: 232 | sub_classifier = LinearClassifier_WRN(name=opt.model, num_classes=opt.n_cls) 233 | 234 | elif opt.model.startswith('eff'): 235 | model = ConEfficientNet(**model_kwargs) 236 | classifier = LinearClassifier_EFF(name=opt.model, num_classes=opt.n_cls) 237 | if opt.subnet: 238 | sub_classifier = LinearClassifier_EFF(name=opt.model, num_classes=opt.n_cls) 239 | 240 | criterion = torch.nn.CrossEntropyLoss() 241 | if opt.ckpt: 242 | ckpt = torch.load(opt.ckpt, map_location='cpu') 243 | state_dict = ckpt['model'] 244 | 245 | if torch.cuda.is_available(): 246 | if torch.cuda.device_count() > 1: 247 | model.encoder = torch.nn.DataParallel(model.encoder) 248 | else: 249 | if opt.ckpt: 250 | new_state_dict = {} 251 | for k, v in state_dict.items(): 252 | k = k.replace("module.", "") 253 | new_state_dict[k] = v 254 | state_dict = new_state_dict 255 | 256 | model.cuda() 257 | classifier = classifier.cuda() 258 | if opt.subnet: 259 | sub_classifier = sub_classifier.cuda() 260 | criterion = criterion.cuda() 261 | cudnn.benchmark = True 262 | 263 | if opt.ckpt: 264 | state_dict = {k.replace("downsample", "shortcut"): v for k, v in state_dict.items()} 265 | model.load_state_dict(state_dict, strict=False) 266 | 267 | if not opt.subnet: 268 | sub_classifier = None 269 | return model, classifier, sub_classifier, criterion, opt 270 | 271 | 272 | def train(train_loader, model, classifier, criterion, optimizer, epoch, opt, subnet=False): 273 | """one epoch training""" 274 | model.eval() 275 | classifier.train() 276 | 277 | batch_time = AverageMeter() 278 | data_time = AverageMeter() 279 | losses = AverageMeter() 280 | top1 = AverageMeter() 281 | 282 | end = time.time() 283 | for idx, (images, labels) in enumerate(train_loader): 284 | data_time.update(time.time() - end) 285 | 286 | images = images.cuda(non_blocking=True) 287 | labels = labels.cuda(non_blocking=True) 288 | bsz = labels.shape[0] 289 | 290 | # warm-up learning rate 291 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 292 | 293 | # compute loss 294 | with torch.no_grad(): 295 | features = model.encoder(images) 296 | features = features[-1] if not subnet else features[0][-1] 297 | output = classifier(features.detach()) 298 | loss = criterion(output, labels) 299 | 300 | # update metric 301 | losses.update(loss.item(), bsz) 302 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 303 | top1.update(acc1[0], bsz) 304 | 305 | # SGD 306 | optimizer.zero_grad() 307 | loss.backward() 308 | optimizer.step() 309 | 310 | # measure elapsed time 311 | batch_time.update(time.time() - end) 312 | end = time.time() 313 | 314 | # print info 315 | if (idx + 1) % opt.print_freq == 0: 316 | print('Train: [{0}][{1}/{2}]\t' 317 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 318 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 319 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 320 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 321 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 322 | data_time=data_time, loss=losses, top1=top1)) 323 | sys.stdout.flush() 324 | 325 | return losses.avg, top1.avg 326 | 327 | 328 | def validate(val_loader, model, classifier, sub_classifier, criterion, opt, best_acc): 329 | def __update_metric(output, labels, top1, top5, bsz): 330 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) 331 | top1.update(acc1[0], bsz) 332 | top5.update(acc5[0], bsz) 333 | 334 | return top1, top5 335 | 336 | def __best_acc(val_acc1, val_acc5, best_acc, key='backbone'): 337 | if val_acc1.item() > best_acc[key][0]: 338 | best_acc[key][0] = val_acc1.item() 339 | best_acc[key][1] = val_acc5.item() 340 | 341 | return best_acc 342 | 343 | """validation""" 344 | model.eval() 345 | classifier.eval() 346 | if sub_classifier: 347 | sub_classifier.eval() 348 | 349 | batch_time = AverageMeter() 350 | losses = AverageMeter() 351 | top1, top5 = AverageMeter(), AverageMeter() 352 | top1_sub, top5_sub = AverageMeter(), AverageMeter() 353 | top1_ens, top5_ens = AverageMeter(), AverageMeter() 354 | 355 | with torch.no_grad(): 356 | end = time.time() 357 | for idx, (images, labels) in enumerate(val_loader): 358 | images = images.float().cuda() 359 | labels = labels.cuda() 360 | bsz = labels.shape[0] 361 | 362 | # forward 363 | features = model.encoder(images) 364 | output = classifier(features[-1]) 365 | loss = criterion(output, labels) 366 | 367 | # for only one subnetwork 368 | if opt.subnet: 369 | sub_output = sub_classifier(features[0][-1]) 370 | ensemble_output = (output + sub_output) / 2 371 | 372 | # update metric 373 | losses.update(loss.item(), bsz) 374 | top1, top5 = __update_metric(output, labels, top1, top5, bsz) 375 | if opt.subnet: 376 | top1_sub, top5_sub = __update_metric(sub_output, labels, top1_sub, top5_sub, bsz) 377 | top1_ens, top5_ens = __update_metric(ensemble_output, labels, top1_ens, top5_ens, bsz) 378 | 379 | # measure elapsed time 380 | batch_time.update(time.time() - end) 381 | end = time.time() 382 | 383 | if idx % opt.print_freq == 0: 384 | print('Test: [{0}/{1}]\t' 385 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 386 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 387 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 388 | idx, len(val_loader), batch_time=batch_time, 389 | loss=losses, top1=top1)) 390 | 391 | print(' * Acc@1 {top1.avg:.2f}, Acc@5 {top5.avg:.2f}'.format(top1=top1, top5=top5)) 392 | best_acc = __best_acc(top1.avg, top5.avg, best_acc) 393 | 394 | if opt.subnet: 395 | print(' * Acc@1 {top1.avg:.2f}, Acc@5 {top5.avg:.2f}'.format(top1=top1_sub, top5=top5_sub)) 396 | best_acc = __best_acc(top1_sub.avg, top5_sub.avg, best_acc, key='sub') 397 | 398 | print(' * Acc@1 {top1.avg:.2f}, Acc@5 {top5.avg:.2f}'.format(top1=top1_ens, top5=top5_ens)) 399 | best_acc = __best_acc(top1_ens.avg, top5_ens.avg, best_acc, key='ensemble') 400 | return best_acc 401 | 402 | 403 | def main(): 404 | opt = parse_option() 405 | 406 | # fix seed 407 | np.random.seed(opt.seed) 408 | random.seed(opt.seed) 409 | torch.manual_seed(opt.seed) 410 | torch.cuda.manual_seed(opt.seed) 411 | cudnn.deterministic = True 412 | 413 | best_acc = {'backbone': [0, 0, 0], 414 | 'sub': [0, 0, 0], 415 | 'ensemble': [0, 0]} 416 | 417 | # build model and criterion 418 | model, classifier, sub_classifier, criterion, opt = set_model(opt) 419 | 420 | # build data loader 421 | train_loader, val_loader, train_sampler = set_loader(opt) 422 | 423 | # build optimizer 424 | optimizer = set_optimizer(opt, classifier) 425 | sub_optimizer = set_optimizer(opt, sub_classifier) if opt.subnet else None 426 | 427 | # training routine 428 | for epoch in range(1, opt.epochs + 1): 429 | adjust_learning_rate(opt, optimizer, epoch) 430 | if opt.subnet: 431 | adjust_learning_rate(opt, sub_optimizer, epoch) 432 | 433 | # train for one epoch 434 | time1 = time.time() 435 | loss, acc = train(train_loader, model, classifier, criterion, 436 | optimizer, epoch, opt) 437 | time2 = time.time() 438 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format( 439 | epoch, time2 - time1, acc)) 440 | best_acc['backbone'][2] = acc.item() 441 | 442 | if opt.subnet: 443 | _, sub_acc = train(train_loader, model, sub_classifier, criterion, 444 | sub_optimizer, epoch, opt, subnet=True) 445 | print('Train epoch {}, accuracy:{:.2f}'.format( 446 | epoch, sub_acc)) 447 | best_acc['sub'][2] = sub_acc.item() 448 | 449 | # eval for one epoch 450 | best_acc = validate(val_loader, model, classifier, sub_classifier, criterion, opt, best_acc) 451 | 452 | update_json(opt.ckpt + '_%s' % opt.exp_name if opt.exp_name else opt.ckpt, best_acc, path='%s/results.json' % (opt.save_dir)) 453 | 454 | # for robustness experiments 455 | method = 'supcon' 456 | if not os.path.isdir('./robustness/ckpt'): 457 | os.makedirs('./robustness/ckpt') 458 | torch.save(model.state_dict(), './robustness/ckpt/{}_encoder.pth'.format(method)) 459 | torch.save(classifier.state_dict(), './robustness/ckpt/{}_classifier.pth'.format(method)) 460 | 461 | if __name__ == '__main__': 462 | main() 463 | -------------------------------------------------------------------------------- /main_ce.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import argparse 6 | import time 7 | import math 8 | import random 9 | import builtins 10 | import numpy as np 11 | import warnings 12 | warnings.filterwarnings(action='ignore') 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | from torchvision import transforms, datasets 17 | from torch.autograd import Variable 18 | 19 | from networks.resnet_big import CEResNet 20 | from networks.vgg_big import CEVGG 21 | from networks.wrn_big import CEWRN 22 | from networks.efficient_big import CEEffNet 23 | from losses import * 24 | from utils.util import * 25 | from utils.tinyimagenet import TinyImageNet 26 | from utils.imagenet import ImageNetSubset 27 | 28 | 29 | def parse_option(): 30 | parser = argparse.ArgumentParser('argument for training') 31 | 32 | parser.add_argument('--exp_name', type=str, default='') 33 | parser.add_argument('--seed', type=int, default=0) 34 | parser.add_argument('--print_freq', type=int, default=10) 35 | parser.add_argument('--resume', help='path of model checkpoint to resume', type=str, 36 | default='') 37 | 38 | # dataset 39 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet', 'imagenet', 'imagenet100']) 40 | parser.add_argument('--data_folder', type=str, default='datasets/') 41 | parser.add_argument('--batch_size', type=int, default=256) 42 | parser.add_argument('--num_workers', type=int, default=16) 43 | 44 | # model 45 | parser.add_argument('--model', type=str, default='resnet50') 46 | parser.add_argument('--selfcon_pos', type=str, default='[False,False,False]', 47 | help='where to augment the paths') 48 | parser.add_argument('--selfcon_arch', type=str, default='resnet', 49 | choices=['resnet', 'vgg', 'efficientnet', 'wrn'], help='which architecture to form a sub-network') 50 | parser.add_argument('--selfcon_size', type=str, default='same', 51 | choices=['fc', 'same', 'small'], help='argument for num_blocks of a sub-network') 52 | parser.add_argument('--dim_out', default=128, type=int, 53 | help='feat dimension for CEResNet') 54 | 55 | # optimization 56 | parser.add_argument('--epochs', type=int, default=500) 57 | parser.add_argument('--learning_rate', type=float, default=0.2) 58 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450') 59 | parser.add_argument('--lr_decay_rate', type=float, default=0.1) 60 | parser.add_argument('--weight_decay', type=float, default=1e-4) 61 | parser.add_argument('--momentum', type=float, default=0.9) 62 | parser.add_argument('--cosine', action='store_true', 63 | help='using cosine annealing') 64 | parser.add_argument('--warm', action='store_true', 65 | help='warm-up for large batch training') 66 | 67 | # important arguments 68 | parser.add_argument('--method', type=str, 69 | choices=['ce', 'subnet_ce', 'kd', 'selfcon'], help='choose method') 70 | parser.add_argument('--alpha', type=float, default=0., help='weight balance for subnet CE') 71 | parser.add_argument('--beta', type=float, default=0., help='weight balance for KD') 72 | parser.add_argument('--gamma', type=float, default=0., help='weight balance for other losses') 73 | parser.add_argument('--temperature', type=float, default=3.0, help='temperature for KD loss function') 74 | 75 | opt = parser.parse_args() 76 | 77 | if opt.model.startswith('vgg'): 78 | if opt.selfcon_pos == '[False,False,False]': 79 | opt.selfcon_pos = '[False,False,False,False]' 80 | opt.selfcon_arch = 'vgg' 81 | if opt.model.startswith('eff'): 82 | if opt.selfcon_pos == '[False,False,False]': 83 | opt.selfcon_pos = '[False]' 84 | opt.selfcon_arch = 'eff' 85 | 86 | # set the path according to the environment 87 | opt.model_path = './save/distill/%s/%s_models' % (opt.method, opt.dataset) 88 | 89 | iterations = opt.lr_decay_epochs.split(',') 90 | opt.lr_decay_epochs = list([]) 91 | for it in iterations: 92 | opt.lr_decay_epochs.append(int(it)) 93 | 94 | opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_seed_{}'.\ 95 | format(opt.method, opt.dataset, opt.model, opt.learning_rate, 96 | opt.weight_decay, opt.batch_size, opt.seed) 97 | 98 | if opt.cosine: 99 | opt.model_name = '{}_cosine'.format(opt.model_name) 100 | if opt.exp_name: 101 | opt.model_name = '{}_{}'.format(opt.model_name, opt.exp_name) 102 | 103 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 104 | if not os.path.isdir(opt.save_folder): 105 | os.makedirs(opt.save_folder) 106 | 107 | if opt.dataset == 'cifar10': 108 | opt.n_cls = 10 109 | opt.n_data = 50000 110 | elif opt.dataset == 'cifar100': 111 | opt.n_cls = 100 112 | opt.n_data = 50000 113 | elif opt.dataset == 'tinyimagenet': 114 | opt.n_cls = 200 115 | opt.n_data = 100000 116 | elif opt.dataset == 'imagenet': 117 | opt.n_cls = 1000 118 | opt.n_data = 1200000 119 | elif opt.dataset == 'imagenet100': 120 | opt.n_cls = 100 121 | opt.n_data = 120000 122 | else: 123 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 124 | 125 | if opt.method == 'ce': 126 | opt.alpha, opt.beta, opt.gamma = 0, 0, 0 127 | elif opt.method == 'subnet_ce': 128 | opt.alpha, opt.beta, opt.gamma = 1.0, 0, 0 129 | elif opt.method == 'kd': 130 | opt.alpha, opt.beta, opt.gamma = 0.5, 0.5, 0 131 | elif opt.method == 'selfcon': 132 | opt.alpha, opt.beta, opt.gamma = 1.0, 0, 0.8 133 | 134 | return opt 135 | 136 | 137 | def set_loader(opt): 138 | # construct data loader 139 | if opt.dataset == 'cifar10': 140 | mean = (0.4914, 0.4822, 0.4465) 141 | std = (0.2023, 0.1994, 0.2010) 142 | size = 32 143 | elif opt.dataset == 'cifar100': 144 | mean = (0.5071, 0.4867, 0.4408) 145 | std = (0.2675, 0.2565, 0.2761) 146 | size = 32 147 | elif opt.dataset == 'tinyimagenet': 148 | mean = (0.485, 0.456, 0.406) 149 | std = (0.229, 0.224, 0.225) 150 | size = 64 151 | elif opt.dataset == 'imagenet' or opt.dataset == 'imagenet100': 152 | mean = (0.485, 0.456, 0.406) 153 | std = (0.229, 0.224, 0.225) 154 | size = 224 155 | else: 156 | raise ValueError('dataset not supported: {}'.format(opt.dataset)) 157 | normalize = transforms.Normalize(mean=mean, std=std) 158 | 159 | train_transform = transforms.Compose([ 160 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), 161 | transforms.RandomHorizontalFlip(), 162 | transforms.ToTensor(), 163 | normalize, 164 | ]) 165 | 166 | if opt.dataset not in ['imagenet', 'imagenet100']: 167 | val_transform = transforms.Compose([ 168 | transforms.ToTensor(), 169 | normalize, 170 | ]) 171 | else: 172 | val_transform = transforms.Compose([transforms.Resize(256), 173 | transforms.CenterCrop(224), 174 | transforms.ToTensor(), 175 | normalize]) 176 | 177 | if opt.dataset == 'cifar10': 178 | train_dataset = datasets.CIFAR10(root=opt.data_folder, 179 | transform=train_transform, 180 | download=True) 181 | val_dataset = datasets.CIFAR10(root=opt.data_folder, 182 | train=False, 183 | transform=val_transform) 184 | elif opt.dataset == 'cifar100': 185 | train_dataset = datasets.CIFAR100(root=opt.data_folder, 186 | transform=train_transform, 187 | download=True) 188 | val_dataset = datasets.CIFAR100(root=opt.data_folder, 189 | train=False, 190 | transform=val_transform) 191 | elif opt.dataset == 'tinyimagenet': 192 | train_dataset = TinyImageNet(root=opt.data_folder, 193 | transform=train_transform, 194 | download=True) 195 | val_dataset = TinyImageNet(root=opt.data_folder, 196 | train=False, 197 | transform=val_transform) 198 | elif opt.dataset == 'imagenet': 199 | traindir = os.path.join(opt.data_folder, 'train') 200 | valdir = os.path.join(opt.data_folder, 'val') 201 | train_dataset = datasets.ImageFolder(root=traindir, transform=train_transform) 202 | val_dataset = datasets.ImageFolder(root=valdir, transform=val_transform) 203 | elif opt.dataset == 'imagenet100': 204 | traindir = os.path.join(opt.data_folder, 'train') 205 | valdir = os.path.join(opt.data_folder, 'val') 206 | 207 | train_dataset = ImageNetSubset('./utils/imagenet100.txt', 208 | root=traindir, 209 | transform=train_transform) 210 | val_dataset = ImageNetSubset('./utils/imagenet100.txt', 211 | root=valdir, 212 | transform=val_transform) 213 | else: 214 | raise ValueError(opt.dataset) 215 | 216 | train_loader = torch.utils.data.DataLoader( 217 | train_dataset, batch_size=opt.batch_size, shuffle=True, 218 | num_workers=opt.num_workers, pin_memory=True, sampler=None) 219 | val_loader = torch.utils.data.DataLoader( 220 | val_dataset, batch_size=512, shuffle=False, 221 | num_workers=8, pin_memory=True) 222 | 223 | return train_loader, val_loader 224 | 225 | 226 | def set_model(opt): 227 | model_kwargs = {'name': opt.model, 228 | 'method': opt.method, 229 | 'num_classes': opt.n_cls, 230 | 'dim_out': opt.dim_out, 231 | 'dataset': opt.dataset, 232 | 'selfcon_pos': eval(opt.selfcon_pos), 233 | 'selfcon_arch': opt.selfcon_arch, 234 | 'selfcon_size': opt.selfcon_size 235 | } 236 | 237 | if opt.model.startswith('resnet'): 238 | model = CEResNet(**model_kwargs) 239 | elif opt.model.startswith('vgg'): 240 | model = CEVGG(**model_kwargs) 241 | elif opt.model.startswith('wrn'): 242 | model = CEWRN(**model_kwargs) 243 | elif opt.model.startswith('eff'): 244 | model = CEEffNet(**model_kwargs) 245 | 246 | criterion = nn.ModuleList([]) 247 | criterion.append(torch.nn.CrossEntropyLoss()) 248 | criterion.append(KLLoss(opt.temperature)) 249 | 250 | # Note that student and teacher feature shape is same 251 | if opt.method in ['ce', 'subnet_ce', 'kd']: 252 | criterion.append(None) 253 | elif opt.method == 'selfcon': 254 | criterion.append(ConLoss(temperature=opt.temperature)) 255 | else: 256 | raise NotImplemented 257 | 258 | if torch.cuda.is_available(): 259 | if torch.cuda.device_count() > 1: 260 | model = torch.nn.DataParallel(model) 261 | model = model.cuda() 262 | criterion = criterion.cuda() 263 | cudnn.benchmark = True 264 | 265 | return model, criterion, opt 266 | 267 | 268 | def train(train_loader, model, criterion, optimizer, epoch, opt): 269 | """one epoch training""" 270 | model.train() 271 | 272 | batch_time = AverageMeter() 273 | data_time = AverageMeter() 274 | losses = AverageMeter() 275 | top1 = AverageMeter() 276 | top1_s = AverageMeter() 277 | 278 | only_backbone = True if eval(opt.selfcon_pos) in [[False], [False,False], [False,False,False], [False,False,False,False]] else False 279 | 280 | end = time.time() 281 | for idx, inputs in enumerate(train_loader): 282 | images, labels = inputs 283 | 284 | data_time.update(time.time() - end) 285 | 286 | images = images.cuda(non_blocking=True) 287 | labels = labels.cuda(non_blocking=True) 288 | bsz = labels.shape[0] 289 | 290 | # warm-up learning rate 291 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) 292 | 293 | # compute loss 294 | if opt.method not in ['ce', 'subnet_ce', 'kd']: 295 | feats, logits = model(images) 296 | else: 297 | logits = model(images) 298 | 299 | loss = criterion[0](logits[-1], labels) 300 | 301 | for sub_logit in logits[0]: 302 | loss += opt.alpha * criterion[0](sub_logit, labels) 303 | loss += opt.beta * criterion[1](sub_logit, logits[-1]) 304 | if criterion[2] is not None: 305 | for idx, feat_s in enumerate(feats[0]): 306 | # MLP head of backbone is always in random intialization 307 | features = torch.cat([feat_s.unsqueeze(1), feats[-1].unsqueeze(1)], dim=1) 308 | loss += opt.gamma * criterion[2](features, labels) 309 | 310 | # update metric 311 | losses.update(loss.item(), bsz) 312 | acc1, _ = accuracy(logits[-1], labels, topk=(1, 5)) 313 | top1.update(acc1[0], bsz) 314 | if not only_backbone: 315 | acc1_s, _ = accuracy(logits[0][0], labels, topk=(1, 5)) 316 | top1_s.update(acc1_s[0], bsz) 317 | else: 318 | top1_s.update(torch.tensor(0.0).to(acc1[0].device), bsz) 319 | 320 | # SGD 321 | optimizer.zero_grad() 322 | loss.backward() 323 | optimizer.step() 324 | 325 | # measure elapsed time 326 | batch_time.update(time.time() - end) 327 | end = time.time() 328 | 329 | # print info 330 | if (idx + 1) % opt.print_freq == 0: 331 | print('Train: [{0}][{1}/{2}]\t' 332 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 333 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' 334 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t' 335 | 'Acc@1 {top1.avg:.3f} {top1_s.avg:.3f}'.format( 336 | epoch, idx + 1, len(train_loader), batch_time=batch_time, 337 | data_time=data_time, loss=losses, top1=top1, top1_s=top1_s)) 338 | sys.stdout.flush() 339 | 340 | return losses.avg, top1.avg 341 | 342 | 343 | def validate(val_loader, model, criterion, opt): 344 | """validation""" 345 | model.eval() 346 | 347 | batch_time = AverageMeter() 348 | losses = AverageMeter() 349 | top1_b = AverageMeter() 350 | top5_b = AverageMeter() 351 | top1_s = AverageMeter() 352 | top5_s = AverageMeter() 353 | 354 | only_backbone = True if eval(opt.selfcon_pos) in [[False], [False,False], [False,False,False], [False,False,False,False]] else False 355 | 356 | with torch.no_grad(): 357 | end = time.time() 358 | for idx, (images, labels) in enumerate(val_loader): 359 | images = images.float().cuda() 360 | labels = labels.cuda() 361 | bsz = labels.shape[0] 362 | 363 | # forward 364 | if opt.method not in ['ce', 'subnet_ce', 'kd']: 365 | _, logits = model(images) 366 | else: 367 | logits = model(images) 368 | 369 | loss = criterion[0](logits[-1], labels) 370 | 371 | # update metric 372 | losses.update(loss.item(), bsz) 373 | acc1, acc5 = accuracy(logits[-1], labels, topk=(1, 5)) 374 | top1_b.update(acc1[0], bsz) 375 | top5_b.update(acc5[0], bsz) 376 | if only_backbone: 377 | top1_s.update(torch.tensor(0.0).to(acc1[0].device), bsz) 378 | top5_s.update(torch.tensor(0.0).to(acc5[0].device), bsz) 379 | else: 380 | # only for the first sub-network (actually we use 1 sub-network) 381 | acc1_s, acc5_s = accuracy(logits[0][0], labels, topk=(1, 5)) 382 | top1_s.update(acc1_s[0], bsz) 383 | top5_s.update(acc5_s[0], bsz) 384 | 385 | # measure elapsed time 386 | batch_time.update(time.time() - end) 387 | end = time.time() 388 | 389 | if idx % opt.print_freq == 0: 390 | print('Test: [{0}/{1}]\t' 391 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 392 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 393 | 'Acc@1 ({top1_b.avg:.3f}) ({top1_s.avg:.3f})'.format( 394 | idx, len(val_loader), batch_time=batch_time, 395 | loss=losses, top1_b=top1_b, top1_s=top1_s)) 396 | 397 | print(' * Acc@1 {top1_b.avg:.3f} {top1_s.avg:.3f}'.format(top1_b=top1_b, top1_s=top1_s)) 398 | return losses.avg, top1_b.avg, top5_b.avg, top1_s.avg, top5_s.avg 399 | 400 | 401 | def main(): 402 | opt = parse_option() 403 | 404 | # fix seed 405 | np.random.seed(opt.seed) 406 | random.seed(opt.seed) 407 | torch.manual_seed(opt.seed) 408 | torch.cuda.manual_seed(opt.seed) 409 | cudnn.deterministic = True 410 | 411 | # build model and criterion 412 | model, criterion, opt = set_model(opt) 413 | 414 | # build data loader 415 | train_loader, val_loader = set_loader(opt) 416 | 417 | # build optimizer 418 | optimizer = set_optimizer(opt, model) 419 | 420 | if opt.resume: 421 | if os.path.isfile(opt.resume): 422 | print("=> loading checkpoint '{}'".format(opt.resume)) 423 | checkpoint = torch.load(opt.resume) 424 | opt.start_epoch = checkpoint['epoch'] + 1 425 | model.load_state_dict(checkpoint['model']) 426 | optimizer.load_state_dict(checkpoint['optimizer']) 427 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch'])) 428 | else: 429 | print("=> no checkpoint found at '{}'".format(opt.resume)) 430 | else: 431 | opt.start_epoch = 1 432 | 433 | # warm-up for large-batch training, 434 | if opt.batch_size >= 1024: 435 | opt.warm = True 436 | if opt.warm: 437 | opt.model_name = '{}_warm'.format(opt.model_name) 438 | opt.warmup_from = 0.01 439 | opt.warm_epochs = 10 440 | if opt.cosine: 441 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) 442 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( 443 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 444 | else: 445 | opt.warmup_to = opt.learning_rate 446 | 447 | # training routine 448 | best_acc1 = 0 449 | for epoch in range(opt.start_epoch, opt.epochs + 1): 450 | adjust_learning_rate(opt, optimizer, epoch) 451 | 452 | # train for one epoch 453 | time1 = time.time() 454 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt) 455 | time2 = time.time() 456 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 457 | 458 | # evaluation 459 | loss, val_acc1, val_acc5, val_acc1_s, val_acc5_s = validate(val_loader, model, criterion, opt) 460 | 461 | if val_acc1.item() > best_acc1: 462 | best_acc1 = val_acc1 463 | best_acc5 = val_acc5 464 | best_acc1_s = val_acc1_s 465 | best_acc5_s = val_acc5_s 466 | best_model = model.state_dict() 467 | 468 | # save the last model 469 | save_file = os.path.join( 470 | opt.save_folder, 'last.pth') 471 | save_model(model, optimizer, opt, epoch, save_file) 472 | 473 | # save the best model 474 | # Note that accuracy in results.json is different from the saved best model 475 | # because of multiprocessing distributed setting 476 | model.load_state_dict(best_model) 477 | save_file = os.path.join( 478 | opt.save_folder, 'best.pth') 479 | save_model(model, optimizer, opt, opt.epochs, save_file) 480 | 481 | update_json_list(opt.save_folder, [best_acc1.item(), best_acc5.item(), best_acc1_s.item(), best_acc5_s.item(), train_acc.item()], path='./save/distill/results.json') 482 | 483 | 484 | if __name__ == '__main__': 485 | main() 486 | -------------------------------------------------------------------------------- /networks/efficient_big.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | import collections 4 | from functools import partial 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | # Parameters for the entire model (stem, all blocks, and head) 11 | GlobalParams = collections.namedtuple('GlobalParams', [ 12 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 13 | 'num_classes', 'width_coefficient', 'depth_coefficient', 14 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) 15 | 16 | # Parameters for an individual model block 17 | BlockArgs = collections.namedtuple('BlockArgs', [ 18 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 19 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) 20 | 21 | # Change namedtuple defaults 22 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 23 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 24 | 25 | 26 | class Swish(nn.Module): 27 | def forward(self, x): 28 | return x * torch.sigmoid(x) 29 | 30 | 31 | def round_filters(filters, global_params): 32 | """ Calculate and round number of filters based on depth multiplier. """ 33 | multiplier = global_params.width_coefficient 34 | if not multiplier: 35 | return filters 36 | divisor = global_params.depth_divisor 37 | min_depth = global_params.min_depth 38 | filters *= multiplier 39 | min_depth = min_depth or divisor 40 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 41 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 42 | new_filters += divisor 43 | return int(new_filters) 44 | 45 | 46 | def round_repeats(repeats, global_params): 47 | """ Round number of filters based on depth multiplier. """ 48 | multiplier = global_params.depth_coefficient 49 | if not multiplier: 50 | return repeats 51 | return int(math.ceil(multiplier * repeats)) 52 | 53 | 54 | def drop_connect(inputs, p, training): 55 | """ Drop connect. """ 56 | if not training: return inputs 57 | batch_size = inputs.shape[0] 58 | keep_prob = 1 - p 59 | random_tensor = keep_prob 60 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 61 | binary_tensor = torch.floor(random_tensor) 62 | output = inputs / keep_prob * binary_tensor 63 | return output 64 | 65 | 66 | def get_same_padding_conv2d(image_size=None): 67 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. 68 | Static padding is necessary for ONNX exporting of models. """ 69 | if image_size is None: 70 | return Conv2dDynamicSamePadding 71 | else: 72 | return partial(Conv2dStaticSamePadding, image_size=image_size) 73 | 74 | 75 | def get_width_and_height_from_size(x): 76 | """ Obtains width and height from a int or tuple """ 77 | if isinstance(x, int): return x, x 78 | if isinstance(x, list) or isinstance(x, tuple): return x 79 | else: raise TypeError() 80 | 81 | 82 | def calculate_output_image_size(input_image_size, stride): 83 | """ Calculates the output image size when using Conv2dSamePadding with a stride. 84 | Necessary for static padding. Thanks to mannatsingh for pointing this out. """ 85 | if input_image_size is None: return None 86 | image_height, image_width = get_width_and_height_from_size(input_image_size) 87 | stride = stride if isinstance(stride, int) else stride[0] 88 | image_height = int(math.ceil(image_height / stride)) 89 | image_width = int(math.ceil(image_width / stride)) 90 | return [image_height, image_width] 91 | 92 | 93 | class Conv2dDynamicSamePadding(nn.Conv2d): 94 | """ 2D Convolutions like TensorFlow, for a dynamic image size """ 95 | 96 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 97 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 98 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 99 | 100 | def forward(self, x): 101 | ih, iw = x.size()[-2:] 102 | kh, kw = self.weight.size()[-2:] 103 | sh, sw = self.stride 104 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 105 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 106 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 107 | if pad_h > 0 or pad_w > 0: 108 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 109 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 110 | 111 | 112 | class Conv2dStaticSamePadding(nn.Conv2d): 113 | """ 2D Convolutions like TensorFlow, for a fixed image size""" 114 | 115 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): 116 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 117 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 118 | 119 | # Calculate padding based on image size and save it 120 | assert image_size is not None 121 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size 122 | kh, kw = self.weight.size()[-2:] 123 | sh, sw = self.stride 124 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 125 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 126 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 127 | if pad_h > 0 or pad_w > 0: 128 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 129 | else: 130 | self.static_padding = Identity() 131 | 132 | def forward(self, x): 133 | x = self.static_padding(x) 134 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 135 | return x 136 | 137 | 138 | class Identity(nn.Module): 139 | def __init__(self, ): 140 | super(Identity, self).__init__() 141 | 142 | def forward(self, input): 143 | return input 144 | 145 | 146 | class BlockDecoder(object): 147 | """ Block Decoder for readability, straight from the official TensorFlow repository """ 148 | 149 | def _decode_block_string(block_string): 150 | """ Gets a block through a string notation of arguments. """ 151 | assert isinstance(block_string, str) 152 | 153 | ops = block_string.split('_') 154 | options = {} 155 | for op in ops: 156 | splits = re.split(r'(\d.*)', op) 157 | if len(splits) >= 2: 158 | key, value = splits[:2] 159 | options[key] = value 160 | 161 | # Check stride 162 | assert (('s' in options and len(options['s']) == 1) or 163 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 164 | 165 | return BlockArgs( 166 | kernel_size=int(options['k']), 167 | num_repeat=int(options['r']), 168 | input_filters=int(options['i']), 169 | output_filters=int(options['o']), 170 | expand_ratio=int(options['e']), 171 | id_skip=('noskip' not in block_string), 172 | se_ratio=float(options['se']) if 'se' in options else None, 173 | stride=[int(options['s'][0])]) 174 | 175 | def decode(string_list): 176 | """ 177 | Decodes a list of string notations to specify blocks inside the network. 178 | :param string_list: a list of strings, each string is a notation of block 179 | :return: a list of BlockArgs namedtuples of block args 180 | """ 181 | assert isinstance(string_list, list) 182 | blocks_args = [] 183 | for block_string in string_list: 184 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 185 | return blocks_args 186 | 187 | 188 | class MBConvBlock(nn.Module): 189 | """ 190 | Mobile Inverted Residual Bottleneck Block 191 | Args: 192 | block_args (namedtuple): BlockArgs, see above 193 | global_params (namedtuple): GlobalParam, see above 194 | Attributes: 195 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 196 | """ 197 | 198 | def __init__(self, block_args, global_params, image_size=None, drop_connect_rate=0.2): 199 | super().__init__() 200 | self._block_args = block_args 201 | self._bn_mom = 1 - global_params.batch_norm_momentum 202 | self._bn_eps = global_params.batch_norm_epsilon 203 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 204 | self.id_skip = block_args.id_skip # skip connection and drop connect 205 | self.drop_connect_rate = drop_connect_rate 206 | 207 | # Expansion phase 208 | inp = self._block_args.input_filters # number of input channels 209 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 210 | if self._block_args.expand_ratio != 1: 211 | Conv2d = get_same_padding_conv2d(image_size=image_size) 212 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 213 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 214 | # image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing 215 | 216 | # Depthwise convolution phase 217 | k = self._block_args.kernel_size 218 | s = self._block_args.stride 219 | Conv2d = get_same_padding_conv2d(image_size=image_size) 220 | self._depthwise_conv = Conv2d( 221 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 222 | kernel_size=k, stride=s, bias=False) 223 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 224 | image_size = calculate_output_image_size(image_size, s) 225 | 226 | # Squeeze and Excitation layer, if desired 227 | if self.has_se: 228 | Conv2d = get_same_padding_conv2d(image_size=(1,1)) 229 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 230 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 231 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 232 | 233 | # Output phase 234 | final_oup = self._block_args.output_filters 235 | Conv2d = get_same_padding_conv2d(image_size=image_size) 236 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 237 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 238 | self._swish = Swish() 239 | 240 | def forward(self, inputs): 241 | """ 242 | :param inputs: input tensor 243 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 244 | :return: output of block 245 | """ 246 | # Expansion and Depthwise Convolution 247 | x = inputs 248 | if self._block_args.expand_ratio != 1: 249 | x = self._swish(self._bn0(self._expand_conv(inputs))) 250 | x = self._swish(self._bn1(self._depthwise_conv(x))) 251 | 252 | # Squeeze and Excitation 253 | if self.has_se: 254 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 255 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 256 | x = torch.sigmoid(x_squeezed) * x 257 | 258 | x = self._bn2(self._project_conv(x)) 259 | 260 | # Skip connection and drop connect 261 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 262 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 263 | if self.drop_connect_rate: 264 | x = drop_connect(x, p=self.drop_connect_rate, training=self.training) 265 | x = x + inputs # skip connection 266 | return x 267 | 268 | 269 | class EfficientNet(nn.Module): 270 | """ 271 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 272 | Args: 273 | blocks_args (list): A list of BlockArgs to construct blocks 274 | global_params (namedtuple): A set of GlobalParams shared between blocks 275 | Example: 276 | model = EfficientNet.from_pretrained('efficientnet-b0') 277 | """ 278 | 279 | def __init__(self, selfcon_pos=[False]): 280 | super().__init__() 281 | blocks_args = [ 282 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 283 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 284 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 285 | 'r1_k3_s11_e6_i192_o320_se0.25', 286 | ] 287 | 288 | blocks_args[1] = 'r2_k3_s11_e6_i16_o24_se0.25' 289 | blocks_args = BlockDecoder.decode(blocks_args) 290 | 291 | params = {'b0': (1.0, 1.0, 32, 0.2), 'b1': (1.0, 1.1, 34, 0.2), 'b2': (1.1, 1.2, 38, 0.3)} 292 | w, d, s, p = params['b0'] 293 | 294 | global_params = GlobalParams( 295 | batch_norm_momentum=0.99, 296 | batch_norm_epsilon=1e-3, 297 | dropout_rate=p, 298 | drop_connect_rate=0.2, 299 | width_coefficient=w, 300 | depth_coefficient=d, 301 | depth_divisor=8, 302 | min_depth=None, 303 | image_size=s, 304 | ) 305 | 306 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 307 | assert len(blocks_args) > 0, 'block args must be greater than 0' 308 | self._global_params = global_params 309 | self._blocks_args = blocks_args 310 | 311 | # Batch norm parameters 312 | bn_mom = 1 - self._global_params.batch_norm_momentum 313 | bn_eps = self._global_params.batch_norm_epsilon 314 | 315 | # Get stem static or dynamic convolution depending on image size 316 | image_size = global_params.image_size 317 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 318 | 319 | # Stem 320 | in_channels = 3 # rgb 321 | out_channels = round_filters(32, self._global_params) # number of output channels 322 | stride = 1 323 | 324 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, bias=False) 325 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 326 | 327 | # Build blocks 328 | block_layers = [] 329 | drop_connect_rate = self._global_params.drop_connect_rate 330 | num = 0 331 | for b in self._blocks_args: 332 | num += b.num_repeat 333 | 334 | index = 0 335 | for block_args in self._blocks_args: 336 | layers = [] 337 | # Update block input and output filters based on depth multiplier. 338 | block_args = block_args._replace( 339 | input_filters=round_filters(block_args.input_filters, self._global_params), 340 | output_filters=round_filters(block_args.output_filters, self._global_params), 341 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 342 | ) 343 | 344 | # The first block needs to take care of stride and filter size increase. 345 | layers.append(MBConvBlock(block_args, self._global_params, image_size=image_size, drop_connect_rate=drop_connect_rate*index/num)) 346 | index += 1 347 | image_size = calculate_output_image_size(image_size, block_args.stride) 348 | if block_args.num_repeat > 1: 349 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 350 | for _ in range(block_args.num_repeat - 1): 351 | layers.append(MBConvBlock(block_args, self._global_params, image_size=image_size, drop_connect_rate=drop_connect_rate*index/num)) 352 | index += 1 353 | 354 | block_layers.append(nn.Sequential(*layers)) 355 | self.block_layers= nn.ModuleList(block_layers) 356 | 357 | # Head 358 | in_channels = block_args.output_filters # output of final block 359 | out_channels = round_filters(512, self._global_params) 360 | self.final_channels = out_channels 361 | 362 | Conv2d = get_same_padding_conv2d(image_size=image_size) 363 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 364 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 365 | 366 | # Final linear layer 367 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 368 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 369 | self._swish = Swish() 370 | 371 | sub_conv = [] 372 | sub_conv.append(Conv2d(80, self.final_channels, kernel_size=1, bias=False)) 373 | sub_conv.append(nn.BatchNorm2d(num_features=self.final_channels, momentum=bn_mom, eps=bn_eps)) 374 | sub_conv.append(Swish()) 375 | 376 | self.selfcon_layer = self._make_sub_layer(selfcon_pos, nn.Sequential(*sub_conv)) 377 | 378 | # simply test with nn.Linear 379 | def _make_sub_layer(self, pos, sub_conv): 380 | pos = pos[0] 381 | if not pos: 382 | return None 383 | else: 384 | return nn.ModuleList([sub_conv, nn.Linear(self.final_channels, self.final_channels)]) 385 | 386 | # Stem 387 | def conv_stem(self, x): 388 | x = self._swish(self._bn0(self._conv_stem(x))) 389 | 390 | return x 391 | 392 | def pool_linear(self, feat): 393 | # Head 394 | feat = self._swish(self._bn1(self._conv_head(feat))) 395 | 396 | # Pooling and final linear layer 397 | feat = self._avg_pooling(feat) 398 | features = feat.view(feat.size(0), -1) 399 | features = self._dropout(features) 400 | 401 | return features 402 | 403 | def forward(self, x): 404 | sub_out = [] 405 | 406 | x = self.conv_stem(x) 407 | 408 | for i in range(4): 409 | x = self.block_layers[i](x) 410 | 411 | if self.selfcon_layer is not None: 412 | out = self.selfcon_layer[0](x) 413 | out = torch.flatten(self._avg_pooling(out), 1) 414 | out = self._dropout(out) 415 | out = self.selfcon_layer[1](out) 416 | sub_out.append(out) 417 | 418 | for i in range(4, len(self.block_layers)): 419 | x = self.block_layers[i](x) 420 | 421 | features = self.pool_linear(x) 422 | 423 | return sub_out, features 424 | 425 | 426 | def efficientnet(**kwargs): 427 | return EfficientNet(**kwargs) 428 | 429 | 430 | model_dict = { 431 | 'efficientnet': [efficientnet, 512] 432 | } 433 | 434 | 435 | class ConEfficientNet(nn.Module): 436 | """backbone + projection head""" 437 | def __init__(self, name='efficientnet', head='mlp', feat_dim=128, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''): 438 | super(ConEfficientNet, self).__init__() 439 | model_fun, dim_in = model_dict[name] 440 | self.encoder = model_fun(selfcon_pos=selfcon_pos) 441 | if head == 'linear': 442 | self.head = nn.Linear(dim_in, feat_dim) 443 | 444 | self.sub_heads = [] 445 | for pos in selfcon_pos: 446 | if pos: 447 | self.sub_heads.append(nn.Linear(dim_in, feat_dim)) 448 | elif head == 'mlp': 449 | self.head = nn.Sequential( 450 | nn.Linear(dim_in, dim_in), 451 | nn.ReLU(inplace=True), 452 | nn.Linear(dim_in, feat_dim) 453 | ) 454 | 455 | heads = [] 456 | for pos in selfcon_pos: 457 | if pos: 458 | heads.append(nn.Sequential( 459 | nn.Linear(dim_in, dim_in), 460 | nn.ReLU(inplace=True), 461 | nn.Linear(dim_in, feat_dim) 462 | )) 463 | self.sub_heads = nn.ModuleList(heads) 464 | else: 465 | raise NotImplementedError( 466 | 'head not supported: {}'.format(head)) 467 | 468 | def forward(self, x): 469 | sub_feat, feat = self.encoder(x) 470 | 471 | sh_feat = [] 472 | for sf, sub_head in zip(sub_feat, self.sub_heads): 473 | sh_feat.append(F.normalize(sub_head(sf), dim=1)) 474 | 475 | feat = F.normalize(self.head(feat), dim=1) 476 | return sh_feat, feat 477 | 478 | 479 | class CEEffNet(nn.Module): 480 | """encoder + classifier""" 481 | def __init__(self, name='efficientnet', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False], selfcon_arch='resnet', selfcon_size='same', dataset=''): 482 | super(CEEffNet, self).__init__() 483 | self.method = method 484 | 485 | model_fun, dim_in = model_dict[name] 486 | self.encoder = model_fun(selfcon_pos=selfcon_pos) 487 | 488 | logit_fcs, feat_fcs = [], [] 489 | for pos in selfcon_pos: 490 | if pos: 491 | logit_fcs.append(nn.Linear(dim_in, num_classes)) 492 | feat_fcs.append(nn.Linear(dim_in, dim_out)) 493 | 494 | self.logit_fc = nn.ModuleList(logit_fcs) 495 | self.l_fc = nn.Linear(dim_in, num_classes) 496 | 497 | if method not in ['ce', 'subnet_ce', 'kd']: 498 | self.feat_fc = nn.ModuleList(feat_fcs) 499 | self.f_fc = nn.Linear(dim_in, dim_out) 500 | for param in self.f_fc.parameters(): 501 | param.requires_grad = False 502 | 503 | def forward(self, x): 504 | sub_feat, feat = self.encoder(x) 505 | 506 | feats, logits = [], [] 507 | 508 | for idx, sh_feat in enumerate(sub_feat): 509 | logits.append(self.logit_fc[idx](sh_feat)) 510 | if self.method not in ['ce', 'subnet_ce', 'kd']: 511 | out = self.feat_fc[idx](sh_feat) 512 | feats.append(F.normalize(out, dim=1)) 513 | 514 | if self.method not in ['ce', 'subnet_ce', 'kd']: 515 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)] 516 | else: 517 | return [logits, self.l_fc(feat)] 518 | 519 | 520 | class LinearClassifier_EFF(nn.Module): 521 | """Linear classifier""" 522 | def __init__(self, name='efficientnet', num_classes=100): 523 | super(LinearClassifier_EFF, self).__init__() 524 | _, feat_dim = model_dict[name] 525 | self.fc = nn.Linear(feat_dim, num_classes) 526 | 527 | def forward(self, features): 528 | return self.fc(features) --------------------------------------------------------------------------------