├── losses
├── __init__.py
└── losses.py
├── networks
├── sub_network
│ ├── __init__.py
│ └── resnet_layer.py
├── memory_bank.py
├── wrn_big.py
├── vgg_big.py
├── resnet_big.py
└── efficient_big.py
├── utils
├── __init__.py
├── imagenet100.txt
├── imagenet.py
├── tinyimagenet.py
└── util.py
├── requirements.txt
├── scripts
├── 1stage_train.sh
├── supcon_represent.sh
├── selfcon_represent.sh
└── selfcon_represent_imagenet.sh
├── LICENSE
├── .gitignore
├── README.md
├── main_represent.py
├── main_linear.py
└── main_ce.py
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import *
--------------------------------------------------------------------------------
/networks/sub_network/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet_layer import *
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import *
2 | from .imagenet import *
3 | from .tinyimagenet import *
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.5
2 | torchvision==0.9.1
3 | torch==1.8.1
4 | apex==0.9.10dev
5 | tensorboard_logger==0.1.0
6 | git+https://github.com/ildoonet/pytorch-randaugment
7 | opencv-python
--------------------------------------------------------------------------------
/scripts/1stage_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | seed="0"
4 | data="cifar100"
5 | bsz="1024"
6 | method="ce"
7 | model="resnet18"
8 | lr="0.8"
9 |
10 | python main_ce.py \
11 | --seed $seed \
12 | --dataset $data \
13 | --batch_size $bsz \
14 | --method $method
15 | --model $model \
16 | --learning_rate $lr \
17 | --epochs 500 \
18 | --cosine
19 |
--------------------------------------------------------------------------------
/networks/sub_network/resnet_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | __all__ = ['resnet_sub_layer', 'wrn_sub_layer']
5 |
6 |
7 | def resnet_sub_layer(block, in_planes, planes, num_blocks, stride):
8 | strides = [stride] + [1] * (num_blocks - 1)
9 | layers = []
10 | for i in range(num_blocks):
11 | stride = strides[i]
12 | layers.append(block(in_planes, planes, stride))
13 | in_planes = planes * block.expansion
14 | return nn.Sequential(*layers)
15 |
16 | def wrn_sub_layer(block, in_planes, planes, num_blocks, dropout_rate, stride):
17 | strides = [stride] + [1]*(int(num_blocks)-1)
18 | layers = []
19 |
20 | for stride in strides:
21 | layers.append(block(in_planes, planes, dropout_rate, stride))
22 | in_planes = planes
23 |
24 | return nn.Sequential(*layers)
25 |
--------------------------------------------------------------------------------
/scripts/supcon_represent.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | seed="0"
4 | data="cifar100"
5 | method="SupCon"
6 | model="resnet18"
7 | bsz="1024"
8 | lr="0.5"
9 | label="True"
10 | multiview="True"
11 |
12 | python main_represent.py \
13 | --seed $seed \
14 | --method $method \
15 | --dataset $data \
16 | --model $model \
17 | --batch_size $bsz \
18 | --learning_rate $lr \
19 | --temp 0.1 \
20 | --epochs 1000 \
21 | --multiview \
22 | --cosine \
23 | --precision
24 |
25 | python main_linear.py --batch_size 512 \
26 | --dataset $data \
27 | --model $model \
28 | --learning_rate 3 \
29 | --weight_decay 0 \
30 | --epochs 100 \
31 | --lr_decay_epochs '60,80' \
32 | --lr_decay_rate 0.1 \
33 | --ckpt ./save/representation/${method}/${data}_models/${method}_${data}_${model}_lr_${lr}_multiview_${multiview}_label_${label}_decay_0.0001_bsz_${bsz}_temp_0.1_seed_${seed}_cosine_warm/last.pth
34 |
--------------------------------------------------------------------------------
/networks/memory_bank.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class MemoryBank(nn.Module):
6 | def __init__(self, dim, K, n_cls):
7 | super(MemoryBank, self).__init__()
8 |
9 | self.K = K
10 |
11 | self.register_buffer("queue", torch.randn(dim, K))
12 | self.register_buffer("q_label", torch.randint(n_cls, (1, K)))
13 | self.queue = nn.functional.normalize(self.queue, dim=0)
14 |
15 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
16 |
17 | @torch.no_grad()
18 | def _dequeue_and_enqueue(self, keys, labels):
19 | batch_size = keys.shape[0]
20 |
21 | ptr = int(self.queue_ptr)
22 | assert self.K % batch_size == 0 # for simplicity
23 |
24 | self.queue[:, ptr:ptr + batch_size] = keys.T
25 | self.q_label[:, ptr:ptr + batch_size] = labels.unsqueeze(1).T
26 | ptr = (ptr + batch_size) % self.K # move pointer
27 |
28 | self.queue_ptr[0] = ptr
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 SangminBae
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/selfcon_represent.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | seed="0"
4 | method="SelfCon"
5 | data="cifar100"
6 | model="resnet18"
7 | arch="resnet"
8 | size="fc"
9 | pos="[False,True,False]"
10 | bsz="1024"
11 | lr="0.5"
12 | label="True"
13 | multiview="False"
14 |
15 | python main_represent.py --exp_name "${arch}_${size}_${pos}" \
16 | --seed $seed \
17 | --method $method \
18 | --dataset $data \
19 | --model $model \
20 | --selfcon_pos $pos \
21 | --selfcon_arch $arch \
22 | --selfcon_size $size \
23 | --batch_size $bsz \
24 | --learning_rate $lr \
25 | --temp 0.1 \
26 | --epochs 1000 \
27 | --cosine \
28 | --precision
29 |
30 | python main_linear.py --batch_size 512 \
31 | --dataset $data \
32 | --model $model \
33 | --learning_rate 3 \
34 | --weight_decay 0 \
35 | --selfcon_pos $pos \
36 | --selfcon_arch $arch \
37 | --selfcon_size $size \
38 | --epochs 100 \
39 | --lr_decay_epochs '60,80' \
40 | --lr_decay_rate 0.1 \
41 | --subnet \
42 | --ckpt ./save/representation/${method}/${data}_models/${method}_${data}_${model}_lr_${lr}_multiview_${multiview}_label_${label}_decay_0.0001_bsz_${bsz}_temp_0.1_seed_${seed}_cosine_warm_${arch}_${size}_${pos}/last.pth
43 |
44 |
--------------------------------------------------------------------------------
/scripts/selfcon_represent_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | seed="0"
4 | method="SelfCon"
5 | data="imagenet"
6 | model="resnet34"
7 | arch="resnet"
8 | size="same"
9 | pos="[False,True,False]"
10 | bsz="2048"
11 | lr="0.25"
12 | label="True"
13 | multiview="False"
14 |
15 | python main_represent.py --exp_name "${arch}_${size}_${pos}" \
16 | --seed $seed \
17 | --method $method \
18 | --dataset $data \
19 | --data_folder './data/ILSVRC2015/ILSVRC2015/Data/CLS-LOC/' \
20 | --model $model \
21 | --selfcon_pos $pos \
22 | --selfcon_arch $arch \
23 | --selfcon_size $size \
24 | --batch_size $bsz \
25 | --learning_rate $lr \
26 | --temp 0.1 \
27 | --epochs 800 \
28 | --cosine \
29 | --precision
30 |
31 | python main_linear.py --batch_size 512 \
32 | --dataset $data \
33 | --data_folder './data/ILSVRC2015/ILSVRC2015/Data/CLS-LOC/' \
34 | --model $model \
35 | --learning_rate 5 \
36 | --weight_decay 0 \
37 | --selfcon_pos $pos \
38 | --selfcon_arch $arch \
39 | --selfcon_size $size \
40 | --epochs 40 \
41 | --lr_decay_epochs '20,30' \
42 | --lr_decay_rate 0.1 \
43 | --subnet \
44 | --ckpt ./save/representation/${method}/${data}_models/${method}_${data}_${model}_lr_${lr}_multiview_${multiview}_label_${label}_decay_0.0001_bsz_${bsz}_temp_0.1_seed_${seed}_cosine_warm_${arch}_${size}_${pos}/last.pth
45 |
46 |
--------------------------------------------------------------------------------
/utils/imagenet100.txt:
--------------------------------------------------------------------------------
1 | n02869837
2 | n01749939
3 | n02488291
4 | n02107142
5 | n13037406
6 | n02091831
7 | n04517823
8 | n04589890
9 | n03062245
10 | n01773797
11 | n01735189
12 | n07831146
13 | n07753275
14 | n03085013
15 | n04485082
16 | n02105505
17 | n01983481
18 | n02788148
19 | n03530642
20 | n04435653
21 | n02086910
22 | n02859443
23 | n13040303
24 | n03594734
25 | n02085620
26 | n02099849
27 | n01558993
28 | n04493381
29 | n02109047
30 | n04111531
31 | n02877765
32 | n04429376
33 | n02009229
34 | n01978455
35 | n02106550
36 | n01820546
37 | n01692333
38 | n07714571
39 | n02974003
40 | n02114855
41 | n03785016
42 | n03764736
43 | n03775546
44 | n02087046
45 | n07836838
46 | n04099969
47 | n04592741
48 | n03891251
49 | n02701002
50 | n03379051
51 | n02259212
52 | n07715103
53 | n03947888
54 | n04026417
55 | n02326432
56 | n03637318
57 | n01980166
58 | n02113799
59 | n02086240
60 | n03903868
61 | n02483362
62 | n04127249
63 | n02089973
64 | n03017168
65 | n02093428
66 | n02804414
67 | n02396427
68 | n04418357
69 | n02172182
70 | n01729322
71 | n02113978
72 | n03787032
73 | n02089867
74 | n02119022
75 | n03777754
76 | n04238763
77 | n02231487
78 | n03032252
79 | n02138441
80 | n02104029
81 | n03837869
82 | n03494278
83 | n04136333
84 | n03794056
85 | n03492542
86 | n02018207
87 | n04067472
88 | n03930630
89 | n03584829
90 | n02123045
91 | n04229816
92 | n02100583
93 | n03642806
94 | n04336792
95 | n03259280
96 | n02116738
97 | n02108089
98 | n03424325
99 | n01855672
100 | n02090622
--------------------------------------------------------------------------------
/utils/imagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke, Simon Vandenhende
3 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
4 | """
5 | import os
6 | import torch
7 | import torchvision.datasets as datasets
8 | import torch.utils.data as data
9 | from PIL import Image
10 | from torchvision import transforms as tf
11 | from glob import glob
12 |
13 |
14 | class ImageNetSubset(data.Dataset):
15 | def __init__(self, subset_file, root='', split='train',
16 | transform=None):
17 | super(ImageNetSubset, self).__init__()
18 |
19 | self.root = root
20 | self.transform = transform
21 | self.split = split
22 |
23 | # Read the subset of classes to include (sorted)
24 | with open(subset_file, 'r') as f:
25 | result = f.read().splitlines()
26 | subdirs = []
27 | for line in result:
28 | subdirs.append(line)
29 |
30 | # Gather the files (sorted)
31 | imgs = []
32 | for i, subdir in enumerate(subdirs):
33 | subdir_path = os.path.join(self.root, subdir)
34 | files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG')))
35 | for f in files:
36 | imgs.append((f, i))
37 | self.imgs = imgs
38 |
39 | # Resize
40 | self.resize = tf.Resize(256)
41 |
42 | def get_image(self, index):
43 | path, target = self.imgs[index]
44 | with open(path, 'rb') as f:
45 | img = Image.open(f).convert('RGB')
46 | img = self.resize(img)
47 | return img
48 |
49 | def __len__(self):
50 | return len(self.imgs)
51 |
52 | def __getitem__(self, index):
53 | path, target = self.imgs[index]
54 | with open(path, 'rb') as f:
55 | img = Image.open(f).convert('RGB')
56 | im_size = img.size
57 | img = self.resize(img)
58 |
59 | if self.transform is not None:
60 | img = self.transform(img)
61 |
62 | return img, target
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # folder
132 | results/
133 | save/
134 | run/
135 |
136 | # model checkpoint
137 | *.pth
138 | output/
139 | log/
--------------------------------------------------------------------------------
/utils/tinyimagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 |
6 | EXTENSION = 'JPEG'
7 | NUM_IMAGES_PER_CLASS = 500
8 | CLASS_LIST_FILE = 'wnids.txt'
9 | VAL_ANNOTATION_FILE = 'val_annotations.txt'
10 |
11 | __all__ = ['TinyImageNet']
12 |
13 |
14 | class TinyImageNet(Dataset):
15 | """Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`.
16 | Parameters
17 | ----------
18 | root: string
19 | Root directory including `train`, `test` and `val` subdirectories.
20 | split: string
21 | Indicating which split to return as a data set.
22 | Valid option: [`train`, `test`, `val`]
23 | transform: torchvision.transforms
24 | A (series) of valid transformation(s).
25 | in_memory: bool
26 | Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead.
27 | """
28 | def __init__(self, root, train=True, transform=None, target_transform=None, in_memory=False, download=False):
29 | self.root = os.path.expanduser(root)
30 | self.train = train
31 | self.split = 'train' if train else 'val'
32 | self.transform = transform
33 | self.target_transform = target_transform
34 | self.in_memory = in_memory
35 | self.split_dir = os.path.join(root, self.split)
36 | self.image_paths = sorted(glob.iglob(os.path.join(self.split_dir, '**', '*.%s' % EXTENSION), recursive=True))
37 | self.labels = {} # fname - label number mapping
38 | self.images = [] # used for in-memory processing
39 |
40 | # build class label - number mapping
41 | with open(os.path.join(self.root, CLASS_LIST_FILE), 'r') as fp:
42 | self.label_texts = sorted([text.strip() for text in fp.readlines()])
43 | self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}
44 |
45 | if self.split == 'train':
46 | for label_text, i in self.label_text_to_number.items():
47 | for cnt in range(NUM_IMAGES_PER_CLASS):
48 | self.labels['%s_%d.%s' % (label_text, cnt, EXTENSION)] = i
49 | elif self.split == 'val':
50 | with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), 'r') as fp:
51 | for line in fp.readlines():
52 | terms = line.split('\t')
53 | file_name, label_text = terms[0], terms[1]
54 | self.labels[file_name] = self.label_text_to_number[label_text]
55 |
56 | # read all images into torch tensor in memory to minimize disk IO overhead
57 | if self.in_memory:
58 | self.images = [self.read_image(path) for path in self.image_paths]
59 |
60 | def __len__(self):
61 | return len(self.image_paths)
62 |
63 | def __getitem__(self, index):
64 | file_path = self.image_paths[index]
65 |
66 | if self.in_memory:
67 | img = self.images[index]
68 | else:
69 | img = self.read_image(file_path)
70 |
71 | if self.split == 'test':
72 | return img
73 | else:
74 | # file_name = file_path.split('/')[-1]
75 | return img, self.labels[os.path.basename(file_path)]
76 |
77 | def __repr__(self):
78 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
79 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
80 | tmp = self.split
81 | fmt_str += ' Split: {}\n'.format(tmp)
82 | fmt_str += ' Root Location: {}\n'.format(self.root)
83 | tmp = ' Transforms (if any): '
84 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
85 | tmp = ' Target Transforms (if any): '
86 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
87 | return fmt_str
88 |
89 | def read_image(self, path):
90 | img = Image.open(path).convert('RGB')
91 | return self.transform(img) if self.transform else img
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import json
5 | import pickle
6 | import math
7 | import numpy as np
8 | import torch
9 | import torch.optim as optim
10 |
11 | __all__ = ['AverageMeter', 'TwoCropTransform', 'adjust_learning_rate', 'warmup_learning_rate', 'accuracy', 'class_accuracy', 'set_optimizer', 'save_model', 'update_json', 'update_json_list']
12 |
13 |
14 | class TwoCropTransform:
15 | """Create two crops of the same image"""
16 | def __init__(self, transform):
17 | self.transform = transform
18 |
19 | def __call__(self, x):
20 | return [self.transform(x), self.transform(x)]
21 |
22 |
23 | class AverageMeter(object):
24 | """Computes and stores the average and current value"""
25 | def __init__(self):
26 | self.reset()
27 |
28 | def reset(self):
29 | self.val = 0
30 | self.avg = 0
31 | self.sum = 0
32 | self.count = 0
33 |
34 | def update(self, val, n=1):
35 | self.val = val
36 | self.sum += val * n
37 | self.count += n
38 | self.avg = self.sum / self.count
39 |
40 |
41 | def accuracy(output, target, topk=(1,)):
42 | """Computes the accuracy over the k top predictions for the specified values of k"""
43 | with torch.no_grad():
44 | maxk = max(topk)
45 | batch_size = target.size(0)
46 |
47 | _, pred = output.topk(maxk, 1, True, True)
48 | pred = pred.t()
49 | correct = pred.eq(target.view(1, -1).expand_as(pred))
50 |
51 | res = []
52 | for k in topk:
53 | #correct_k = correct[:k].reshape(-1, k).float().sum(1).sum(0, keepdim=True)
54 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
55 | res.append(correct_k.mul_(100.0 / batch_size))
56 | return res
57 |
58 |
59 | def class_accuracy(output, target, cls, topk=(1,)):
60 | """Computes the accuracy over the k top predictions for the specified values of k"""
61 | with torch.no_grad():
62 | maxk = max(topk)
63 |
64 | output = output[target == cls]
65 | target = target[target == cls]
66 |
67 | batch_size = target.size(0)
68 |
69 | _, pred = output.topk(maxk, 1, True, True)
70 | pred = pred.t()
71 | correct = pred.eq(target.view(1, -1).expand_as(pred))
72 |
73 | res = []
74 | for k in topk:
75 | #correct_k = correct[:k].reshape(-1, k).float().sum(1).sum(0, keepdim=True)
76 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
77 | res.append(correct_k.mul_(100.0 / batch_size))
78 | return res, batch_size
79 |
80 |
81 | def adjust_learning_rate(args, optimizer, epoch):
82 | lr = args.learning_rate
83 | if args.cosine:
84 | eta_min = lr * (args.lr_decay_rate ** 3)
85 | lr = eta_min + (lr - eta_min) * (
86 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2
87 | else:
88 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
89 | if steps > 0:
90 | lr = lr * (args.lr_decay_rate ** steps)
91 | for param_group in optimizer.param_groups:
92 | param_group['lr'] = lr
93 |
94 |
95 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
96 | if args.warm and epoch <= args.warm_epochs:
97 | p = (batch_id + (epoch - 1) * total_batches) / \
98 | (args.warm_epochs * total_batches)
99 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)
100 |
101 | for param_group in optimizer.param_groups:
102 | param_group['lr'] = lr
103 |
104 |
105 | def set_optimizer(opt, model, optimizer='sgd'):
106 | if optimizer == 'sgd':
107 | optimizer = optim.SGD(model.parameters(),
108 | lr=opt.learning_rate,
109 | momentum=opt.momentum,
110 | weight_decay=opt.weight_decay)
111 | elif optimizer == 'adam':
112 | optimizer = optim.Adam(model.parameters(),
113 | lr=opt.learning_rate)
114 | return optimizer
115 |
116 |
117 | def save_model(model, optimizer, opt, epoch, save_file):
118 | print('==> Saving...')
119 | state = {
120 | 'opt': opt,
121 | 'model': model.state_dict(),
122 | 'optimizer': optimizer.state_dict(),
123 | 'epoch': epoch,
124 | }
125 | torch.save(state, save_file)
126 | del state
127 |
128 |
129 | def update_json(exp_name, acc={}, path='./save/results.json'):
130 | for k, v in acc.items():
131 | acc[k] = [round(a, 2) for a in v]
132 | if not os.path.exists(path):
133 | with open(path, 'w') as f:
134 | json.dump({}, f)
135 |
136 | with open(path, 'r', encoding="UTF-8") as f:
137 | result_dict = json.load(f)
138 | result_dict[exp_name] = acc
139 |
140 | with open(path, 'w') as f:
141 | json.dump(result_dict, f)
142 |
143 | print('best accuracy: {}'.format(acc))
144 | print('results updated to %s' % path)
145 |
146 |
147 | def update_json_list(exp_name, acc=[0., 0.], path='./save/results.json'):
148 | acc = [round(a, 2) for a in acc]
149 | if not os.path.exists(path):
150 | with open(path, 'w') as f:
151 | json.dump({}, f)
152 |
153 | with open(path, 'r', encoding="UTF-8") as f:
154 | result_dict = json.load(f)
155 | result_dict[exp_name] = acc
156 |
157 | with open(path, 'w') as f:
158 | json.dump(result_dict, f)
159 |
160 | print('best accuracy: {}'.format(acc))
161 | print('results updated to %s' % path)
--------------------------------------------------------------------------------
/losses/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | refer to
3 | 1) Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf
4 | 2) SimCLR: https://arxiv.org/pdf/2002.05709.pdf
5 | """
6 | from __future__ import print_function
7 |
8 | import math
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import numpy as np
13 |
14 | eps = 1e-7
15 |
16 |
17 | class ConLoss(nn.Module):
18 | """Self-Contrastive Learning: https://arxiv.org/abs/2106.15499."""
19 | def __init__(self, temperature=0.07, contrast_mode='all', base_temperature=0.07):
20 | super(ConLoss, self).__init__()
21 | self.temperature = temperature
22 | self.contrast_mode = contrast_mode
23 | self.base_temperature = base_temperature
24 |
25 | def forward(self, features, labels=None, mask=None, supcon_s=False, selfcon_s_FG=False, selfcon_m_FG=False):
26 | """
27 | Args:
28 | features: hidden vector of shape [bsz, n_views, ...].
29 | labels: ground truth of shape [bsz].
30 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
31 | has the same class as sample i. Can be asymmetric.
32 | supcon_s: boolean for using single-viewed batch.
33 | selfcon_s_FG: exclude contrastive loss when the anchor is from F (backbone) and the pairs are from G (sub-network).
34 | selfcon_m_FG: exclude contrastive loss when the anchor is from F (backbone) and the pairs are from G (sub-network).
35 | Returns:
36 | A loss scalar.
37 | """
38 | device = features.device
39 |
40 | if len(features.shape) < 3:
41 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
42 | 'at least 3 dimensions are required')
43 | if len(features.shape) > 3:
44 | features = features.view(features.shape[0], features.shape[1], -1)
45 |
46 | batch_size = features.shape[0] if not selfcon_m_FG else int(features.shape[0]/2)
47 |
48 | if labels is not None and mask is not None:
49 | raise ValueError('Cannot define both `labels` and `mask`')
50 | elif labels is None and mask is None:
51 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
52 | elif labels is not None:
53 | labels = labels.contiguous().view(-1, 1)
54 | if labels.shape[0] != batch_size:
55 | raise ValueError('Num of labels does not match num of features')
56 | mask = torch.eq(labels, labels.T).float().to(device)
57 | else:
58 | mask = mask.float().to(device)
59 |
60 | if not selfcon_s_FG and not selfcon_m_FG:
61 | contrast_count = features.shape[1]
62 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
63 | if self.contrast_mode == 'one':
64 | anchor_feature = features[:, 0]
65 | anchor_count = 1
66 | elif self.contrast_mode == 'all':
67 | anchor_feature = contrast_feature
68 | anchor_count = contrast_count
69 | else:
70 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
71 | elif selfcon_s_FG:
72 | contrast_count = features.shape[1]
73 | anchor_count = features.shape[1]-1
74 |
75 | anchor_feature, contrast_feature = torch.cat(torch.unbind(features, dim=1)[:-1], dim=0), torch.unbind(features, dim=1)[-1]
76 | contrast_feature = torch.cat([anchor_feature, contrast_feature], dim=0)
77 | elif selfcon_m_FG:
78 | contrast_count = int(features.shape[1] * 2)
79 | anchor_count = (features.shape[1]-1)*2
80 |
81 | anchor_feature, contrast_feature = torch.cat(torch.unbind(features, dim=1)[:-1], dim=0), torch.unbind(features, dim=1)[-1]
82 | contrast_feature = torch.cat([anchor_feature, contrast_feature], dim=0)
83 |
84 | # compute logits
85 | anchor_dot_contrast = torch.div(
86 | torch.matmul(anchor_feature, contrast_feature.T),
87 | self.temperature)
88 |
89 | # for numerical stability
90 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
91 | logits = anchor_dot_contrast - logits_max.detach()
92 |
93 | # tile mask
94 | mask = mask.repeat(anchor_count, contrast_count)
95 |
96 | # mask-out self-contrast cases
97 | logits_mask = torch.scatter(
98 | torch.ones_like(mask),
99 | 1,
100 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
101 | 0
102 | )
103 |
104 | mask = mask * logits_mask
105 | if supcon_s:
106 | idx = mask.sum(1) != 0
107 | mask = mask[idx, :]
108 | logits_mask = logits_mask[idx, :]
109 | logits = logits[idx, :]
110 | batch_size = idx.sum()
111 |
112 | # compute log_prob
113 | exp_logits = torch.exp(logits) * logits_mask
114 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
115 |
116 | # compute mean of log-likelihood over positive
117 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
118 |
119 | # loss
120 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
121 | loss = loss.view(anchor_count, batch_size).mean()
122 |
123 | return loss
124 |
125 |
126 | class KLLoss(nn.Module):
127 | """Distilling the Knowledge in a Neural Network"""
128 | def __init__(self, T=3.0):
129 | super(KLLoss, self).__init__()
130 | self.T = T
131 |
132 | def forward(self, logit_s, logit_t):
133 | p_s = F.log_softmax(logit_s/self.T, dim=1)
134 | p_t = F.softmax(logit_t.clone().detach()/self.T, dim=1)
135 | loss = -pow(self.T, 2)*(p_s * p_t).sum(dim=1).mean()
136 |
137 | return loss
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Contrastive Learning: Single-viewed Supervised Contrastive Framework using Sub-network
2 |
3 |
4 |
5 |
6 |
7 | This repository contains the official PyTorch implementation of the following paper:
8 |
9 | > **Self-Contrastive Learning: Single-viewed Supervised Contrastive Framework using Sub-network** by
10 | > Sangmin Bae*, Sungnyun Kim*, Jongwoo Ko, Gihun Lee, Seungjong Noh, Se-Young Yun, [AAAI 2023](https://aaai.org/Conferences/AAAI-23/).
11 | >
12 | > **Paper**: https://arxiv.org/abs/2106.15499
13 | > **Video**: https://www.youtube.com/watch?v=VNv3LXzqX_4
14 | >
15 | > **Abstract:** *Contrastive loss has significantly improved performance in supervised classification tasks by using a multi-viewed framework that leverages augmentation and label information. The augmentation enables contrast with another view of a single image but enlarges training time and memory usage. To exploit the strength of multi-views while avoiding the high computation cost, we introduce a multi-exit architecture that outputs multiple features of a single image in a single-viewed framework. To this end, we propose Self-Contrastive (SelfCon) learning, which self-contrasts within multiple outputs from the different levels of a single network. The multi-exit architecture efficiently replaces multi-augmented images and leverages various information from different layers of a network. We demonstrate that SelfCon learning improves the classification performance of the encoder network, and empirically analyze its advantages in terms of the single-view and the sub-network. Furthermore, we provide theoretical evidence of the performance increase based on the mutual information bound. For ImageNet classification on ResNet-50, SelfCon improves accuracy by +0.6% with 59% memory and 48% time of Supervised Contrastive learning, and a simple ensemble of multi-exit outputs boosts performance up to +1.5%.*
16 |
17 | ## Table of Contents
18 |
19 | * [Installation](#installation)
20 | * [Usage](#usage)
21 | * [Parameters for Pretraining](#parameters-for-pretraining)
22 | * [Experimental Results](#experimental-results)
23 | * [License](#license)
24 | * [Contact](#contact)
25 |
26 | ## Installation
27 | We experimented with eight RTX 3090 GPUs and CUDA version of 11.3.
28 | Please check below requirements and install packages from `requirements.txt`.
29 |
30 | ```bash
31 | $ pip install --upgrade pip
32 | $ pip install -r requirements.txt
33 | ```
34 |
35 | ## Usage
36 | To pretrain the SelfCon model, the following command is an example of running `main_represent.py`.
37 |
38 | ```bash
39 | # Pretraining on [Dataset: CIFAR-100, Architecture: ResNet-18]
40 | python main_represent.py --exp_name "resnet_fc_[False,True,False]" \
41 | --seed 2022 \
42 | --method SelfCon \
43 | --dataset cifar100 \
44 | --model resnet18 \
45 | --selfcon_pos "[False,True,False]" \
46 | --selfcon_arch "resnet" \
47 | --selfcon_size "fc" \
48 | --batch_size 1024 \
49 | --learning_rate 0.5 \
50 | --temp 0.1 \
51 | --epochs 1000 \
52 | --cosine \
53 | --precision
54 | ```
55 |
56 | For linear evaluation, run `main_linear.py` with an appropriate `${SAVE_CKPT}`.
57 | For the above example, `${SAVE_CKPT}` is `./save/representation/SelfCon/cifar100_models/SelfCon_cifar100_resnet18_lr_0.5_multiview_False_label_True_decay_0.0001_bsz_1024_temp_0.1_seed_2022_cosine_warm_resnet_fc_[False,True,False]/last.pth`.
58 |
59 | ```bash
60 | # Finetuning on [Dataset: CIFAR-100, Architecture: ResNet-18]
61 | python main_linear.py --batch_size 512 \
62 | --dataset cifar100 \
63 | --model resnet18 \
64 | --learning_rate 3 \
65 | --weight_decay 0 \
66 | --selfcon_pos "[False,True,False]" \
67 | --selfcon_arch "resnet" \
68 | --selfcon_size "fc" \
69 | --epochs 100 \
70 | --lr_decay_epochs '60,80' \
71 | --lr_decay_rate 0.1 \
72 | --subnet \
73 | --ckpt ${SAVE_CKPT}
74 | ```
75 |
76 | Also, refer to `./scripts/` for SupCon pretraining and 1-stage training examples.
77 | For ImageNet experiments, change `--dataset` to `imagenet`, specify `--data_folder`, and set hyperparameters as denoted in the paper.
78 |
79 | ### Parameters for Pretraining
80 | | Parameter | Description |
81 | | ----------------------------- | ---------------------------------------- |
82 | | `model` | The model architecture. Default: `resnet50`. |
83 | | `dataset` | Dataset to use. Options: `cifar10`, `cifar100`, `tinyimagenet`, `imagenet100`, `imagenet`. |
84 | | `method` | Pretraining method. Options: `Con`, `SupCon`, `SelfCon`. |
85 | | `lr` | Learning rate for the pretraining. Default: `0.5` for the batch size of 1024. |
86 | | `temp` | Temperature of contrastive loss function. Default: `0.07`. |
87 | | `precision` | Whether to use mixed precision. Default: `False`. |
88 | | `cosine` | Whether to use cosine annealing scheduling. Default: `False`. |
89 | | `selfcon_pos` | Position where to attach the sub-network. Default: `[False,True,False]` for ResNet architectures. |
90 | | `selfcon_arch` | Sub-network architecture. Options: `resnet`, `vgg`, `efficientnet`, `wrn`. Default: `resnet`. |
91 | | `selfcon_size` | Block numbers of a sub-network. Options: `fc`, `small`, `same`. Default: `same`. |
92 | | `multiview` | Whether to use multi-viwed batch. Default: `False`. |
93 | | `label` | Whether to use label information in a contrastive loss. Default: `False`. |
94 |
95 |
96 | ### Experimental Results
97 | See our paper for more details and extensive analyses.
98 | Here are some of our main results.
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | ## Citing This Work
108 |
109 | If you find this repo useful for your research, please consider citing our paper:
110 | ```
111 | @article{bae2021self,
112 | title={Self-Contrastive Learning: Single-viewed Supervised Contrastive Framework using Sub-network},
113 | author={Bae, Sangmin and Kim, Sungnyun and Ko, Jongwoo and Lee, Gihun and Noh, Seungjong and Yun, Se-Young},
114 | journal={arXiv preprint arXiv:2106.15499},
115 | year={2021}
116 | }
117 | ```
118 |
119 | ## License
120 | Distributed under the MIT License.
121 |
122 | ## Contact
123 | * Sangmin Bae: bsmn0223@kaist.ac.kr
124 | * Sungnyun Kim: ksn4397@kaist.ac.kr
125 |
--------------------------------------------------------------------------------
/networks/wrn_big.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | from .sub_network import *
8 |
9 | import sys
10 | import numpy as np
11 |
12 | def conv3x3(in_planes, out_planes, stride=1):
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
14 |
15 | def conv7x7(in_planes, out_planes, stride=2):
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride, padding=3, bias=True)
17 |
18 | def conv_init(m):
19 | classname = m.__class__.__name__
20 | if classname.find('Conv') != -1:
21 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
22 | init.constant_(m.bias, 0)
23 | elif classname.find('BatchNorm') != -1:
24 | init.constant_(m.weight, 1)
25 | init.constant_(m.bias, 0)
26 |
27 | class wide_basic(nn.Module):
28 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
29 | super(wide_basic, self).__init__()
30 | self.bn1 = nn.BatchNorm2d(in_planes)
31 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
32 | self.dropout = nn.Dropout(p=dropout_rate)
33 | self.bn2 = nn.BatchNorm2d(planes)
34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
35 |
36 | self.shortcut = nn.Sequential()
37 | if stride != 1 or in_planes != planes:
38 | self.shortcut = nn.Sequential(
39 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
40 | )
41 |
42 | def forward(self, x):
43 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
44 | out = self.conv2(F.relu(self.bn2(out)))
45 | out += self.shortcut(x)
46 |
47 | return out
48 |
49 | class Wide_ResNet(nn.Module):
50 | def __init__(self, depth, widen_factor, dropout_rate, selfcon_pos=[False,False], selfcon_arch='wrn', selfcon_size='same', dataset=''):
51 | super(Wide_ResNet, self).__init__()
52 | self.in_planes = 16
53 |
54 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
55 | n = (depth-4)/6
56 | k = widen_factor
57 | self.dropout_rate = dropout_rate
58 | self.num_blocks = n
59 |
60 | print('| Wide-Resnet %dx%d' %(depth, k))
61 | nStages = [16, 16*k, 32*k, 64*k]
62 | self.nStages = nStages
63 |
64 | if dataset in ['imagenet', 'imagenet100']:
65 | self.conv1 = conv7x7(7,nStages[0])
66 | else:
67 | self.conv1 = conv3x3(3,nStages[0])
68 | if 'imagenet' in dataset:
69 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
70 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
71 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
72 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
73 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
74 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
75 |
76 | self.selfcon_pos = selfcon_pos
77 | self.selfcon_arch = selfcon_arch
78 | self.selfcon_size = selfcon_size
79 | self.selfcon_layer = nn.ModuleList([self._make_sub_layer(idx, pos) for idx, pos in enumerate(selfcon_pos)])
80 | self.dataset = dataset
81 |
82 | for m in self.modules():
83 | conv_init(m)
84 |
85 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
86 | strides = [stride] + [1]*(int(num_blocks)-1)
87 | layers = []
88 |
89 | for stride in strides:
90 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
91 | self.in_planes = planes
92 |
93 | return nn.Sequential(*layers)
94 |
95 | def _make_sub_layer(self, idx, pos):
96 | channels = [128, 256, 512]
97 | strides = [1, 2, 2]
98 | num_blocks = [self.num_blocks]*3
99 | if self.selfcon_size == 'same':
100 | num_blocks = num_blocks
101 | elif self.selfcon_size == 'small':
102 | num_blocks = [int((n+1)/2) for n in num_blocks]
103 | elif self.selfcon_size == 'large':
104 | num_blocks = [int(n*2) for n in num_blocks]
105 | elif self.selfcon_size == 'fc':
106 | pass
107 | else:
108 | raise NotImplemented
109 |
110 | if not pos:
111 | return None
112 | else:
113 | if self.selfcon_size == 'fc':
114 | return nn.Linear(channels[idx], channels[-1])
115 | else:
116 | if self.selfcon_arch == 'resnet':
117 | raise NotImplemented
118 | elif self.selfcon_arch == 'vgg':
119 | raise NotImplemented
120 | elif self.selfcon_arch == 'efficientnet':
121 | raise NotImplemented
122 | elif self.selfcon_arch == 'wrn':
123 | layers = []
124 | for i in range(idx+1, 3):
125 | in_planes = channels[i-1]
126 | layers.append(wrn_sub_layer(wide_basic, in_planes, channels[i], num_blocks[i], self.dropout_rate, strides[i]))
127 |
128 | return nn.Sequential(*layers)
129 |
130 | def forward(self, x):
131 | sub_out = []
132 |
133 | x = self.conv1(x)
134 | # maxpool -> last map before avgpool is 4x4
135 | if 'imagenet' in self.dataset:
136 | x = self.maxpool(x)
137 |
138 | x = self.layer1(x)
139 | if self.selfcon_layer[0]:
140 | if self.selfcon_size != 'fc':
141 | out = self.selfcon_layer[0](x)
142 | out = torch.flatten(self.avgpool(out), 1)
143 | else:
144 | out = torch.flatten(self.avgpool(x), 1)
145 | out = self.selfcon_layer[0](out)
146 | sub_out.append(out)
147 |
148 | x = self.layer2(x)
149 | if self.selfcon_layer[1]:
150 | if self.selfcon_size != 'fc':
151 | out = self.selfcon_layer[1](x)
152 | out = torch.flatten(self.avgpool(out), 1)
153 | else:
154 | out = torch.flatten(self.avgpool(x), 1)
155 | out = self.selfcon_layer[1](out)
156 | sub_out.append(out)
157 |
158 | x = self.layer3(x)
159 | x = F.relu(self.bn1(x))
160 | x = self.avgpool(x)
161 |
162 | x = x.view(x.size(0), -1)
163 | x = torch.flatten(x ,1)
164 | # out = self.linear(out)
165 |
166 | return sub_out, x
167 |
168 |
169 | def wrn_16_8(**kwargs):
170 | return Wide_ResNet(16, 8, 0.3, **kwargs)
171 |
172 | model_dict = {
173 | 'wrn_16_8': [wrn_16_8, 512],
174 | }
175 |
176 | class ConWRN(nn.Module):
177 | """backbone + projection head"""
178 | def __init__(self, name='wrn_16_8', head='mlp', feat_dim=128, selfcon_pos=[False,False], selfcon_arch='wrn', selfcon_size='same', dataset=''):
179 | super(ConWRN, self).__init__()
180 | model_fun, dim_in = model_dict[name]
181 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset)
182 | if head == 'linear':
183 | self.head = nn.Linear(dim_in, feat_dim)
184 |
185 | self.sub_heads = []
186 | for pos in selfcon_pos:
187 | if pos:
188 | self.sub_heads.append(nn.Linear(dim_in, feat_dim))
189 | elif head == 'mlp':
190 | self.head = nn.Sequential(
191 | nn.Linear(dim_in, dim_in),
192 | nn.ReLU(inplace=True),
193 | nn.Linear(dim_in, feat_dim)
194 | )
195 |
196 | heads = []
197 | for pos in selfcon_pos:
198 | if pos:
199 | heads.append(nn.Sequential(
200 | nn.Linear(dim_in, dim_in),
201 | nn.ReLU(inplace=True),
202 | nn.Linear(dim_in, feat_dim)
203 | ))
204 | self.sub_heads = nn.ModuleList(heads)
205 | else:
206 | raise NotImplementedError(
207 | 'head not supported: {}'.format(head))
208 |
209 | def forward(self, x):
210 | sub_feat, feat = self.encoder(x)
211 |
212 | sh_feat = []
213 | for sf, sub_head in zip(sub_feat, self.sub_heads):
214 | sh_feat.append(F.normalize(sub_head(sf), dim=1))
215 |
216 | feat = F.normalize(self.head(feat), dim=1)
217 | return sh_feat, feat
218 |
219 |
220 | class CEWRN(nn.Module):
221 | """encoder + classifier"""
222 | def __init__(self, name='wrn_16_8', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False,False], selfcon_arch='wrn', selfcon_size='same', dataset=''):
223 | super(CEWRN, self).__init__()
224 | self.method = method
225 |
226 | model_fun, dim_in = model_dict[name]
227 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset)
228 |
229 | logit_fcs, feat_fcs = [], []
230 | for pos in selfcon_pos:
231 | if pos:
232 | logit_fcs.append(nn.Linear(dim_in, num_classes))
233 | feat_fcs.append(nn.Linear(dim_in, dim_out))
234 |
235 | self.logit_fc = nn.ModuleList(logit_fcs)
236 | self.l_fc = nn.Linear(dim_in, num_classes)
237 |
238 | if method not in ['ce', 'subnet_ce', 'kd']:
239 | self.feat_fc = nn.ModuleList(feat_fcs)
240 | self.f_fc = nn.Linear(dim_in, dim_out)
241 | for param in self.f_fc.parameters():
242 | param.requires_grad = False
243 |
244 | def forward(self, x):
245 | sub_feat, feat = self.encoder(x)
246 |
247 | feats, logits = [], []
248 |
249 | for idx, sh_feat in enumerate(sub_feat):
250 | logits.append(self.logit_fc[idx](sh_feat))
251 | if self.method not in ['ce', 'subnet_ce', 'kd']:
252 | out = self.feat_fc[idx](sh_feat)
253 | feats.append(F.normalize(out, dim=1))
254 |
255 | if self.method not in ['ce', 'subnet_ce', 'kd']:
256 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)]
257 | else:
258 | return [logits, self.l_fc(feat)]
259 |
260 |
261 | class LinearClassifier_WRN(nn.Module):
262 | """Linear classifier"""
263 | def __init__(self, name='wrn_16_8', num_classes=10):
264 | super(LinearClassifier_WRN, self).__init__()
265 | _, feat_dim = model_dict[name]
266 | self.fc = nn.Linear(feat_dim, num_classes)
267 |
268 | def forward(self, features):
269 | return self.fc(features)
270 |
271 |
272 | if __name__ == '__main__':
273 | net=Wide_ResNet(16, 8, 0.3, 10)
274 | y = net(Variable(torch.randn(1,3,32,32)))
275 |
276 | print(y.size())
277 |
--------------------------------------------------------------------------------
/networks/vgg_big.py:
--------------------------------------------------------------------------------
1 | '''
2 | VGG in PyTorch
3 | Adapted from: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from typing import Union, List, Dict, Any, cast
10 |
11 |
12 | class VGG(nn.Module):
13 |
14 | def __init__(
15 | self,
16 | features: nn.Module,
17 | cfg: str = 'D',
18 | arch: str = 'vgg16_bn',
19 | init_weights: bool = True,
20 | selfcon_pos: List[bool] = [False,False,False,False],
21 | selfcon_arch: str = 'vgg',
22 | selfcon_size: str = 'small',
23 | dataset: str = ''
24 | ) -> None:
25 | super(VGG, self).__init__()
26 | features_lst, modules_lst = [], []
27 | for module in features.modules():
28 | if isinstance(module, nn.Sequential):
29 | continue
30 | modules_lst.append(module)
31 | if isinstance(module, nn.MaxPool2d):
32 | features_lst.append(modules_lst)
33 | modules_lst = []
34 | self.block1 = nn.Sequential(*features_lst[0])
35 | self.block2 = nn.Sequential(*features_lst[1])
36 | self.block3 = nn.Sequential(*features_lst[2])
37 | self.block4 = nn.Sequential(*features_lst[3])
38 | self.block5 = nn.Sequential(*features_lst[4])
39 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
40 |
41 | self.arch = arch
42 | self.selfcon_pos = selfcon_pos
43 | self.selfcon_arch = selfcon_arch
44 | self.selfcon_size = selfcon_size
45 | self.dataset = dataset
46 | self.selfcon_layer = nn.ModuleList([self._make_sub_layer(idx, pos, cfg) for idx, pos in enumerate(selfcon_pos)])
47 |
48 | if init_weights:
49 | self._initialize_weights()
50 |
51 | def forward(self, x):
52 | sub_out = []
53 |
54 | x = self.block1(x)
55 | if self.selfcon_layer[0]:
56 | if self.selfcon_size != 'fc':
57 | out = self.selfcon_layer[0](x)
58 | out = torch.flatten(self.avgpool(out), 1)
59 | else:
60 | out = torch.flatten(self.avgpool(x), 1)
61 | out = self.selfcon_layer[0](out)
62 | sub_out.append(out)
63 |
64 | x = self.block2(x)
65 | if self.selfcon_layer[1]:
66 | if self.selfcon_size != 'fc':
67 | out = self.selfcon_layer[1](x)
68 | out = torch.flatten(self.avgpool(out), 1)
69 | else:
70 | out = torch.flatten(self.avgpool(x), 1)
71 | out = self.selfcon_layer[1](out)
72 | sub_out.append(out)
73 |
74 | x = self.block3(x)
75 | if self.selfcon_layer[2]:
76 | if self.selfcon_size != 'fc':
77 | out = self.selfcon_layer[2](x)
78 | out = torch.flatten(self.avgpool(out), 1)
79 | else:
80 | out = torch.flatten(self.avgpool(x), 1)
81 | out = self.selfcon_layer[2](out)
82 | sub_out.append(out)
83 |
84 | x = self.block4(x)
85 | if self.selfcon_layer[3]:
86 | if self.selfcon_size != 'fc':
87 | out = self.selfcon_layer[3](x)
88 | out = torch.flatten(self.avgpool(out), 1)
89 | else:
90 | out = torch.flatten(self.avgpool(x), 1)
91 | out = self.selfcon_layer[3](out)
92 | sub_out.append(out)
93 |
94 | x = self.block5(x)
95 | x = self.avgpool(x)
96 | x = torch.flatten(x, 1)
97 | return sub_out, x
98 |
99 | def _initialize_weights(self) -> None:
100 | for m in self.modules():
101 | if isinstance(m, nn.Conv2d):
102 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
103 | if m.bias is not None:
104 | nn.init.constant_(m.bias, 0)
105 | elif isinstance(m, nn.BatchNorm2d):
106 | nn.init.constant_(m.weight, 1)
107 | nn.init.constant_(m.bias, 0)
108 | elif isinstance(m, nn.Linear):
109 | nn.init.normal_(m.weight, 0, 0.01)
110 | nn.init.constant_(m.bias, 0)
111 |
112 | def _make_sub_layer(self, idx, pos, cfg):
113 | channels = [64, 128, 256, 512, 512]
114 |
115 | if not pos:
116 | return None
117 | else:
118 | if self.selfcon_arch == 'resnet':
119 | raise NotImplemented
120 | elif self.selfcon_arch == 'vgg':
121 | if self.selfcon_size == 'fc':
122 | layers = [nn.Linear(channels[idx], channels[-1])]
123 | else:
124 | layers = []
125 | if self.selfcon_size == 'same':
126 | num_blocks = 3 if cfg == 'D' else 2
127 | elif self.selfcon_size == 'small':
128 | num_blocks = 1
129 | elif self.selfcon_size == 'large':
130 | raise NotImplemented
131 |
132 | for i in range(idx+1, 5):
133 | in_planes = channels[i-1]
134 | v = channels[i]
135 | for b in range(num_blocks):
136 | if self.arch.endswith('_bn'):
137 | layers += [nn.Conv2d(in_planes, v, kernel_size=3, padding=1), nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
138 | else:
139 | layers += [nn.Conv2d(in_planes, v, kernel_size=3, padding=1), nn.ReLU(inplace=True)]
140 | in_planes = v
141 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
142 | else:
143 | raise NotImplemented
144 |
145 | return nn.Sequential(*layers)
146 |
147 |
148 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
149 | layers: List[nn.Module] = []
150 | in_channels = 3
151 | for v in cfg:
152 | if v == 'M':
153 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
154 | else:
155 | v = cast(int, v)
156 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
157 | if batch_norm:
158 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
159 | else:
160 | layers += [conv2d, nn.ReLU(inplace=True)]
161 | in_channels = v
162 | return nn.Sequential(*layers)
163 |
164 |
165 | cfgs: Dict[str, List[Union[str, int]]] = {
166 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
167 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
168 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
169 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
170 | }
171 |
172 |
173 | def _vgg(arch: str, cfg: str, batch_norm: bool, progress: bool, **kwargs: Any) -> VGG:
174 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), cfg=cfg, arch=arch, **kwargs)
175 | return model
176 |
177 | def vgg13(progress: bool = True, **kwargs: Any) -> VGG:
178 | return _vgg('vgg13', 'B', False, progress, **kwargs)
179 |
180 | def vgg13_bn(progress: bool = True, **kwargs: Any) -> VGG:
181 | return _vgg('vgg13_bn', 'B', True, progress, **kwargs)
182 |
183 | def vgg16(progress: bool = True, **kwargs: Any) -> VGG:
184 | return _vgg('vgg16', 'D', False, progress, **kwargs)
185 |
186 | def vgg16_bn(progress: bool = True, **kwargs: Any) -> VGG:
187 | return _vgg('vgg16_bn', 'D', True, progress, **kwargs)
188 |
189 | model_dict = {'vgg13': vgg13,
190 | 'vgg13_bn': vgg13_bn,
191 | 'vgg16': vgg16,
192 | 'vgg16_bn': vgg16_bn
193 | }
194 |
195 | class ConVGG(nn.Module):
196 | def __init__(self, name='vgg13_bn', head='mlp', feat_dim=128, selfcon_pos=[False,False,False,False], selfcon_arch='vgg', selfcon_size='same', dataset=''):
197 | super(ConVGG, self).__init__()
198 | model_fun = model_dict[name]
199 | dim_in = 512
200 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset)
201 | if head == 'linear':
202 | self.head = nn.Linear(dim_in, feat_dim)
203 |
204 | self.sub_heads = []
205 | for pos in selfcon_pos:
206 | if pos:
207 | self.sub_heads.append(nn.Linear(dim_in, feat_dim))
208 | elif head == 'mlp':
209 | self.head = nn.Sequential(
210 | nn.Linear(dim_in, dim_in),
211 | nn.ReLU(inplace=True),
212 | nn.Linear(dim_in, feat_dim)
213 | )
214 |
215 | heads = []
216 | for pos in selfcon_pos:
217 | if pos:
218 | heads.append(nn.Sequential(
219 | nn.Linear(dim_in, dim_in),
220 | nn.ReLU(inplace=True),
221 | nn.Linear(dim_in, feat_dim)
222 | ))
223 | self.sub_heads = nn.ModuleList(heads)
224 | else:
225 | raise NotImplementedError(
226 | 'head not supported: {}'.format(head))
227 |
228 | def forward(self, x):
229 | sub_feat, feat = self.encoder(x)
230 |
231 | sh_feat = []
232 | for sf, sub_head in zip(sub_feat, self.sub_heads):
233 | sh_feat.append(F.normalize(sub_head(sf), dim=1))
234 |
235 | feat = F.normalize(self.head(feat), dim=1)
236 | return sh_feat, feat
237 |
238 |
239 | class CEVGG(nn.Module):
240 | """encoder + classifier"""
241 | def __init__(self, name='vgg13_bn', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False,False,False,False], selfcon_arch='vgg', selfcon_size='same', dataset=''):
242 | super(CEVGG, self).__init__()
243 | self.method = method
244 |
245 | model_fun = model_dict[name]
246 | dim_in = 512
247 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset)
248 |
249 | logit_fcs, feat_fcs = [], []
250 | for pos in selfcon_pos:
251 | if pos:
252 | logit_fcs.append(nn.Sequential(nn.Linear(dim_in, dim_in),
253 | nn.ReLU(inplace=True),
254 | nn.Dropout(),
255 | nn.Linear(dim_in, num_classes)
256 | ))
257 | feat_fcs.append(nn.Linear(dim_in, dim_out))
258 |
259 | self.logit_fc = nn.ModuleList(logit_fcs)
260 | self.l_fc = nn.Sequential(nn.Linear(dim_in, dim_in),
261 | nn.ReLU(inplace=True),
262 | nn.Dropout(),
263 | nn.Linear(dim_in, num_classes)
264 | )
265 |
266 | if method not in ['ce', 'subnet_ce', 'kd']:
267 | self.feat_fc = nn.ModuleList(feat_fcs)
268 | self.f_fc = nn.Linear(dim_in, dim_out)
269 | for param in self.f_fc.parameters():
270 | param.requires_grad = False
271 |
272 | def forward(self, x):
273 | sub_feat, feat = self.encoder(x)
274 |
275 | feats, logits = [], []
276 |
277 | for idx, sh_feat in enumerate(sub_feat):
278 | logits.append(self.logit_fc[idx](sh_feat))
279 | if self.method not in ['ce', 'subnet_ce', 'kd']:
280 | out = self.feat_fc[idx](sh_feat)
281 | feats.append(F.normalize(out, dim=1))
282 |
283 | if self.method not in ['ce', 'subnet_ce', 'kd']:
284 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)]
285 | else:
286 | return [logits, self.l_fc(feat)]
287 |
288 |
289 | class LinearClassifier_VGG(nn.Module):
290 | """Linear classifier"""
291 | def __init__(self, name='vgg13_bn', num_classes=10):
292 | super(LinearClassifier_VGG, self).__init__()
293 | feat_dim = 512
294 | self.fc1 = nn.Linear(feat_dim, feat_dim)
295 | self.relu = nn.ReLU(inplace=True)
296 | self.dropout = nn.Dropout()
297 | self.fc2 = nn.Linear(feat_dim, num_classes)
298 |
299 | def forward(self, features):
300 | features = self.dropout(self.relu(self.fc1(features)))
301 | return self.fc2(features)
302 |
--------------------------------------------------------------------------------
/networks/resnet_big.py:
--------------------------------------------------------------------------------
1 | """ResNet in PyTorch.
2 | ImageNet-Style ResNet
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | Adapted from: https://github.com/bearpaw/pytorch-classification
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from .sub_network import *
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1, is_last=False):
18 | super(BasicBlock, self).__init__()
19 | self.is_last = is_last
20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 |
25 | self.shortcut = nn.Sequential()
26 | if stride != 1 or in_planes != self.expansion * planes:
27 | self.shortcut = nn.Sequential(
28 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
29 | nn.BatchNorm2d(self.expansion * planes)
30 | )
31 |
32 | def forward(self, x):
33 | out = F.relu(self.bn1(self.conv1(x)))
34 | out = self.bn2(self.conv2(out))
35 | out += self.shortcut(x)
36 | preact = out
37 | out = F.relu(out)
38 | if self.is_last:
39 | return out, preact
40 | else:
41 | return out
42 |
43 |
44 | class Bottleneck(nn.Module):
45 | expansion = 4
46 |
47 | def __init__(self, in_planes, planes, stride=1, is_last=False):
48 | super(Bottleneck, self).__init__()
49 | self.is_last = is_last
50 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
53 | self.bn2 = nn.BatchNorm2d(planes)
54 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
55 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
56 |
57 | self.shortcut = nn.Sequential()
58 | if stride != 1 or in_planes != self.expansion * planes:
59 | self.shortcut = nn.Sequential(
60 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
61 | nn.BatchNorm2d(self.expansion * planes)
62 | )
63 |
64 | def forward(self, x):
65 | out = F.relu(self.bn1(self.conv1(x)))
66 | out = F.relu(self.bn2(self.conv2(out)))
67 | out = self.bn3(self.conv3(out))
68 | out += self.shortcut(x)
69 | preact = out
70 | out = F.relu(out)
71 | if self.is_last:
72 | return out, preact
73 | else:
74 | return out
75 |
76 |
77 | class ResNet(nn.Module):
78 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''):
79 | super(ResNet, self).__init__()
80 | self.in_planes = 64
81 | self.block = block
82 | self.num_blocks = num_blocks
83 | self.in_channel = in_channel
84 | self.dataset = dataset
85 |
86 | self.large = False if dataset in ['cifar10', 'cifar100', 'tinyimagenet'] else True
87 | if not self.large:
88 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
89 | else:
90 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
91 |
92 | self.bn1 = nn.BatchNorm2d(64)
93 | if self.large:
94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
95 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
96 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
98 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
99 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
100 |
101 | self.selfcon_pos = selfcon_pos
102 | self.selfcon_arch = selfcon_arch
103 | self.selfcon_size = selfcon_size
104 | self.selfcon_layer = nn.ModuleList([self._make_sub_layer(idx, pos) for idx, pos in enumerate(selfcon_pos)])
105 |
106 | for k, m in self.named_modules():
107 | if isinstance(m, nn.Conv2d):
108 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
109 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
110 | nn.init.constant_(m.weight, 1)
111 | nn.init.constant_(m.bias, 0)
112 |
113 | # Zero-initialize the last BN in each residual branch,
114 | # so that the residual branch starts with zeros, and each residual block behaves
115 | # like an identity. This improves the model by 0.2~0.3% according to:
116 | # https://arxiv.org/abs/1706.02677
117 | if zero_init_residual:
118 | for m in self.modules():
119 | if isinstance(m, Bottleneck):
120 | nn.init.constant_(m.bn3.weight, 0)
121 | elif isinstance(m, BasicBlock):
122 | nn.init.constant_(m.bn2.weight, 0)
123 |
124 | def _make_layer(self, block, planes, num_blocks, stride):
125 | strides = [stride] + [1] * (num_blocks - 1)
126 | layers = []
127 | for i in range(num_blocks):
128 | stride = strides[i]
129 | layers.append(block(self.in_planes, planes, stride))
130 | self.in_planes = planes * block.expansion
131 | return nn.Sequential(*layers)
132 |
133 | def _make_sub_layer(self, idx, pos):
134 | channels = [64, 128, 256, 512]
135 | strides = [1, 2, 2, 2]
136 | if self.selfcon_size == 'same':
137 | num_blocks = self.num_blocks
138 | elif self.selfcon_size == 'small':
139 | num_blocks = [int(n/2) for n in self.num_blocks]
140 | elif self.selfcon_size == 'large':
141 | num_blocks = [int(n*2) for n in self.num_blocks]
142 | elif self.selfcon_size == 'fc':
143 | pass
144 | else:
145 | raise NotImplemented
146 |
147 | if not pos:
148 | return None
149 | else:
150 | if self.selfcon_size == 'fc':
151 | return nn.Linear(channels[idx] * self.block.expansion, channels[-1] * self.block.expansion)
152 | else:
153 | if self.selfcon_arch == 'resnet':
154 | # selfcon layer do not share any parameters
155 | layers = []
156 | for i in range(idx+1, 4):
157 | in_planes = channels[i-1] * self.block.expansion
158 | layers.append(resnet_sub_layer(self.block, in_planes, channels[i], num_blocks[i], strides[i]))
159 | elif self.selfcon_arch == 'vgg':
160 | raise NotImplemented
161 | elif self.selfcon_arch == 'efficientnet':
162 | raise NotImplemented
163 |
164 | return nn.Sequential(*layers)
165 |
166 | def forward(self, x):
167 | sub_out = []
168 |
169 | x = F.relu(self.bn1(self.conv1(x)))
170 | if self.large:
171 | x = self.maxpool(x)
172 |
173 | x = self.layer1(x)
174 | if self.selfcon_layer[0]:
175 | if self.selfcon_size != 'fc':
176 | out = self.selfcon_layer[0](x)
177 | out = torch.flatten(self.avgpool(out), 1)
178 | else:
179 | out = torch.flatten(self.avgpool(x), 1)
180 | out = self.selfcon_layer[0](out)
181 | sub_out.append(out)
182 |
183 | x = self.layer2(x)
184 | if self.selfcon_layer[1]:
185 | if self.selfcon_size != 'fc':
186 | out = self.selfcon_layer[1](x)
187 | out = torch.flatten(self.avgpool(out), 1)
188 | else:
189 | out = torch.flatten(self.avgpool(x), 1)
190 | out = self.selfcon_layer[1](out)
191 | sub_out.append(out)
192 |
193 | x = self.layer3(x)
194 | if self.selfcon_layer[2]:
195 | if self.selfcon_size != 'fc':
196 | out = self.selfcon_layer[2](x)
197 | out = torch.flatten(self.avgpool(out), 1)
198 | else:
199 | out = torch.flatten(self.avgpool(x), 1)
200 | out = self.selfcon_layer[2](out)
201 | sub_out.append(out)
202 |
203 | out = self.layer4(x)
204 | out = self.avgpool(out)
205 | out = torch.flatten(out, 1)
206 |
207 | return sub_out, out
208 |
209 |
210 | def resnet18(**kwargs):
211 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
212 |
213 |
214 | def resnet34(**kwargs):
215 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
216 |
217 |
218 | def resnet50(**kwargs):
219 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
220 |
221 |
222 | def resnet101(**kwargs):
223 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
224 |
225 |
226 | model_dict = {
227 | 'resnet18': [resnet18, 512],
228 | 'resnet34': [resnet34, 512],
229 | 'resnet50': [resnet50, 2048],
230 | 'resnet101': [resnet101, 2048],
231 | }
232 |
233 |
234 | class LinearBatchNorm(nn.Module):
235 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
236 | def __init__(self, dim, affine=True):
237 | super(LinearBatchNorm, self).__init__()
238 | self.dim = dim
239 | self.bn = nn.BatchNorm2d(dim, affine=affine)
240 |
241 | def forward(self, x):
242 | x = x.view(-1, self.dim, 1, 1)
243 | x = self.bn(x)
244 | x = x.view(-1, self.dim)
245 | return x
246 |
247 |
248 | class ConResNet(nn.Module):
249 | """backbone + projection head"""
250 | def __init__(self, name='resnet50', head='mlp', feat_dim=128, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''):
251 | super(ConResNet, self).__init__()
252 | model_fun, dim_in = model_dict[name]
253 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset)
254 | if head == 'linear':
255 | self.head = nn.Linear(dim_in, feat_dim)
256 |
257 | self.sub_heads = []
258 | for pos in selfcon_pos:
259 | if pos:
260 | self.sub_heads.append(nn.Linear(dim_in, feat_dim))
261 | elif head == 'mlp':
262 | self.head = nn.Sequential(
263 | nn.Linear(dim_in, dim_in),
264 | nn.ReLU(inplace=True),
265 | nn.Linear(dim_in, feat_dim)
266 | )
267 |
268 | heads = []
269 | for pos in selfcon_pos:
270 | if pos:
271 | heads.append(nn.Sequential(
272 | nn.Linear(dim_in, dim_in),
273 | nn.ReLU(inplace=True),
274 | nn.Linear(dim_in, feat_dim)
275 | ))
276 | self.sub_heads = nn.ModuleList(heads)
277 | else:
278 | raise NotImplementedError(
279 | 'head not supported: {}'.format(head))
280 |
281 | def forward(self, x):
282 | sub_feat, feat = self.encoder(x)
283 |
284 | sh_feat = []
285 | for sf, sub_head in zip(sub_feat, self.sub_heads):
286 | sh_feat.append(F.normalize(sub_head(sf), dim=1))
287 |
288 | feat = F.normalize(self.head(feat), dim=1)
289 | return sh_feat, feat
290 |
291 |
292 | class CEResNet(nn.Module):
293 | """encoder + classifier"""
294 | def __init__(self, name='resnet50', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''):
295 | super(CEResNet, self).__init__()
296 | self.method = method
297 |
298 | model_fun, dim_in = model_dict[name]
299 | self.encoder = model_fun(selfcon_pos=selfcon_pos, selfcon_arch=selfcon_arch, selfcon_size=selfcon_size, dataset=dataset)
300 |
301 | logit_fcs, feat_fcs = [], []
302 | for pos in selfcon_pos:
303 | if pos:
304 | logit_fcs.append(nn.Linear(dim_in, num_classes))
305 | feat_fcs.append(nn.Linear(dim_in, dim_out))
306 |
307 | self.logit_fc = nn.ModuleList(logit_fcs)
308 | self.l_fc = nn.Linear(dim_in, num_classes)
309 |
310 | if method not in ['ce', 'subnet_ce', 'kd']:
311 | self.feat_fc = nn.ModuleList(feat_fcs)
312 | self.f_fc = nn.Linear(dim_in, dim_out)
313 | for param in self.f_fc.parameters():
314 | param.requires_grad = False
315 |
316 | def forward(self, x):
317 | sub_feat, feat = self.encoder(x)
318 |
319 | feats, logits = [], []
320 |
321 | for idx, sh_feat in enumerate(sub_feat):
322 | logits.append(self.logit_fc[idx](sh_feat))
323 | if self.method not in ['ce', 'subnet_ce', 'kd']:
324 | out = self.feat_fc[idx](sh_feat)
325 | feats.append(F.normalize(out, dim=1))
326 |
327 | if self.method not in ['ce', 'subnet_ce', 'kd']:
328 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)]
329 | else:
330 | return [logits, self.l_fc(feat)]
331 |
332 |
333 | class LinearClassifier(nn.Module):
334 | """Linear classifier"""
335 | def __init__(self, name='resnet50', num_classes=10):
336 | super(LinearClassifier, self).__init__()
337 | _, feat_dim = model_dict[name]
338 | self.fc = nn.Linear(feat_dim, num_classes)
339 |
340 | def forward(self, features):
341 | return self.fc(features)
342 |
--------------------------------------------------------------------------------
/main_represent.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import argparse
6 | import time
7 | import math
8 | import copy
9 | import random
10 | import builtins
11 | import numpy as np
12 |
13 | import torch
14 | import torch.backends.cudnn as cudnn
15 | from torchvision import transforms, datasets
16 | from RandAugment import RandAugment
17 |
18 | from losses import ConLoss
19 | from utils.util import *
20 | from utils.tinyimagenet import TinyImageNet
21 | from utils.imagenet import ImageNetSubset
22 | from networks.resnet_big import ConResNet
23 | from networks.vgg_big import ConVGG
24 | from networks.wrn_big import ConWRN
25 | from networks.efficient_big import ConEfficientNet
26 |
27 |
28 | def parse_option():
29 | parser = argparse.ArgumentParser('argument for training')
30 |
31 | parser.add_argument('--exp_name', type=str, default='')
32 | parser.add_argument('--seed', type=int, default=0)
33 | parser.add_argument('--print_freq', type=int, default=10)
34 | parser.add_argument('--save_freq', type=int, default=0)
35 | parser.add_argument('--save_dir', type=str, default='./save/representation')
36 | parser.add_argument('--resume', help='path of model checkpoint to resume', type=str,
37 | default='')
38 |
39 | # dataset
40 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet', 'imagenet', 'imagenet100'])
41 | parser.add_argument('--data_folder', type=str, default='datasets/')
42 | parser.add_argument('--batch_size', type=int, default=256)
43 | parser.add_argument('--num_workers', type=int, default=16)
44 |
45 | # model
46 | parser.add_argument('--model', type=str, default='resnet50')
47 | parser.add_argument('--selfcon_pos', type=str, default='[False,False,False]',
48 | help='where to augment the paths')
49 | parser.add_argument('--selfcon_arch', type=str, default='resnet',
50 | choices=['resnet', 'vgg', 'efficientnet', 'wrn'], help='which architecture to form a sub-network')
51 | parser.add_argument('--selfcon_size', type=str, default='same',
52 | choices=['fc', 'same', 'small'], help='argument for num_blocks of a sub-network')
53 | parser.add_argument('--feat_dim', type=int, default=128,
54 | help='feature dimension for mlp')
55 |
56 | # optimization
57 | parser.add_argument('--epochs', type=int, default=1000)
58 | parser.add_argument('--learning_rate', type=float, default=0.05)
59 | parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900')
60 | parser.add_argument('--lr_decay_rate', type=float, default=0.1)
61 | parser.add_argument('--weight_decay', type=float, default=1e-4)
62 | parser.add_argument('--momentum', type=float, default=0.9)
63 | parser.add_argument('--precision', action='store_true',
64 | help='whether to use 16 bit precision or not')
65 | parser.add_argument('--cosine', action='store_true',
66 | help='using cosine annealing')
67 | parser.add_argument('--warm', action='store_true',
68 | help='warm-up for large batch training')
69 | parser.add_argument('--temp', type=float, default=0.07,
70 | help='temperature for loss function')
71 |
72 | # important arguments
73 | parser.add_argument('--method', type=str,
74 | choices=['Con', 'SupCon', 'SelfCon'], help='choose method')
75 | parser.add_argument('--multiview', action='store_true',
76 | help='use multiview batch or not')
77 | parser.add_argument('--label', action='store_false',
78 | help='whether to use label information or not')
79 | parser.add_argument('--alpha', type=float, default=0.0,
80 | help='weight for selfcon with multiview loss function')
81 |
82 | # other arguments
83 | parser.add_argument('--randaug', action='store_true',
84 | help='whether to add randaugment or not')
85 | parser.add_argument('--weakaug', action='store_true',
86 | help='whether to use weak augmentation or not')
87 |
88 | opt = parser.parse_args()
89 |
90 | if opt.model.startswith('vgg'):
91 | if opt.selfcon_pos == '[False,False,False]':
92 | opt.selfcon_pos = '[False,False,False,False]'
93 | opt.selfcon_arch = 'vgg'
94 | elif opt.model.startswith('wrn'):
95 | if opt.selfcon_pos == '[False,False,False]':
96 | opt.selfcon_pos = '[False,False]'
97 | opt.selfcon_arch = 'wrn'
98 |
99 | # set the path according to the environment
100 | opt.model_path = '%s/%s/%s_models' % (opt.save_dir, opt.method, opt.dataset)
101 |
102 | if opt.dataset == 'cifar10':
103 | opt.n_cls = 10
104 | elif opt.dataset == 'cifar100':
105 | opt.n_cls = 100
106 | elif opt.dataset == 'tinyimagenet':
107 | opt.n_cls = 200
108 | elif opt.dataset == 'imagenet':
109 | opt.n_cls = 1000
110 | elif opt.dataset == 'imagenet100':
111 | opt.n_cls = 100
112 | else:
113 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
114 |
115 | iterations = opt.lr_decay_epochs.split(',')
116 | opt.lr_decay_epochs = list([])
117 | for it in iterations:
118 | opt.lr_decay_epochs.append(int(it))
119 | opt.model_name = '{}_{}_{}_lr_{}_multiview_{}_label_{}_decay_{}_bsz_{}_temp_{}_seed_{}'.\
120 | format(opt.method, opt.dataset, opt.model, opt.learning_rate,
121 | opt.multiview, opt.label, opt.weight_decay, opt.batch_size,
122 | opt.temp, opt.seed)
123 |
124 | # warm-up for large-batch training,
125 | if opt.batch_size >= 1024:
126 | opt.warm = True
127 | if opt.warm:
128 | opt.warmup_from = 0.01
129 | opt.warm_epochs = 10
130 | if opt.cosine:
131 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
132 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
133 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
134 | else:
135 | opt.warmup_to = opt.learning_rate
136 |
137 | if opt.cosine:
138 | opt.model_name = '{}_cosine'.format(opt.model_name)
139 | if opt.warm:
140 | opt.model_name = '{}_warm'.format(opt.model_name)
141 | if opt.exp_name:
142 | opt.model_name = '{}_{}'.format(opt.model_name, opt.exp_name)
143 |
144 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
145 | if not os.path.isdir(opt.save_folder):
146 | os.makedirs(opt.save_folder)
147 |
148 | return opt
149 |
150 |
151 | def set_loader(opt):
152 | # construct data loader
153 | if opt.dataset == 'cifar10':
154 | mean = (0.4914, 0.4822, 0.4465)
155 | std = (0.2023, 0.1994, 0.2010)
156 | size = 32
157 | elif opt.dataset == 'cifar100':
158 | mean = (0.5071, 0.4867, 0.4408)
159 | std = (0.2675, 0.2565, 0.2761)
160 | size = 32
161 | elif opt.dataset == 'tinyimagenet':
162 | mean = (0.485, 0.456, 0.406)
163 | std = (0.229, 0.224, 0.225)
164 | size = 64
165 | elif opt.dataset == 'imagenet' or opt.dataset == 'imagenet100':
166 | mean = (0.485, 0.456, 0.406)
167 | std = (0.229, 0.224, 0.225)
168 | size = 224
169 | else:
170 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
171 |
172 | normalize = transforms.Normalize(mean=mean, std=std)
173 | transform = [transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)), transforms.RandomHorizontalFlip()]
174 | if not opt.weakaug:
175 | transform += [transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
176 | transforms.RandomGrayscale(p=0.2)]
177 |
178 | transform += [transforms.ToTensor(), normalize]
179 | train_transform = transforms.Compose(transform)
180 |
181 | if opt.randaug:
182 | train_transform.transforms.insert(0, RandAugment(2, 9))
183 | if opt.multiview:
184 | train_transform = TwoCropTransform(train_transform)
185 |
186 | if opt.dataset == 'cifar10':
187 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
188 | transform=train_transform,
189 | download=True)
190 | elif opt.dataset == 'cifar100':
191 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
192 | transform=train_transform,
193 | download=True)
194 | elif opt.dataset == 'tinyimagenet':
195 | train_dataset = TinyImageNet(root=opt.data_folder,
196 | transform=train_transform,
197 | download=True)
198 | elif opt.dataset == 'imagenet':
199 | traindir = os.path.join(opt.data_folder, 'train')
200 | train_dataset = datasets.ImageFolder(root=traindir,
201 | transform=train_transform)
202 | elif opt.dataset == 'imagenet100':
203 | traindir = os.path.join(opt.data_folder, 'train')
204 | train_dataset = ImageNetSubset('./utils/imagenet100.txt',
205 | root=traindir,
206 | transform=train_transform)
207 | else:
208 | raise ValueError(opt.dataset)
209 |
210 | train_loader = torch.utils.data.DataLoader(
211 | train_dataset, batch_size=opt.batch_size, shuffle=True,
212 | num_workers=opt.num_workers, pin_memory=True, sampler=None)
213 |
214 | return train_loader
215 |
216 |
217 | def set_model(opt):
218 | model_kwargs = {'name': opt.model,
219 | 'dataset': opt.dataset,
220 | 'selfcon_pos': eval(opt.selfcon_pos),
221 | 'selfcon_arch': opt.selfcon_arch,
222 | 'selfcon_size': opt.selfcon_size
223 | }
224 | if opt.model.startswith('resnet'):
225 | model = ConResNet(**model_kwargs)
226 | elif opt.model.startswith('vgg'):
227 | model = ConVGG(**model_kwargs)
228 | elif opt.model.startswith('wrn'):
229 | model = ConWRN(**model_kwargs)
230 | elif opt.model.startswith('eff'):
231 | model = ConEfficientNet(**model_kwargs)
232 |
233 | criterion = ConLoss(temperature=opt.temp)
234 |
235 | if torch.cuda.is_available():
236 | if torch.cuda.device_count() > 1:
237 | model.encoder = torch.nn.DataParallel(model.encoder)
238 | model = model.cuda()
239 | criterion = criterion.cuda()
240 | cudnn.benchmark = True
241 |
242 | return model, criterion, opt
243 |
244 |
245 | def _train(images, labels, model, criterion, epoch, bsz, opt):
246 | # compute loss
247 | features = model(images)
248 | if opt.method == 'Con':
249 | f1, f2 = torch.split(features[1], [bsz, bsz], dim=0)
250 | elif opt.method == 'SupCon':
251 | if opt.multiview:
252 | f1, f2 = torch.split(features[1], [bsz, bsz], dim=0)
253 | else: # opt.method == 'SelfCon'
254 | f1, f2 = features
255 |
256 | if opt.method == 'SupCon':
257 | # SupCon
258 | if opt.multiview:
259 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
260 | loss = criterion(features, labels)
261 | # SupCon-S
262 | else:
263 | features = features[1].unsqueeze(1)
264 | loss = criterion(features, labels, supcon_s=True)
265 |
266 | elif opt.method == 'Con':
267 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
268 | loss = criterion(features)
269 | elif opt.method == 'SelfCon':
270 | loss = torch.tensor([0.0]).cuda()
271 | # SelfCon
272 | if not opt.multiview:
273 | if not opt.alpha:
274 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1)
275 | # SelfCon-SU
276 | if not opt.label:
277 | loss += criterion(features)
278 | # SelfCon
279 | else:
280 | loss += criterion(features, labels)
281 | else:
282 | features = f2.unsqueeze(1)
283 | if opt.label:
284 | loss += criterion(features, labels, supcon_s=True)
285 |
286 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1)
287 | # SelfCon-SU*
288 | if not opt.label:
289 | loss += opt.alpha * criterion(features, selfcon_s_FG=True)
290 | # SelfCon-S*
291 | else:
292 | loss += opt.alpha * criterion(features, labels, selfcon_s_FG=True)
293 | # SelfCon-M
294 | else:
295 | if not opt.alpha:
296 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1)
297 | labels_repeat = torch.cat([labels, labels], dim=0)
298 | # SelfCon-MU
299 | if not opt.label:
300 | loss += criterion(features)
301 | # SelfCon-M
302 | else:
303 | loss += criterion(features, labels_repeat)
304 | else:
305 | f2_1, f2_2 = torch.split(f2, [bsz, bsz], dim=0)
306 | features = torch.cat([f2_1.unsqueeze(1), f2_2.unsqueeze(1)], dim=1)
307 | # contrastive loss between F (backbone)
308 | if not opt.label:
309 | loss += criterion(features)
310 | else:
311 | loss += criterion(features, labels)
312 |
313 | features = torch.cat([f.unsqueeze(1) for f in f1] + [f2.unsqueeze(1)], dim=1)
314 | # SelfCon-MU*
315 | if not opt.label:
316 | loss += opt.alpha * criterion(features, selfcon_m_FG=True)
317 | # SelfCon-M*
318 | else:
319 | loss += opt.alpha * criterion(features, labels, selfcon_m_FG=True)
320 | else:
321 | raise ValueError('contrastive method not supported: {}'.
322 | format(opt.method))
323 |
324 | return loss
325 |
326 |
327 | def train(train_loader, model, criterion, optimizer, epoch, opt):
328 | """one epoch training"""
329 | model.train()
330 | if opt.precision:
331 | scaler = torch.cuda.amp.GradScaler()
332 |
333 | batch_time = AverageMeter()
334 | data_time = AverageMeter()
335 | losses = AverageMeter()
336 |
337 | end = time.time()
338 | for idx, (images, labels) in enumerate(train_loader):
339 | data_time.update(time.time() - end)
340 |
341 | bsz = labels.shape[0]
342 |
343 | if opt.multiview:
344 | images = torch.cat([images[0], images[1]], dim=0)
345 | if torch.cuda.is_available():
346 | images = images.cuda(non_blocking=True)
347 | labels = labels.cuda(non_blocking=True)
348 |
349 | # warm-up learning rate
350 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
351 |
352 | if opt.precision:
353 | with torch.cuda.amp.autocast():
354 | loss = _train(images, labels, model, criterion, epoch, bsz, opt)
355 | else:
356 | loss = _train(images, labels, model, criterion, epoch, bsz, opt)
357 |
358 | # update metric
359 | losses.update(loss.item(), bsz)
360 |
361 | # SGD
362 | optimizer.zero_grad()
363 | if not opt.precision:
364 | loss.backward()
365 | optimizer.step()
366 | else:
367 | scaler.scale(loss).backward()
368 | scaler.step(optimizer)
369 | scaler.update()
370 |
371 | # measure elapsed time
372 | batch_time.update(time.time() - end)
373 | end = time.time()
374 |
375 | # print info
376 | if (idx + 1) % opt.print_freq == 0:
377 | print('Train: [{0}][{1}/{2}]\t'
378 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
379 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
380 | 'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
381 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
382 | data_time=data_time, loss=losses))
383 | sys.stdout.flush()
384 |
385 | return losses.avg
386 |
387 |
388 | def main():
389 | opt = parse_option()
390 |
391 | np.random.seed(opt.seed)
392 | random.seed(opt.seed)
393 | torch.manual_seed(opt.seed)
394 | torch.cuda.manual_seed(opt.seed)
395 | # cudnn.deterministic = True
396 |
397 | # build model and criterion
398 | model, criterion, opt = set_model(opt)
399 |
400 | # build data loader
401 | train_loader = set_loader(opt)
402 |
403 | # build optimizer
404 | optimizer = set_optimizer(opt, model)
405 |
406 | if opt.resume:
407 | if os.path.isfile(opt.resume):
408 | print("=> loading checkpoint '{}'".format(opt.resume))
409 | checkpoint = torch.load(opt.resume)
410 | opt.start_epoch = checkpoint['epoch']
411 | model.load_state_dict(checkpoint['model'])
412 | optimizer.load_state_dict(checkpoint['optimizer'])
413 | opt.start_epoch += 1
414 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch']))
415 | else:
416 | print("=> no checkpoint found at '{}'".format(opt.resume))
417 | else:
418 | opt.start_epoch = 1
419 |
420 | # training routine
421 | for epoch in range(opt.start_epoch, opt.epochs + 1):
422 | adjust_learning_rate(opt, optimizer, epoch)
423 |
424 | # train for one epoch
425 | time1 = time.time()
426 | loss = train(train_loader, model, criterion, optimizer, epoch, opt)
427 | time2 = time.time()
428 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
429 |
430 | if opt.save_freq:
431 | if epoch % opt.save_freq == 0:
432 | save_file = os.path.join(
433 | opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
434 | save_model(model, optimizer, opt, epoch, save_file)
435 |
436 | # save the last model
437 | save_file = os.path.join(
438 | opt.save_folder, 'last.pth')
439 | save_model(model, optimizer, opt, epoch, save_file)
440 |
441 |
442 | if __name__ == '__main__':
443 | main()
444 |
--------------------------------------------------------------------------------
/main_linear.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import argparse
6 | import warnings
7 | import time
8 | import math
9 | import random
10 | import builtins
11 | import numpy as np
12 |
13 | import torch
14 | import torch.backends.cudnn as cudnn
15 | import torch.multiprocessing as mp
16 | import torch.distributed as dist
17 | from torchvision import transforms, datasets
18 |
19 | from utils.util import *
20 | from utils.tinyimagenet import TinyImageNet
21 | from utils.imagenet import ImageNetSubset
22 | from networks.resnet_big import ConResNet, LinearClassifier
23 | from networks.vgg_big import ConVGG, LinearClassifier_VGG
24 | from networks.wrn_big import ConWRN, LinearClassifier_WRN
25 | from networks.efficient_big import ConEfficientNet, LinearClassifier_EFF
26 |
27 |
28 | def parse_option():
29 | parser = argparse.ArgumentParser('argument for training')
30 |
31 | parser.add_argument('--exp_name', type=str, default='')
32 | parser.add_argument('--seed', type=int, default=0)
33 | parser.add_argument('--print_freq', type=int, default=10)
34 | parser.add_argument('--save_dir', type=str, default='./save/representation')
35 | parser.add_argument('--ckpt', type=str, default='',
36 | help='path for pre-trained model')
37 | parser.add_argument('--subnet', action='store_true',
38 | help='measure the accuracy of sub-network or not')
39 |
40 | # dataset
41 | parser.add_argument('--dataset', type=str, default='imagenet', choices=['cifar10', 'cifar100', 'tinyimagenet', 'imagenet', 'imagenet100'])
42 | parser.add_argument('--data_folder', type=str, default='datasets/')
43 | parser.add_argument('--batch_size', type=int, default=256)
44 | parser.add_argument('--num_workers', type=int, default=16)
45 |
46 | # model
47 | parser.add_argument('--model', type=str, default='resnet50')
48 | parser.add_argument('--selfcon_pos', type=str, default='[False,False,False]',
49 | help='where to augment the paths')
50 | parser.add_argument('--selfcon_arch', type=str, default='resnet',
51 | choices=['resnet', 'vgg', 'efficientnet', 'wrn'], help='which architecture to form a sub-network')
52 | parser.add_argument('--selfcon_size', type=str, default='same',
53 | choices=['fc', 'same', 'small'], help='argument for num_blocks of a sub-network')
54 |
55 | # optimization
56 | parser.add_argument('--epochs', type=int, default=100)
57 | parser.add_argument('--learning_rate', type=float, default=0.1)
58 | parser.add_argument('--lr_decay_epochs', type=str, default='60,75,90')
59 | parser.add_argument('--lr_decay_rate', type=float, default=0.2)
60 | parser.add_argument('--weight_decay', type=float, default=0)
61 | parser.add_argument('--momentum', type=float, default=0.9)
62 | parser.add_argument('--cosine', action='store_true',
63 | help='using cosine annealing')
64 | parser.add_argument('--warm', action='store_true',
65 | help='warm-up for large batch training')
66 |
67 | opt = parser.parse_args()
68 |
69 | iterations = opt.lr_decay_epochs.split(',')
70 | opt.lr_decay_epochs = list([])
71 | for it in iterations:
72 | opt.lr_decay_epochs.append(int(it))
73 |
74 | opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}'.\
75 | format(opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
76 | opt.batch_size)
77 |
78 | if opt.cosine:
79 | opt.model_name = '{}_cosine'.format(opt.model_name)
80 |
81 | # warm-up for large-batch training,
82 | if opt.warm:
83 | opt.model_name = '{}_warm'.format(opt.model_name)
84 | opt.warmup_from = 0.01
85 | opt.warm_epochs = 10
86 | if opt.cosine:
87 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
88 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
89 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
90 | else:
91 | opt.warmup_to = opt.learning_rate
92 |
93 | if opt.dataset == 'cifar10':
94 | opt.n_cls = 10
95 | elif opt.dataset == 'cifar100':
96 | opt.n_cls = 100
97 | elif opt.dataset == 'tinyimagenet':
98 | opt.n_cls = 200
99 | elif opt.dataset == 'imagenet':
100 | opt.n_cls = 1000
101 | elif opt.dataset == 'imagenet100':
102 | opt.n_cls = 100
103 | else:
104 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
105 |
106 | return opt
107 |
108 |
109 | def set_loader(opt):
110 | # construct data loader
111 | if opt.dataset == 'cifar10':
112 | mean = (0.4914, 0.4822, 0.4465)
113 | std = (0.2023, 0.1994, 0.2010)
114 | size = 32
115 | elif opt.dataset == 'cifar100':
116 | mean = (0.5071, 0.4867, 0.4408)
117 | std = (0.2675, 0.2565, 0.2761)
118 | size = 32
119 | elif opt.dataset == 'tinyimagenet':
120 | mean = (0.485, 0.456, 0.406)
121 | std = (0.229, 0.224, 0.225)
122 | size = 64
123 | elif opt.dataset == 'imagenet':
124 | mean = (0.485, 0.456, 0.406)
125 | std = (0.229, 0.224, 0.225)
126 | size = 224
127 | elif opt.dataset == 'imagenet100':
128 | mean = (0.485, 0.456, 0.406)
129 | std = (0.229, 0.224, 0.225)
130 | size = 224
131 | else:
132 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
133 | normalize = transforms.Normalize(mean=mean, std=std)
134 |
135 | train_transform = transforms.Compose([
136 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)),
137 | transforms.RandomHorizontalFlip(),
138 | transforms.ToTensor(),
139 | normalize,
140 | ])
141 |
142 | # if opt.randaug:
143 | # train_transform.transforms.insert(0, RandAugment(2, 9))
144 |
145 | if opt.dataset not in ['imagenet', 'imagenet100']:
146 | val_transform = transforms.Compose([
147 | transforms.ToTensor(),
148 | normalize,
149 | ])
150 | else:
151 | val_transform = transforms.Compose([transforms.Resize(256),
152 | transforms.CenterCrop(224),
153 | transforms.ToTensor(),
154 | normalize])
155 |
156 | if opt.dataset == 'cifar10':
157 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
158 | transform=train_transform,
159 | download=True)
160 | val_dataset = datasets.CIFAR10(root=opt.data_folder,
161 | train=False,
162 | transform=val_transform)
163 | elif opt.dataset == 'cifar100':
164 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
165 | transform=train_transform,
166 | download=True)
167 | val_dataset = datasets.CIFAR100(root=opt.data_folder,
168 | train=False,
169 | transform=val_transform)
170 | elif opt.dataset == 'tinyimagenet':
171 | train_dataset = TinyImageNet(root=opt.data_folder,
172 | transform=train_transform,
173 | download=True)
174 | val_dataset = TinyImageNet(root=opt.data_folder,
175 | train=False,
176 | transform=val_transform)
177 | elif opt.dataset == 'imagenet':
178 | traindir = os.path.join(opt.data_folder, 'train')
179 | train_dataset = datasets.ImageFolder(root=traindir,
180 | transform=train_transform)
181 |
182 | valdir = os.path.join(opt.data_folder, 'val')
183 | val_dataset = datasets.ImageFolder(root=valdir,
184 | transform=val_transform)
185 | elif opt.dataset == 'imagenet100':
186 | traindir = os.path.join(opt.data_folder, 'train')
187 | train_dataset = ImageNetSubset('./utils/imagenet100.txt',
188 | root=traindir,
189 | transform=train_transform)
190 |
191 | valdir = os.path.join(opt.data_folder, 'val')
192 | val_dataset = ImageNetSubset('./utils/imagenet100.txt',
193 | root=valdir,
194 | transform=val_transform)
195 | else:
196 | raise ValueError(opt.dataset)
197 |
198 | train_sampler = None
199 | train_loader = torch.utils.data.DataLoader(
200 | train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
201 | num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
202 | val_loader = torch.utils.data.DataLoader(
203 | val_dataset, batch_size=512, shuffle=False,
204 | num_workers=8, pin_memory=True)
205 |
206 | return train_loader, val_loader, train_sampler
207 |
208 | def set_model(opt):
209 | model_kwargs = {'name': opt.model,
210 | 'dataset': opt.dataset,
211 | 'selfcon_pos': eval(opt.selfcon_pos),
212 | 'selfcon_arch': opt.selfcon_arch,
213 | 'selfcon_size': opt.selfcon_size
214 | }
215 |
216 | if opt.model.startswith('resnet'):
217 | model = ConResNet(**model_kwargs)
218 | classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)
219 | if opt.subnet:
220 | sub_classifier = LinearClassifier(name=opt.model, num_classes=opt.n_cls)
221 |
222 | elif opt.model.startswith('vgg'):
223 | model = ConVGG(**model_kwargs)
224 | classifier = LinearClassifier_VGG(name=opt.model, num_classes=opt.n_cls)
225 | if opt.subnet:
226 | sub_classifier = LinearClassifier_VGG(name=opt.model, num_classes=opt.n_cls)
227 |
228 | elif opt.model.startswith('wrn'):
229 | model = ConWRN(**model_kwargs)
230 | classifier = LinearClassifier_WRN(name=opt.model, num_classes=opt.n_cls)
231 | if opt.subnet:
232 | sub_classifier = LinearClassifier_WRN(name=opt.model, num_classes=opt.n_cls)
233 |
234 | elif opt.model.startswith('eff'):
235 | model = ConEfficientNet(**model_kwargs)
236 | classifier = LinearClassifier_EFF(name=opt.model, num_classes=opt.n_cls)
237 | if opt.subnet:
238 | sub_classifier = LinearClassifier_EFF(name=opt.model, num_classes=opt.n_cls)
239 |
240 | criterion = torch.nn.CrossEntropyLoss()
241 | if opt.ckpt:
242 | ckpt = torch.load(opt.ckpt, map_location='cpu')
243 | state_dict = ckpt['model']
244 |
245 | if torch.cuda.is_available():
246 | if torch.cuda.device_count() > 1:
247 | model.encoder = torch.nn.DataParallel(model.encoder)
248 | else:
249 | if opt.ckpt:
250 | new_state_dict = {}
251 | for k, v in state_dict.items():
252 | k = k.replace("module.", "")
253 | new_state_dict[k] = v
254 | state_dict = new_state_dict
255 |
256 | model.cuda()
257 | classifier = classifier.cuda()
258 | if opt.subnet:
259 | sub_classifier = sub_classifier.cuda()
260 | criterion = criterion.cuda()
261 | cudnn.benchmark = True
262 |
263 | if opt.ckpt:
264 | state_dict = {k.replace("downsample", "shortcut"): v for k, v in state_dict.items()}
265 | model.load_state_dict(state_dict, strict=False)
266 |
267 | if not opt.subnet:
268 | sub_classifier = None
269 | return model, classifier, sub_classifier, criterion, opt
270 |
271 |
272 | def train(train_loader, model, classifier, criterion, optimizer, epoch, opt, subnet=False):
273 | """one epoch training"""
274 | model.eval()
275 | classifier.train()
276 |
277 | batch_time = AverageMeter()
278 | data_time = AverageMeter()
279 | losses = AverageMeter()
280 | top1 = AverageMeter()
281 |
282 | end = time.time()
283 | for idx, (images, labels) in enumerate(train_loader):
284 | data_time.update(time.time() - end)
285 |
286 | images = images.cuda(non_blocking=True)
287 | labels = labels.cuda(non_blocking=True)
288 | bsz = labels.shape[0]
289 |
290 | # warm-up learning rate
291 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
292 |
293 | # compute loss
294 | with torch.no_grad():
295 | features = model.encoder(images)
296 | features = features[-1] if not subnet else features[0][-1]
297 | output = classifier(features.detach())
298 | loss = criterion(output, labels)
299 |
300 | # update metric
301 | losses.update(loss.item(), bsz)
302 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
303 | top1.update(acc1[0], bsz)
304 |
305 | # SGD
306 | optimizer.zero_grad()
307 | loss.backward()
308 | optimizer.step()
309 |
310 | # measure elapsed time
311 | batch_time.update(time.time() - end)
312 | end = time.time()
313 |
314 | # print info
315 | if (idx + 1) % opt.print_freq == 0:
316 | print('Train: [{0}][{1}/{2}]\t'
317 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
318 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
319 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
320 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
321 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
322 | data_time=data_time, loss=losses, top1=top1))
323 | sys.stdout.flush()
324 |
325 | return losses.avg, top1.avg
326 |
327 |
328 | def validate(val_loader, model, classifier, sub_classifier, criterion, opt, best_acc):
329 | def __update_metric(output, labels, top1, top5, bsz):
330 | acc1, acc5 = accuracy(output, labels, topk=(1, 5))
331 | top1.update(acc1[0], bsz)
332 | top5.update(acc5[0], bsz)
333 |
334 | return top1, top5
335 |
336 | def __best_acc(val_acc1, val_acc5, best_acc, key='backbone'):
337 | if val_acc1.item() > best_acc[key][0]:
338 | best_acc[key][0] = val_acc1.item()
339 | best_acc[key][1] = val_acc5.item()
340 |
341 | return best_acc
342 |
343 | """validation"""
344 | model.eval()
345 | classifier.eval()
346 | if sub_classifier:
347 | sub_classifier.eval()
348 |
349 | batch_time = AverageMeter()
350 | losses = AverageMeter()
351 | top1, top5 = AverageMeter(), AverageMeter()
352 | top1_sub, top5_sub = AverageMeter(), AverageMeter()
353 | top1_ens, top5_ens = AverageMeter(), AverageMeter()
354 |
355 | with torch.no_grad():
356 | end = time.time()
357 | for idx, (images, labels) in enumerate(val_loader):
358 | images = images.float().cuda()
359 | labels = labels.cuda()
360 | bsz = labels.shape[0]
361 |
362 | # forward
363 | features = model.encoder(images)
364 | output = classifier(features[-1])
365 | loss = criterion(output, labels)
366 |
367 | # for only one subnetwork
368 | if opt.subnet:
369 | sub_output = sub_classifier(features[0][-1])
370 | ensemble_output = (output + sub_output) / 2
371 |
372 | # update metric
373 | losses.update(loss.item(), bsz)
374 | top1, top5 = __update_metric(output, labels, top1, top5, bsz)
375 | if opt.subnet:
376 | top1_sub, top5_sub = __update_metric(sub_output, labels, top1_sub, top5_sub, bsz)
377 | top1_ens, top5_ens = __update_metric(ensemble_output, labels, top1_ens, top5_ens, bsz)
378 |
379 | # measure elapsed time
380 | batch_time.update(time.time() - end)
381 | end = time.time()
382 |
383 | if idx % opt.print_freq == 0:
384 | print('Test: [{0}/{1}]\t'
385 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
386 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
387 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
388 | idx, len(val_loader), batch_time=batch_time,
389 | loss=losses, top1=top1))
390 |
391 | print(' * Acc@1 {top1.avg:.2f}, Acc@5 {top5.avg:.2f}'.format(top1=top1, top5=top5))
392 | best_acc = __best_acc(top1.avg, top5.avg, best_acc)
393 |
394 | if opt.subnet:
395 | print(' * Acc@1 {top1.avg:.2f}, Acc@5 {top5.avg:.2f}'.format(top1=top1_sub, top5=top5_sub))
396 | best_acc = __best_acc(top1_sub.avg, top5_sub.avg, best_acc, key='sub')
397 |
398 | print(' * Acc@1 {top1.avg:.2f}, Acc@5 {top5.avg:.2f}'.format(top1=top1_ens, top5=top5_ens))
399 | best_acc = __best_acc(top1_ens.avg, top5_ens.avg, best_acc, key='ensemble')
400 | return best_acc
401 |
402 |
403 | def main():
404 | opt = parse_option()
405 |
406 | # fix seed
407 | np.random.seed(opt.seed)
408 | random.seed(opt.seed)
409 | torch.manual_seed(opt.seed)
410 | torch.cuda.manual_seed(opt.seed)
411 | cudnn.deterministic = True
412 |
413 | best_acc = {'backbone': [0, 0, 0],
414 | 'sub': [0, 0, 0],
415 | 'ensemble': [0, 0]}
416 |
417 | # build model and criterion
418 | model, classifier, sub_classifier, criterion, opt = set_model(opt)
419 |
420 | # build data loader
421 | train_loader, val_loader, train_sampler = set_loader(opt)
422 |
423 | # build optimizer
424 | optimizer = set_optimizer(opt, classifier)
425 | sub_optimizer = set_optimizer(opt, sub_classifier) if opt.subnet else None
426 |
427 | # training routine
428 | for epoch in range(1, opt.epochs + 1):
429 | adjust_learning_rate(opt, optimizer, epoch)
430 | if opt.subnet:
431 | adjust_learning_rate(opt, sub_optimizer, epoch)
432 |
433 | # train for one epoch
434 | time1 = time.time()
435 | loss, acc = train(train_loader, model, classifier, criterion,
436 | optimizer, epoch, opt)
437 | time2 = time.time()
438 | print('Train epoch {}, total time {:.2f}, accuracy:{:.2f}'.format(
439 | epoch, time2 - time1, acc))
440 | best_acc['backbone'][2] = acc.item()
441 |
442 | if opt.subnet:
443 | _, sub_acc = train(train_loader, model, sub_classifier, criterion,
444 | sub_optimizer, epoch, opt, subnet=True)
445 | print('Train epoch {}, accuracy:{:.2f}'.format(
446 | epoch, sub_acc))
447 | best_acc['sub'][2] = sub_acc.item()
448 |
449 | # eval for one epoch
450 | best_acc = validate(val_loader, model, classifier, sub_classifier, criterion, opt, best_acc)
451 |
452 | update_json(opt.ckpt + '_%s' % opt.exp_name if opt.exp_name else opt.ckpt, best_acc, path='%s/results.json' % (opt.save_dir))
453 |
454 | # for robustness experiments
455 | method = 'supcon'
456 | if not os.path.isdir('./robustness/ckpt'):
457 | os.makedirs('./robustness/ckpt')
458 | torch.save(model.state_dict(), './robustness/ckpt/{}_encoder.pth'.format(method))
459 | torch.save(classifier.state_dict(), './robustness/ckpt/{}_classifier.pth'.format(method))
460 |
461 | if __name__ == '__main__':
462 | main()
463 |
--------------------------------------------------------------------------------
/main_ce.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import argparse
6 | import time
7 | import math
8 | import random
9 | import builtins
10 | import numpy as np
11 | import warnings
12 | warnings.filterwarnings(action='ignore')
13 |
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | from torchvision import transforms, datasets
17 | from torch.autograd import Variable
18 |
19 | from networks.resnet_big import CEResNet
20 | from networks.vgg_big import CEVGG
21 | from networks.wrn_big import CEWRN
22 | from networks.efficient_big import CEEffNet
23 | from losses import *
24 | from utils.util import *
25 | from utils.tinyimagenet import TinyImageNet
26 | from utils.imagenet import ImageNetSubset
27 |
28 |
29 | def parse_option():
30 | parser = argparse.ArgumentParser('argument for training')
31 |
32 | parser.add_argument('--exp_name', type=str, default='')
33 | parser.add_argument('--seed', type=int, default=0)
34 | parser.add_argument('--print_freq', type=int, default=10)
35 | parser.add_argument('--resume', help='path of model checkpoint to resume', type=str,
36 | default='')
37 |
38 | # dataset
39 | parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'tinyimagenet', 'imagenet', 'imagenet100'])
40 | parser.add_argument('--data_folder', type=str, default='datasets/')
41 | parser.add_argument('--batch_size', type=int, default=256)
42 | parser.add_argument('--num_workers', type=int, default=16)
43 |
44 | # model
45 | parser.add_argument('--model', type=str, default='resnet50')
46 | parser.add_argument('--selfcon_pos', type=str, default='[False,False,False]',
47 | help='where to augment the paths')
48 | parser.add_argument('--selfcon_arch', type=str, default='resnet',
49 | choices=['resnet', 'vgg', 'efficientnet', 'wrn'], help='which architecture to form a sub-network')
50 | parser.add_argument('--selfcon_size', type=str, default='same',
51 | choices=['fc', 'same', 'small'], help='argument for num_blocks of a sub-network')
52 | parser.add_argument('--dim_out', default=128, type=int,
53 | help='feat dimension for CEResNet')
54 |
55 | # optimization
56 | parser.add_argument('--epochs', type=int, default=500)
57 | parser.add_argument('--learning_rate', type=float, default=0.2)
58 | parser.add_argument('--lr_decay_epochs', type=str, default='350,400,450')
59 | parser.add_argument('--lr_decay_rate', type=float, default=0.1)
60 | parser.add_argument('--weight_decay', type=float, default=1e-4)
61 | parser.add_argument('--momentum', type=float, default=0.9)
62 | parser.add_argument('--cosine', action='store_true',
63 | help='using cosine annealing')
64 | parser.add_argument('--warm', action='store_true',
65 | help='warm-up for large batch training')
66 |
67 | # important arguments
68 | parser.add_argument('--method', type=str,
69 | choices=['ce', 'subnet_ce', 'kd', 'selfcon'], help='choose method')
70 | parser.add_argument('--alpha', type=float, default=0., help='weight balance for subnet CE')
71 | parser.add_argument('--beta', type=float, default=0., help='weight balance for KD')
72 | parser.add_argument('--gamma', type=float, default=0., help='weight balance for other losses')
73 | parser.add_argument('--temperature', type=float, default=3.0, help='temperature for KD loss function')
74 |
75 | opt = parser.parse_args()
76 |
77 | if opt.model.startswith('vgg'):
78 | if opt.selfcon_pos == '[False,False,False]':
79 | opt.selfcon_pos = '[False,False,False,False]'
80 | opt.selfcon_arch = 'vgg'
81 | if opt.model.startswith('eff'):
82 | if opt.selfcon_pos == '[False,False,False]':
83 | opt.selfcon_pos = '[False]'
84 | opt.selfcon_arch = 'eff'
85 |
86 | # set the path according to the environment
87 | opt.model_path = './save/distill/%s/%s_models' % (opt.method, opt.dataset)
88 |
89 | iterations = opt.lr_decay_epochs.split(',')
90 | opt.lr_decay_epochs = list([])
91 | for it in iterations:
92 | opt.lr_decay_epochs.append(int(it))
93 |
94 | opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_seed_{}'.\
95 | format(opt.method, opt.dataset, opt.model, opt.learning_rate,
96 | opt.weight_decay, opt.batch_size, opt.seed)
97 |
98 | if opt.cosine:
99 | opt.model_name = '{}_cosine'.format(opt.model_name)
100 | if opt.exp_name:
101 | opt.model_name = '{}_{}'.format(opt.model_name, opt.exp_name)
102 |
103 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
104 | if not os.path.isdir(opt.save_folder):
105 | os.makedirs(opt.save_folder)
106 |
107 | if opt.dataset == 'cifar10':
108 | opt.n_cls = 10
109 | opt.n_data = 50000
110 | elif opt.dataset == 'cifar100':
111 | opt.n_cls = 100
112 | opt.n_data = 50000
113 | elif opt.dataset == 'tinyimagenet':
114 | opt.n_cls = 200
115 | opt.n_data = 100000
116 | elif opt.dataset == 'imagenet':
117 | opt.n_cls = 1000
118 | opt.n_data = 1200000
119 | elif opt.dataset == 'imagenet100':
120 | opt.n_cls = 100
121 | opt.n_data = 120000
122 | else:
123 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
124 |
125 | if opt.method == 'ce':
126 | opt.alpha, opt.beta, opt.gamma = 0, 0, 0
127 | elif opt.method == 'subnet_ce':
128 | opt.alpha, opt.beta, opt.gamma = 1.0, 0, 0
129 | elif opt.method == 'kd':
130 | opt.alpha, opt.beta, opt.gamma = 0.5, 0.5, 0
131 | elif opt.method == 'selfcon':
132 | opt.alpha, opt.beta, opt.gamma = 1.0, 0, 0.8
133 |
134 | return opt
135 |
136 |
137 | def set_loader(opt):
138 | # construct data loader
139 | if opt.dataset == 'cifar10':
140 | mean = (0.4914, 0.4822, 0.4465)
141 | std = (0.2023, 0.1994, 0.2010)
142 | size = 32
143 | elif opt.dataset == 'cifar100':
144 | mean = (0.5071, 0.4867, 0.4408)
145 | std = (0.2675, 0.2565, 0.2761)
146 | size = 32
147 | elif opt.dataset == 'tinyimagenet':
148 | mean = (0.485, 0.456, 0.406)
149 | std = (0.229, 0.224, 0.225)
150 | size = 64
151 | elif opt.dataset == 'imagenet' or opt.dataset == 'imagenet100':
152 | mean = (0.485, 0.456, 0.406)
153 | std = (0.229, 0.224, 0.225)
154 | size = 224
155 | else:
156 | raise ValueError('dataset not supported: {}'.format(opt.dataset))
157 | normalize = transforms.Normalize(mean=mean, std=std)
158 |
159 | train_transform = transforms.Compose([
160 | transforms.RandomResizedCrop(size=size, scale=(0.2, 1.)),
161 | transforms.RandomHorizontalFlip(),
162 | transforms.ToTensor(),
163 | normalize,
164 | ])
165 |
166 | if opt.dataset not in ['imagenet', 'imagenet100']:
167 | val_transform = transforms.Compose([
168 | transforms.ToTensor(),
169 | normalize,
170 | ])
171 | else:
172 | val_transform = transforms.Compose([transforms.Resize(256),
173 | transforms.CenterCrop(224),
174 | transforms.ToTensor(),
175 | normalize])
176 |
177 | if opt.dataset == 'cifar10':
178 | train_dataset = datasets.CIFAR10(root=opt.data_folder,
179 | transform=train_transform,
180 | download=True)
181 | val_dataset = datasets.CIFAR10(root=opt.data_folder,
182 | train=False,
183 | transform=val_transform)
184 | elif opt.dataset == 'cifar100':
185 | train_dataset = datasets.CIFAR100(root=opt.data_folder,
186 | transform=train_transform,
187 | download=True)
188 | val_dataset = datasets.CIFAR100(root=opt.data_folder,
189 | train=False,
190 | transform=val_transform)
191 | elif opt.dataset == 'tinyimagenet':
192 | train_dataset = TinyImageNet(root=opt.data_folder,
193 | transform=train_transform,
194 | download=True)
195 | val_dataset = TinyImageNet(root=opt.data_folder,
196 | train=False,
197 | transform=val_transform)
198 | elif opt.dataset == 'imagenet':
199 | traindir = os.path.join(opt.data_folder, 'train')
200 | valdir = os.path.join(opt.data_folder, 'val')
201 | train_dataset = datasets.ImageFolder(root=traindir, transform=train_transform)
202 | val_dataset = datasets.ImageFolder(root=valdir, transform=val_transform)
203 | elif opt.dataset == 'imagenet100':
204 | traindir = os.path.join(opt.data_folder, 'train')
205 | valdir = os.path.join(opt.data_folder, 'val')
206 |
207 | train_dataset = ImageNetSubset('./utils/imagenet100.txt',
208 | root=traindir,
209 | transform=train_transform)
210 | val_dataset = ImageNetSubset('./utils/imagenet100.txt',
211 | root=valdir,
212 | transform=val_transform)
213 | else:
214 | raise ValueError(opt.dataset)
215 |
216 | train_loader = torch.utils.data.DataLoader(
217 | train_dataset, batch_size=opt.batch_size, shuffle=True,
218 | num_workers=opt.num_workers, pin_memory=True, sampler=None)
219 | val_loader = torch.utils.data.DataLoader(
220 | val_dataset, batch_size=512, shuffle=False,
221 | num_workers=8, pin_memory=True)
222 |
223 | return train_loader, val_loader
224 |
225 |
226 | def set_model(opt):
227 | model_kwargs = {'name': opt.model,
228 | 'method': opt.method,
229 | 'num_classes': opt.n_cls,
230 | 'dim_out': opt.dim_out,
231 | 'dataset': opt.dataset,
232 | 'selfcon_pos': eval(opt.selfcon_pos),
233 | 'selfcon_arch': opt.selfcon_arch,
234 | 'selfcon_size': opt.selfcon_size
235 | }
236 |
237 | if opt.model.startswith('resnet'):
238 | model = CEResNet(**model_kwargs)
239 | elif opt.model.startswith('vgg'):
240 | model = CEVGG(**model_kwargs)
241 | elif opt.model.startswith('wrn'):
242 | model = CEWRN(**model_kwargs)
243 | elif opt.model.startswith('eff'):
244 | model = CEEffNet(**model_kwargs)
245 |
246 | criterion = nn.ModuleList([])
247 | criterion.append(torch.nn.CrossEntropyLoss())
248 | criterion.append(KLLoss(opt.temperature))
249 |
250 | # Note that student and teacher feature shape is same
251 | if opt.method in ['ce', 'subnet_ce', 'kd']:
252 | criterion.append(None)
253 | elif opt.method == 'selfcon':
254 | criterion.append(ConLoss(temperature=opt.temperature))
255 | else:
256 | raise NotImplemented
257 |
258 | if torch.cuda.is_available():
259 | if torch.cuda.device_count() > 1:
260 | model = torch.nn.DataParallel(model)
261 | model = model.cuda()
262 | criterion = criterion.cuda()
263 | cudnn.benchmark = True
264 |
265 | return model, criterion, opt
266 |
267 |
268 | def train(train_loader, model, criterion, optimizer, epoch, opt):
269 | """one epoch training"""
270 | model.train()
271 |
272 | batch_time = AverageMeter()
273 | data_time = AverageMeter()
274 | losses = AverageMeter()
275 | top1 = AverageMeter()
276 | top1_s = AverageMeter()
277 |
278 | only_backbone = True if eval(opt.selfcon_pos) in [[False], [False,False], [False,False,False], [False,False,False,False]] else False
279 |
280 | end = time.time()
281 | for idx, inputs in enumerate(train_loader):
282 | images, labels = inputs
283 |
284 | data_time.update(time.time() - end)
285 |
286 | images = images.cuda(non_blocking=True)
287 | labels = labels.cuda(non_blocking=True)
288 | bsz = labels.shape[0]
289 |
290 | # warm-up learning rate
291 | warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
292 |
293 | # compute loss
294 | if opt.method not in ['ce', 'subnet_ce', 'kd']:
295 | feats, logits = model(images)
296 | else:
297 | logits = model(images)
298 |
299 | loss = criterion[0](logits[-1], labels)
300 |
301 | for sub_logit in logits[0]:
302 | loss += opt.alpha * criterion[0](sub_logit, labels)
303 | loss += opt.beta * criterion[1](sub_logit, logits[-1])
304 | if criterion[2] is not None:
305 | for idx, feat_s in enumerate(feats[0]):
306 | # MLP head of backbone is always in random intialization
307 | features = torch.cat([feat_s.unsqueeze(1), feats[-1].unsqueeze(1)], dim=1)
308 | loss += opt.gamma * criterion[2](features, labels)
309 |
310 | # update metric
311 | losses.update(loss.item(), bsz)
312 | acc1, _ = accuracy(logits[-1], labels, topk=(1, 5))
313 | top1.update(acc1[0], bsz)
314 | if not only_backbone:
315 | acc1_s, _ = accuracy(logits[0][0], labels, topk=(1, 5))
316 | top1_s.update(acc1_s[0], bsz)
317 | else:
318 | top1_s.update(torch.tensor(0.0).to(acc1[0].device), bsz)
319 |
320 | # SGD
321 | optimizer.zero_grad()
322 | loss.backward()
323 | optimizer.step()
324 |
325 | # measure elapsed time
326 | batch_time.update(time.time() - end)
327 | end = time.time()
328 |
329 | # print info
330 | if (idx + 1) % opt.print_freq == 0:
331 | print('Train: [{0}][{1}/{2}]\t'
332 | 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
333 | 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
334 | 'loss {loss.val:.3f} ({loss.avg:.3f})\t'
335 | 'Acc@1 {top1.avg:.3f} {top1_s.avg:.3f}'.format(
336 | epoch, idx + 1, len(train_loader), batch_time=batch_time,
337 | data_time=data_time, loss=losses, top1=top1, top1_s=top1_s))
338 | sys.stdout.flush()
339 |
340 | return losses.avg, top1.avg
341 |
342 |
343 | def validate(val_loader, model, criterion, opt):
344 | """validation"""
345 | model.eval()
346 |
347 | batch_time = AverageMeter()
348 | losses = AverageMeter()
349 | top1_b = AverageMeter()
350 | top5_b = AverageMeter()
351 | top1_s = AverageMeter()
352 | top5_s = AverageMeter()
353 |
354 | only_backbone = True if eval(opt.selfcon_pos) in [[False], [False,False], [False,False,False], [False,False,False,False]] else False
355 |
356 | with torch.no_grad():
357 | end = time.time()
358 | for idx, (images, labels) in enumerate(val_loader):
359 | images = images.float().cuda()
360 | labels = labels.cuda()
361 | bsz = labels.shape[0]
362 |
363 | # forward
364 | if opt.method not in ['ce', 'subnet_ce', 'kd']:
365 | _, logits = model(images)
366 | else:
367 | logits = model(images)
368 |
369 | loss = criterion[0](logits[-1], labels)
370 |
371 | # update metric
372 | losses.update(loss.item(), bsz)
373 | acc1, acc5 = accuracy(logits[-1], labels, topk=(1, 5))
374 | top1_b.update(acc1[0], bsz)
375 | top5_b.update(acc5[0], bsz)
376 | if only_backbone:
377 | top1_s.update(torch.tensor(0.0).to(acc1[0].device), bsz)
378 | top5_s.update(torch.tensor(0.0).to(acc5[0].device), bsz)
379 | else:
380 | # only for the first sub-network (actually we use 1 sub-network)
381 | acc1_s, acc5_s = accuracy(logits[0][0], labels, topk=(1, 5))
382 | top1_s.update(acc1_s[0], bsz)
383 | top5_s.update(acc5_s[0], bsz)
384 |
385 | # measure elapsed time
386 | batch_time.update(time.time() - end)
387 | end = time.time()
388 |
389 | if idx % opt.print_freq == 0:
390 | print('Test: [{0}/{1}]\t'
391 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
392 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
393 | 'Acc@1 ({top1_b.avg:.3f}) ({top1_s.avg:.3f})'.format(
394 | idx, len(val_loader), batch_time=batch_time,
395 | loss=losses, top1_b=top1_b, top1_s=top1_s))
396 |
397 | print(' * Acc@1 {top1_b.avg:.3f} {top1_s.avg:.3f}'.format(top1_b=top1_b, top1_s=top1_s))
398 | return losses.avg, top1_b.avg, top5_b.avg, top1_s.avg, top5_s.avg
399 |
400 |
401 | def main():
402 | opt = parse_option()
403 |
404 | # fix seed
405 | np.random.seed(opt.seed)
406 | random.seed(opt.seed)
407 | torch.manual_seed(opt.seed)
408 | torch.cuda.manual_seed(opt.seed)
409 | cudnn.deterministic = True
410 |
411 | # build model and criterion
412 | model, criterion, opt = set_model(opt)
413 |
414 | # build data loader
415 | train_loader, val_loader = set_loader(opt)
416 |
417 | # build optimizer
418 | optimizer = set_optimizer(opt, model)
419 |
420 | if opt.resume:
421 | if os.path.isfile(opt.resume):
422 | print("=> loading checkpoint '{}'".format(opt.resume))
423 | checkpoint = torch.load(opt.resume)
424 | opt.start_epoch = checkpoint['epoch'] + 1
425 | model.load_state_dict(checkpoint['model'])
426 | optimizer.load_state_dict(checkpoint['optimizer'])
427 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch']))
428 | else:
429 | print("=> no checkpoint found at '{}'".format(opt.resume))
430 | else:
431 | opt.start_epoch = 1
432 |
433 | # warm-up for large-batch training,
434 | if opt.batch_size >= 1024:
435 | opt.warm = True
436 | if opt.warm:
437 | opt.model_name = '{}_warm'.format(opt.model_name)
438 | opt.warmup_from = 0.01
439 | opt.warm_epochs = 10
440 | if opt.cosine:
441 | eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
442 | opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
443 | 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
444 | else:
445 | opt.warmup_to = opt.learning_rate
446 |
447 | # training routine
448 | best_acc1 = 0
449 | for epoch in range(opt.start_epoch, opt.epochs + 1):
450 | adjust_learning_rate(opt, optimizer, epoch)
451 |
452 | # train for one epoch
453 | time1 = time.time()
454 | loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt)
455 | time2 = time.time()
456 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
457 |
458 | # evaluation
459 | loss, val_acc1, val_acc5, val_acc1_s, val_acc5_s = validate(val_loader, model, criterion, opt)
460 |
461 | if val_acc1.item() > best_acc1:
462 | best_acc1 = val_acc1
463 | best_acc5 = val_acc5
464 | best_acc1_s = val_acc1_s
465 | best_acc5_s = val_acc5_s
466 | best_model = model.state_dict()
467 |
468 | # save the last model
469 | save_file = os.path.join(
470 | opt.save_folder, 'last.pth')
471 | save_model(model, optimizer, opt, epoch, save_file)
472 |
473 | # save the best model
474 | # Note that accuracy in results.json is different from the saved best model
475 | # because of multiprocessing distributed setting
476 | model.load_state_dict(best_model)
477 | save_file = os.path.join(
478 | opt.save_folder, 'best.pth')
479 | save_model(model, optimizer, opt, opt.epochs, save_file)
480 |
481 | update_json_list(opt.save_folder, [best_acc1.item(), best_acc5.item(), best_acc1_s.item(), best_acc5_s.item(), train_acc.item()], path='./save/distill/results.json')
482 |
483 |
484 | if __name__ == '__main__':
485 | main()
486 |
--------------------------------------------------------------------------------
/networks/efficient_big.py:
--------------------------------------------------------------------------------
1 | import re
2 | import math
3 | import collections
4 | from functools import partial
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | # Parameters for the entire model (stem, all blocks, and head)
11 | GlobalParams = collections.namedtuple('GlobalParams', [
12 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
13 | 'num_classes', 'width_coefficient', 'depth_coefficient',
14 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
15 |
16 | # Parameters for an individual model block
17 | BlockArgs = collections.namedtuple('BlockArgs', [
18 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
19 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
20 |
21 | # Change namedtuple defaults
22 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
23 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
24 |
25 |
26 | class Swish(nn.Module):
27 | def forward(self, x):
28 | return x * torch.sigmoid(x)
29 |
30 |
31 | def round_filters(filters, global_params):
32 | """ Calculate and round number of filters based on depth multiplier. """
33 | multiplier = global_params.width_coefficient
34 | if not multiplier:
35 | return filters
36 | divisor = global_params.depth_divisor
37 | min_depth = global_params.min_depth
38 | filters *= multiplier
39 | min_depth = min_depth or divisor
40 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
41 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
42 | new_filters += divisor
43 | return int(new_filters)
44 |
45 |
46 | def round_repeats(repeats, global_params):
47 | """ Round number of filters based on depth multiplier. """
48 | multiplier = global_params.depth_coefficient
49 | if not multiplier:
50 | return repeats
51 | return int(math.ceil(multiplier * repeats))
52 |
53 |
54 | def drop_connect(inputs, p, training):
55 | """ Drop connect. """
56 | if not training: return inputs
57 | batch_size = inputs.shape[0]
58 | keep_prob = 1 - p
59 | random_tensor = keep_prob
60 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
61 | binary_tensor = torch.floor(random_tensor)
62 | output = inputs / keep_prob * binary_tensor
63 | return output
64 |
65 |
66 | def get_same_padding_conv2d(image_size=None):
67 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
68 | Static padding is necessary for ONNX exporting of models. """
69 | if image_size is None:
70 | return Conv2dDynamicSamePadding
71 | else:
72 | return partial(Conv2dStaticSamePadding, image_size=image_size)
73 |
74 |
75 | def get_width_and_height_from_size(x):
76 | """ Obtains width and height from a int or tuple """
77 | if isinstance(x, int): return x, x
78 | if isinstance(x, list) or isinstance(x, tuple): return x
79 | else: raise TypeError()
80 |
81 |
82 | def calculate_output_image_size(input_image_size, stride):
83 | """ Calculates the output image size when using Conv2dSamePadding with a stride.
84 | Necessary for static padding. Thanks to mannatsingh for pointing this out. """
85 | if input_image_size is None: return None
86 | image_height, image_width = get_width_and_height_from_size(input_image_size)
87 | stride = stride if isinstance(stride, int) else stride[0]
88 | image_height = int(math.ceil(image_height / stride))
89 | image_width = int(math.ceil(image_width / stride))
90 | return [image_height, image_width]
91 |
92 |
93 | class Conv2dDynamicSamePadding(nn.Conv2d):
94 | """ 2D Convolutions like TensorFlow, for a dynamic image size """
95 |
96 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
97 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
98 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
99 |
100 | def forward(self, x):
101 | ih, iw = x.size()[-2:]
102 | kh, kw = self.weight.size()[-2:]
103 | sh, sw = self.stride
104 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
105 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
106 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
107 | if pad_h > 0 or pad_w > 0:
108 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
109 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
110 |
111 |
112 | class Conv2dStaticSamePadding(nn.Conv2d):
113 | """ 2D Convolutions like TensorFlow, for a fixed image size"""
114 |
115 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
116 | super().__init__(in_channels, out_channels, kernel_size, **kwargs)
117 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
118 |
119 | # Calculate padding based on image size and save it
120 | assert image_size is not None
121 | ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
122 | kh, kw = self.weight.size()[-2:]
123 | sh, sw = self.stride
124 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
125 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
126 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
127 | if pad_h > 0 or pad_w > 0:
128 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
129 | else:
130 | self.static_padding = Identity()
131 |
132 | def forward(self, x):
133 | x = self.static_padding(x)
134 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
135 | return x
136 |
137 |
138 | class Identity(nn.Module):
139 | def __init__(self, ):
140 | super(Identity, self).__init__()
141 |
142 | def forward(self, input):
143 | return input
144 |
145 |
146 | class BlockDecoder(object):
147 | """ Block Decoder for readability, straight from the official TensorFlow repository """
148 |
149 | def _decode_block_string(block_string):
150 | """ Gets a block through a string notation of arguments. """
151 | assert isinstance(block_string, str)
152 |
153 | ops = block_string.split('_')
154 | options = {}
155 | for op in ops:
156 | splits = re.split(r'(\d.*)', op)
157 | if len(splits) >= 2:
158 | key, value = splits[:2]
159 | options[key] = value
160 |
161 | # Check stride
162 | assert (('s' in options and len(options['s']) == 1) or
163 | (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
164 |
165 | return BlockArgs(
166 | kernel_size=int(options['k']),
167 | num_repeat=int(options['r']),
168 | input_filters=int(options['i']),
169 | output_filters=int(options['o']),
170 | expand_ratio=int(options['e']),
171 | id_skip=('noskip' not in block_string),
172 | se_ratio=float(options['se']) if 'se' in options else None,
173 | stride=[int(options['s'][0])])
174 |
175 | def decode(string_list):
176 | """
177 | Decodes a list of string notations to specify blocks inside the network.
178 | :param string_list: a list of strings, each string is a notation of block
179 | :return: a list of BlockArgs namedtuples of block args
180 | """
181 | assert isinstance(string_list, list)
182 | blocks_args = []
183 | for block_string in string_list:
184 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
185 | return blocks_args
186 |
187 |
188 | class MBConvBlock(nn.Module):
189 | """
190 | Mobile Inverted Residual Bottleneck Block
191 | Args:
192 | block_args (namedtuple): BlockArgs, see above
193 | global_params (namedtuple): GlobalParam, see above
194 | Attributes:
195 | has_se (bool): Whether the block contains a Squeeze and Excitation layer.
196 | """
197 |
198 | def __init__(self, block_args, global_params, image_size=None, drop_connect_rate=0.2):
199 | super().__init__()
200 | self._block_args = block_args
201 | self._bn_mom = 1 - global_params.batch_norm_momentum
202 | self._bn_eps = global_params.batch_norm_epsilon
203 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
204 | self.id_skip = block_args.id_skip # skip connection and drop connect
205 | self.drop_connect_rate = drop_connect_rate
206 |
207 | # Expansion phase
208 | inp = self._block_args.input_filters # number of input channels
209 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
210 | if self._block_args.expand_ratio != 1:
211 | Conv2d = get_same_padding_conv2d(image_size=image_size)
212 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
213 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
214 | # image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing
215 |
216 | # Depthwise convolution phase
217 | k = self._block_args.kernel_size
218 | s = self._block_args.stride
219 | Conv2d = get_same_padding_conv2d(image_size=image_size)
220 | self._depthwise_conv = Conv2d(
221 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
222 | kernel_size=k, stride=s, bias=False)
223 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
224 | image_size = calculate_output_image_size(image_size, s)
225 |
226 | # Squeeze and Excitation layer, if desired
227 | if self.has_se:
228 | Conv2d = get_same_padding_conv2d(image_size=(1,1))
229 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
230 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
231 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
232 |
233 | # Output phase
234 | final_oup = self._block_args.output_filters
235 | Conv2d = get_same_padding_conv2d(image_size=image_size)
236 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
237 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
238 | self._swish = Swish()
239 |
240 | def forward(self, inputs):
241 | """
242 | :param inputs: input tensor
243 | :param drop_connect_rate: drop connect rate (float, between 0 and 1)
244 | :return: output of block
245 | """
246 | # Expansion and Depthwise Convolution
247 | x = inputs
248 | if self._block_args.expand_ratio != 1:
249 | x = self._swish(self._bn0(self._expand_conv(inputs)))
250 | x = self._swish(self._bn1(self._depthwise_conv(x)))
251 |
252 | # Squeeze and Excitation
253 | if self.has_se:
254 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
255 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
256 | x = torch.sigmoid(x_squeezed) * x
257 |
258 | x = self._bn2(self._project_conv(x))
259 |
260 | # Skip connection and drop connect
261 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
262 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
263 | if self.drop_connect_rate:
264 | x = drop_connect(x, p=self.drop_connect_rate, training=self.training)
265 | x = x + inputs # skip connection
266 | return x
267 |
268 |
269 | class EfficientNet(nn.Module):
270 | """
271 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
272 | Args:
273 | blocks_args (list): A list of BlockArgs to construct blocks
274 | global_params (namedtuple): A set of GlobalParams shared between blocks
275 | Example:
276 | model = EfficientNet.from_pretrained('efficientnet-b0')
277 | """
278 |
279 | def __init__(self, selfcon_pos=[False]):
280 | super().__init__()
281 | blocks_args = [
282 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
283 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
284 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
285 | 'r1_k3_s11_e6_i192_o320_se0.25',
286 | ]
287 |
288 | blocks_args[1] = 'r2_k3_s11_e6_i16_o24_se0.25'
289 | blocks_args = BlockDecoder.decode(blocks_args)
290 |
291 | params = {'b0': (1.0, 1.0, 32, 0.2), 'b1': (1.0, 1.1, 34, 0.2), 'b2': (1.1, 1.2, 38, 0.3)}
292 | w, d, s, p = params['b0']
293 |
294 | global_params = GlobalParams(
295 | batch_norm_momentum=0.99,
296 | batch_norm_epsilon=1e-3,
297 | dropout_rate=p,
298 | drop_connect_rate=0.2,
299 | width_coefficient=w,
300 | depth_coefficient=d,
301 | depth_divisor=8,
302 | min_depth=None,
303 | image_size=s,
304 | )
305 |
306 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
307 | assert len(blocks_args) > 0, 'block args must be greater than 0'
308 | self._global_params = global_params
309 | self._blocks_args = blocks_args
310 |
311 | # Batch norm parameters
312 | bn_mom = 1 - self._global_params.batch_norm_momentum
313 | bn_eps = self._global_params.batch_norm_epsilon
314 |
315 | # Get stem static or dynamic convolution depending on image size
316 | image_size = global_params.image_size
317 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
318 |
319 | # Stem
320 | in_channels = 3 # rgb
321 | out_channels = round_filters(32, self._global_params) # number of output channels
322 | stride = 1
323 |
324 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, bias=False)
325 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
326 |
327 | # Build blocks
328 | block_layers = []
329 | drop_connect_rate = self._global_params.drop_connect_rate
330 | num = 0
331 | for b in self._blocks_args:
332 | num += b.num_repeat
333 |
334 | index = 0
335 | for block_args in self._blocks_args:
336 | layers = []
337 | # Update block input and output filters based on depth multiplier.
338 | block_args = block_args._replace(
339 | input_filters=round_filters(block_args.input_filters, self._global_params),
340 | output_filters=round_filters(block_args.output_filters, self._global_params),
341 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
342 | )
343 |
344 | # The first block needs to take care of stride and filter size increase.
345 | layers.append(MBConvBlock(block_args, self._global_params, image_size=image_size, drop_connect_rate=drop_connect_rate*index/num))
346 | index += 1
347 | image_size = calculate_output_image_size(image_size, block_args.stride)
348 | if block_args.num_repeat > 1:
349 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
350 | for _ in range(block_args.num_repeat - 1):
351 | layers.append(MBConvBlock(block_args, self._global_params, image_size=image_size, drop_connect_rate=drop_connect_rate*index/num))
352 | index += 1
353 |
354 | block_layers.append(nn.Sequential(*layers))
355 | self.block_layers= nn.ModuleList(block_layers)
356 |
357 | # Head
358 | in_channels = block_args.output_filters # output of final block
359 | out_channels = round_filters(512, self._global_params)
360 | self.final_channels = out_channels
361 |
362 | Conv2d = get_same_padding_conv2d(image_size=image_size)
363 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
364 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
365 |
366 | # Final linear layer
367 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
368 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
369 | self._swish = Swish()
370 |
371 | sub_conv = []
372 | sub_conv.append(Conv2d(80, self.final_channels, kernel_size=1, bias=False))
373 | sub_conv.append(nn.BatchNorm2d(num_features=self.final_channels, momentum=bn_mom, eps=bn_eps))
374 | sub_conv.append(Swish())
375 |
376 | self.selfcon_layer = self._make_sub_layer(selfcon_pos, nn.Sequential(*sub_conv))
377 |
378 | # simply test with nn.Linear
379 | def _make_sub_layer(self, pos, sub_conv):
380 | pos = pos[0]
381 | if not pos:
382 | return None
383 | else:
384 | return nn.ModuleList([sub_conv, nn.Linear(self.final_channels, self.final_channels)])
385 |
386 | # Stem
387 | def conv_stem(self, x):
388 | x = self._swish(self._bn0(self._conv_stem(x)))
389 |
390 | return x
391 |
392 | def pool_linear(self, feat):
393 | # Head
394 | feat = self._swish(self._bn1(self._conv_head(feat)))
395 |
396 | # Pooling and final linear layer
397 | feat = self._avg_pooling(feat)
398 | features = feat.view(feat.size(0), -1)
399 | features = self._dropout(features)
400 |
401 | return features
402 |
403 | def forward(self, x):
404 | sub_out = []
405 |
406 | x = self.conv_stem(x)
407 |
408 | for i in range(4):
409 | x = self.block_layers[i](x)
410 |
411 | if self.selfcon_layer is not None:
412 | out = self.selfcon_layer[0](x)
413 | out = torch.flatten(self._avg_pooling(out), 1)
414 | out = self._dropout(out)
415 | out = self.selfcon_layer[1](out)
416 | sub_out.append(out)
417 |
418 | for i in range(4, len(self.block_layers)):
419 | x = self.block_layers[i](x)
420 |
421 | features = self.pool_linear(x)
422 |
423 | return sub_out, features
424 |
425 |
426 | def efficientnet(**kwargs):
427 | return EfficientNet(**kwargs)
428 |
429 |
430 | model_dict = {
431 | 'efficientnet': [efficientnet, 512]
432 | }
433 |
434 |
435 | class ConEfficientNet(nn.Module):
436 | """backbone + projection head"""
437 | def __init__(self, name='efficientnet', head='mlp', feat_dim=128, selfcon_pos=[False,False,False], selfcon_arch='resnet', selfcon_size='same', dataset=''):
438 | super(ConEfficientNet, self).__init__()
439 | model_fun, dim_in = model_dict[name]
440 | self.encoder = model_fun(selfcon_pos=selfcon_pos)
441 | if head == 'linear':
442 | self.head = nn.Linear(dim_in, feat_dim)
443 |
444 | self.sub_heads = []
445 | for pos in selfcon_pos:
446 | if pos:
447 | self.sub_heads.append(nn.Linear(dim_in, feat_dim))
448 | elif head == 'mlp':
449 | self.head = nn.Sequential(
450 | nn.Linear(dim_in, dim_in),
451 | nn.ReLU(inplace=True),
452 | nn.Linear(dim_in, feat_dim)
453 | )
454 |
455 | heads = []
456 | for pos in selfcon_pos:
457 | if pos:
458 | heads.append(nn.Sequential(
459 | nn.Linear(dim_in, dim_in),
460 | nn.ReLU(inplace=True),
461 | nn.Linear(dim_in, feat_dim)
462 | ))
463 | self.sub_heads = nn.ModuleList(heads)
464 | else:
465 | raise NotImplementedError(
466 | 'head not supported: {}'.format(head))
467 |
468 | def forward(self, x):
469 | sub_feat, feat = self.encoder(x)
470 |
471 | sh_feat = []
472 | for sf, sub_head in zip(sub_feat, self.sub_heads):
473 | sh_feat.append(F.normalize(sub_head(sf), dim=1))
474 |
475 | feat = F.normalize(self.head(feat), dim=1)
476 | return sh_feat, feat
477 |
478 |
479 | class CEEffNet(nn.Module):
480 | """encoder + classifier"""
481 | def __init__(self, name='efficientnet', method='ce', num_classes=10, dim_out=128, selfcon_pos=[False], selfcon_arch='resnet', selfcon_size='same', dataset=''):
482 | super(CEEffNet, self).__init__()
483 | self.method = method
484 |
485 | model_fun, dim_in = model_dict[name]
486 | self.encoder = model_fun(selfcon_pos=selfcon_pos)
487 |
488 | logit_fcs, feat_fcs = [], []
489 | for pos in selfcon_pos:
490 | if pos:
491 | logit_fcs.append(nn.Linear(dim_in, num_classes))
492 | feat_fcs.append(nn.Linear(dim_in, dim_out))
493 |
494 | self.logit_fc = nn.ModuleList(logit_fcs)
495 | self.l_fc = nn.Linear(dim_in, num_classes)
496 |
497 | if method not in ['ce', 'subnet_ce', 'kd']:
498 | self.feat_fc = nn.ModuleList(feat_fcs)
499 | self.f_fc = nn.Linear(dim_in, dim_out)
500 | for param in self.f_fc.parameters():
501 | param.requires_grad = False
502 |
503 | def forward(self, x):
504 | sub_feat, feat = self.encoder(x)
505 |
506 | feats, logits = [], []
507 |
508 | for idx, sh_feat in enumerate(sub_feat):
509 | logits.append(self.logit_fc[idx](sh_feat))
510 | if self.method not in ['ce', 'subnet_ce', 'kd']:
511 | out = self.feat_fc[idx](sh_feat)
512 | feats.append(F.normalize(out, dim=1))
513 |
514 | if self.method not in ['ce', 'subnet_ce', 'kd']:
515 | return [feats, F.normalize(self.f_fc(feat), dim=1)], [logits, self.l_fc(feat)]
516 | else:
517 | return [logits, self.l_fc(feat)]
518 |
519 |
520 | class LinearClassifier_EFF(nn.Module):
521 | """Linear classifier"""
522 | def __init__(self, name='efficientnet', num_classes=100):
523 | super(LinearClassifier_EFF, self).__init__()
524 | _, feat_dim = model_dict[name]
525 | self.fc = nn.Linear(feat_dim, num_classes)
526 |
527 | def forward(self, features):
528 | return self.fc(features)
--------------------------------------------------------------------------------