├── codebase
├── __init__.py
├── networks
│ ├── __init__.py
│ └── nsganetv2.py
├── data_providers
│ ├── stl10.py
│ ├── flowers102.py
│ ├── imagenet.py
│ ├── pets.py
│ ├── dtd.py
│ ├── autoaugment.py
│ └── aircraft.py
└── run_manager
│ └── __init__.py
├── assets
├── c10.gif
├── c10.png
├── c100.png
├── dtd.png
├── pets.png
├── cinic10.png
├── stl-10.png
├── aircraft.png
├── imagenet.gif
├── imagenet.png
├── overview.png
└── flowers102.png
├── scripts
├── distributed_train.sh
└── search.sh
├── acc_predictor
├── factory.py
├── rbf.py
├── gp.py
├── adaptive_switching.py
├── carts.py
└── mlp.py
├── .gitignore
├── search_space
└── ofa.py
├── validation.py
├── post_search.py
├── LICENSE
├── utils.py
├── evaluator.py
├── README.md
├── train_cifar.py
└── msunas.py
/codebase/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/c10.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/c10.gif
--------------------------------------------------------------------------------
/assets/c10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/c10.png
--------------------------------------------------------------------------------
/assets/c100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/c100.png
--------------------------------------------------------------------------------
/assets/dtd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/dtd.png
--------------------------------------------------------------------------------
/assets/pets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/pets.png
--------------------------------------------------------------------------------
/assets/cinic10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/cinic10.png
--------------------------------------------------------------------------------
/assets/stl-10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/stl-10.png
--------------------------------------------------------------------------------
/assets/aircraft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/aircraft.png
--------------------------------------------------------------------------------
/assets/imagenet.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/imagenet.gif
--------------------------------------------------------------------------------
/assets/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/imagenet.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/overview.png
--------------------------------------------------------------------------------
/assets/flowers102.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/flowers102.png
--------------------------------------------------------------------------------
/scripts/distributed_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | python -m torch.distributed.launch --nproc_per_node=$NUM_PROC train_imagenet.py "$@"
--------------------------------------------------------------------------------
/codebase/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from ofa.imagenet_classification.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2
2 | from ofa.imagenet_classification.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large
3 | from codebase.networks.nsganetv2 import NSGANetV2
4 |
5 |
--------------------------------------------------------------------------------
/acc_predictor/factory.py:
--------------------------------------------------------------------------------
1 | def get_acc_predictor(model, inputs, targets):
2 |
3 | if model == 'rbf':
4 | from acc_predictor.rbf import RBF
5 | acc_predictor = RBF()
6 | acc_predictor.fit(inputs, targets)
7 |
8 | elif model == 'carts':
9 | from acc_predictor.carts import CART
10 | acc_predictor = CART(n_tree=5000)
11 | acc_predictor.fit(inputs, targets)
12 |
13 | elif model == 'gp':
14 | from acc_predictor.gp import GP
15 | acc_predictor = GP()
16 | acc_predictor.fit(inputs, targets)
17 |
18 | elif model == 'mlp':
19 | from acc_predictor.mlp import MLP
20 | acc_predictor = MLP(n_feature=inputs.shape[1])
21 | acc_predictor.fit(x=inputs, y=targets)
22 |
23 | elif model == 'as':
24 | from acc_predictor.adaptive_switching import AdaptiveSwitching
25 | acc_predictor = AdaptiveSwitching()
26 | acc_predictor.fit(inputs, targets)
27 |
28 | else:
29 | raise NotImplementedError
30 |
31 | return acc_predictor
32 |
33 |
--------------------------------------------------------------------------------
/acc_predictor/rbf.py:
--------------------------------------------------------------------------------
1 | from pySOT.surrogate import RBFInterpolant, CubicKernel, TPSKernel, LinearTail, ConstantTail
2 |
3 |
4 | class RBF:
5 | """ Radial Basis Function """
6 |
7 | def __init__(self, kernel='cubic', tail='linear'):
8 | self.kernel = kernel
9 | self.tail = tail
10 | self.name = 'rbf'
11 | self.model = None
12 |
13 | def fit(self, train_data, train_label):
14 | if self.kernel == 'cubic':
15 | kernel = CubicKernel
16 | elif self.kernel == 'tps':
17 | kernel = TPSKernel
18 | else:
19 | raise NotImplementedError("unknown RBF kernel")
20 |
21 | if self.tail == 'linear':
22 | tail = LinearTail
23 | elif self.tail == 'constant':
24 | tail = ConstantTail
25 | else:
26 | raise NotImplementedError("unknown RBF tail")
27 |
28 | self.model = RBFInterpolant(
29 | dim=train_data.shape[1], kernel=kernel(), tail=tail(train_data.shape[1]))
30 |
31 | for i in range(len(train_data)):
32 | self.model.add_points(train_data[i, :], train_label[i])
33 |
34 | def predict(self, test_data):
35 | assert self.model is not None, "RBF model does not exist, call fit to obtain rbf model first"
36 | return self.model.predict(test_data)
37 |
--------------------------------------------------------------------------------
/acc_predictor/gp.py:
--------------------------------------------------------------------------------
1 | from pydacefit.regr import regr_constant
2 | from pydacefit.dace import DACE, regr_linear, regr_quadratic
3 | from pydacefit.corr import corr_gauss, corr_cubic, corr_exp, corr_expg, corr_spline, corr_spherical
4 |
5 |
6 | class GP:
7 | """ Gaussian Process (Kriging) """
8 | def __init__(self, regr='linear', corr='gauss'):
9 | self.regr = regr
10 | self.corr = corr
11 | self.name = 'gp'
12 | self.model = None
13 |
14 | def fit(self, train_data, train_label):
15 | if self.regr == 'linear':
16 | regr = regr_linear
17 | elif self.regr == 'constant':
18 | regr = regr_constant
19 | elif self.regr == 'quadratic':
20 | regr = regr_quadratic
21 | else:
22 | raise NotImplementedError("unknown GP regression")
23 |
24 | if self.corr == 'gauss':
25 | corr = corr_gauss
26 | elif self.corr == 'cubic':
27 | corr = corr_cubic
28 | elif self.corr == 'exp':
29 | corr = corr_exp
30 | elif self.corr == 'expg':
31 | corr = corr_expg
32 | elif self.corr == 'spline':
33 | corr = corr_spline
34 | elif self.corr == 'spherical':
35 | corr = corr_spherical
36 | else:
37 | raise NotImplementedError("unknown GP correlation")
38 |
39 | self.model = DACE(
40 | regr=regr, corr=corr, theta=1.0, thetaL=0.00001, thetaU=100)
41 | self.model.fit(train_data, train_label)
42 |
43 | def predict(self, test_data):
44 | assert self.model is not None, "GP does not exist, call fit to obtain GP first"
45 | return self.model.predict(test_data)
46 |
--------------------------------------------------------------------------------
/acc_predictor/adaptive_switching.py:
--------------------------------------------------------------------------------
1 | import utils
2 | import numpy as np
3 | from acc_predictor.factory import get_acc_predictor
4 |
5 |
6 | class AdaptiveSwitching:
7 | """ ensemble surrogate model """
8 | """ try all available models, pick one based on 10-fold crx vld """
9 | def __init__(self, n_fold=10):
10 | # self.model_pool = ['rbf', 'gp', 'mlp', 'carts']
11 | self.model_pool = ['rbf', 'gp', 'carts']
12 | self.n_fold = n_fold
13 | self.name = 'adaptive switching'
14 | self.model = None
15 |
16 | def fit(self, train_data, train_target):
17 | self._n_fold_validation(train_data, train_target, n=self.n_fold)
18 |
19 | def _n_fold_validation(self, train_data, train_target, n=10):
20 |
21 | n_samples = len(train_data)
22 | perm = np.random.permutation(n_samples)
23 |
24 | kendall_tau = np.full((n, len(self.model_pool)), np.nan)
25 |
26 | for i, tst_split in enumerate(np.array_split(perm, n)):
27 | trn_split = np.setdiff1d(perm, tst_split, assume_unique=True)
28 |
29 | # loop over all considered surrogate model in pool
30 | for j, model in enumerate(self.model_pool):
31 |
32 | acc_predictor = get_acc_predictor(model, train_data[trn_split], train_target[trn_split])
33 |
34 | rmse, rho, tau = utils.get_correlation(
35 | acc_predictor.predict(train_data[tst_split]), train_target[tst_split])
36 |
37 | kendall_tau[i, j] = tau
38 |
39 | winner = int(np.argmax(np.mean(kendall_tau, axis=0) - np.std(kendall_tau, axis=0)))
40 | print("winner model = {}, tau = {}".format(self.model_pool[winner],
41 | np.mean(kendall_tau, axis=0)[winner]))
42 | self.winner = self.model_pool[winner]
43 | # re-fit the winner model with entire data
44 | acc_predictor = get_acc_predictor(self.model_pool[winner], train_data, train_target)
45 | self.model = acc_predictor
46 |
47 | def predict(self, test_data):
48 | return self.model.predict(test_data)
49 |
--------------------------------------------------------------------------------
/acc_predictor/carts.py:
--------------------------------------------------------------------------------
1 | # implementation based on
2 | # https://github.com/yn-sun/e2epp/blob/master/build_predict_model.py
3 | # and https://github.com/HandingWang/RF-CMOCO
4 | import numpy as np
5 | from sklearn.tree import DecisionTreeRegressor
6 |
7 |
8 | class CART:
9 | """ Classification and Regression Tree """
10 | def __init__(self, n_tree=1000):
11 | self.n_tree = n_tree
12 | self.name = 'carts'
13 | self.model = None
14 |
15 | @staticmethod
16 | def _make_decision_trees(train_data, train_label, n_tree):
17 | feature_record = []
18 | tree_record = []
19 |
20 | for i in range(n_tree):
21 | sample_idx = np.arange(train_data.shape[0])
22 | np.random.shuffle(sample_idx)
23 | train_data = train_data[sample_idx, :]
24 | train_label = train_label[sample_idx]
25 |
26 | feature_idx = np.arange(train_data.shape[1])
27 | np.random.shuffle(feature_idx)
28 | n_feature = np.random.randint(1, train_data.shape[1] + 1)
29 | selected_feature_ids = feature_idx[0:n_feature]
30 | feature_record.append(selected_feature_ids)
31 |
32 | dt = DecisionTreeRegressor()
33 | dt.fit(train_data[:, selected_feature_ids], train_label)
34 | tree_record.append(dt)
35 |
36 | return tree_record, feature_record
37 |
38 | def fit(self, train_data, train_label):
39 | self.model = self._make_decision_trees(train_data, train_label, self.n_tree)
40 |
41 | def predict(self, test_data):
42 | assert self.model is not None, "carts does not exist, call fit to obtain cart first"
43 |
44 | # redundant variable device
45 | trees, features = self.model[0], self.model[1]
46 | test_num, n_tree = len(test_data), len(trees)
47 |
48 | predict_labels = np.zeros((test_num, 1))
49 | for i in range(test_num):
50 | this_test_data = test_data[i, :]
51 | predict_this_list = np.zeros(n_tree)
52 |
53 | for j, (tree, feature) in enumerate(zip(trees, features)):
54 | predict_this_list[j] = tree.predict([this_test_data[feature]])[0]
55 |
56 | # find the top 100 prediction
57 | predict_this_list = np.sort(predict_this_list)
58 | predict_this_list = predict_this_list[::-1]
59 | this_predict = np.mean(predict_this_list)
60 | predict_labels[i, 0] = this_predict
61 |
62 | return predict_labels
63 |
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .tmp/
6 | .idea
7 | ipynb/
8 | #subnets/
9 | # C extensions
10 | *.so
11 | .DS_Store
12 | subnets/
13 | data/
14 | search-*/
15 | sample_run/
16 | # Distribution / packaging
17 | .Python
18 | eccv20/
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | pip-wheel-metadata/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 | *.pdf
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 | *.pptx
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 | experiments/
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | *.o
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 | /pretrained/
141 | .idea
142 |
--------------------------------------------------------------------------------
/search_space/ofa.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class OFASearchSpace:
5 | def __init__(self):
6 | self.num_blocks = 5 # number of blocks
7 | self.kernel_size = [3, 5, 7] # depth-wise conv kernel size
8 | self.exp_ratio = [3, 4, 6] # expansion rate
9 | self.depth = [2, 3, 4] # number of Inverted Residual Bottleneck layers repetition
10 | self.resolution = list(range(192, 257, 4)) # input image resolutions
11 |
12 | def sample(self, n_samples=1, nb=None, ks=None, e=None, d=None, r=None):
13 | """ randomly sample a architecture"""
14 | nb = self.num_blocks if nb is None else nb
15 | ks = self.kernel_size if ks is None else ks
16 | e = self.exp_ratio if e is None else e
17 | d = self.depth if d is None else d
18 | r = self.resolution if r is None else r
19 |
20 | data = []
21 | for n in range(n_samples):
22 | # first sample layers
23 | depth = np.random.choice(d, nb, replace=True).tolist()
24 | # then sample kernel size, expansion rate and resolution
25 | kernel_size = np.random.choice(ks, size=int(np.sum(depth)), replace=True).tolist()
26 | exp_ratio = np.random.choice(e, size=int(np.sum(depth)), replace=True).tolist()
27 | resolution = int(np.random.choice(r))
28 |
29 | data.append({'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'r': resolution})
30 | return data
31 |
32 | def initialize(self, n_doe):
33 | # sample one arch with least (lb of hyperparameters) and most complexity (ub of hyperparameters)
34 | data = [
35 | self.sample(1, ks=[min(self.kernel_size)], e=[min(self.exp_ratio)],
36 | d=[min(self.depth)], r=[min(self.resolution)])[0],
37 | self.sample(1, ks=[max(self.kernel_size)], e=[max(self.exp_ratio)],
38 | d=[max(self.depth)], r=[max(self.resolution)])[0]
39 | ]
40 | data.extend(self.sample(n_samples=n_doe - 2))
41 | return data
42 |
43 | def pad_zero(self, x, depth):
44 | # pad zeros to make bit-string of equal length
45 | new_x, counter = [], 0
46 | for d in depth:
47 | for _ in range(d):
48 | new_x.append(x[counter])
49 | counter += 1
50 | if d < max(self.depth):
51 | new_x += [0] * (max(self.depth) - d)
52 | return new_x
53 |
54 | def encode(self, config):
55 | # encode config ({'ks': , 'd': , etc}) to integer bit-string [1, 0, 2, 1, ...]
56 | x = []
57 | depth = [np.argwhere(_x == np.array(self.depth))[0, 0] for _x in config['d']]
58 | kernel_size = [np.argwhere(_x == np.array(self.kernel_size))[0, 0] for _x in config['ks']]
59 | exp_ratio = [np.argwhere(_x == np.array(self.exp_ratio))[0, 0] for _x in config['e']]
60 |
61 | kernel_size = self.pad_zero(kernel_size, config['d'])
62 | exp_ratio = self.pad_zero(exp_ratio, config['d'])
63 |
64 | for i in range(len(depth)):
65 | x = x + [depth[i]] + kernel_size[i * max(self.depth):i * max(self.depth) + max(self.depth)] \
66 | + exp_ratio[i * max(self.depth):i * max(self.depth) + max(self.depth)]
67 | x.append(np.argwhere(config['r'] == np.array(self.resolution))[0, 0])
68 |
69 | return x
70 |
71 | def decode(self, x):
72 | """
73 | remove un-expressed part of the chromosome
74 | assumes x = [block1, block2, ..., block5, resolution, width_mult];
75 | block_i = [depth, kernel_size, exp_rate]
76 | """
77 | depth, kernel_size, exp_rate = [], [], []
78 | for i in range(0, len(x) - 2, 9):
79 | depth.append(self.depth[x[i]])
80 | kernel_size.extend(np.array(self.kernel_size)[x[i + 1:i + 1 + self.depth[x[i]]]].tolist())
81 | exp_rate.extend(np.array(self.exp_ratio)[x[i + 5:i + 5 + self.depth[x[i]]]].tolist())
82 | return {'ks': kernel_size, 'e': exp_rate, 'd': depth, 'r': self.resolution[x[-1]]}
83 |
84 |
--------------------------------------------------------------------------------
/scripts/search.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Search Examples
4 | ## In general, set `--n_iter` to `--n_gpus`, which is the # of available gpus you have
5 |
6 | # Maximize Top-1 Accuracy and Minimize #FLOPs on ImageNet
7 | #python msunas.py --sec_obj flops \
8 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \
9 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \
10 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
11 | # --save search-imagenet-flops-w1.0 --iterations 30 --vld_size 10000
12 |
13 | # Maximize Top-1 Accuracy and Minimize #Params on ImageNet
14 | #python msunas.py --sec_obj params \
15 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \
16 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \
17 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
18 | # --save search-imagenet-params-w1.0 --iterations 30 --vld_size 10000
19 |
20 | # Maximize Top-1 Accuracy and Minimize CPU Latency on ImageNet
21 | #python msunas.py --sec_obj cpu \
22 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \
23 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \
24 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
25 | # --save search-imagenet-cpu-w1.0 --iterations 30 --vld_size 10000
26 |
27 | # Maximize Top-1 Accuracy and Minimize #FLOPs on CIFAR-10
28 | #python msunas.py --sec_obj flops \
29 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \
30 | # --dataset cifar10 --n_classes 10 \
31 | # --data /usr/local/soft/temp-datastore/CIFAR/ \
32 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
33 | # --save search-cifar10-flops-w1.0 --iterations 30 --vld_size 5000
34 |
35 | # Maximize Top-1 Accuracy and Minimize #FLOPs on CIFAR-100
36 | #python msunas.py --sec_obj flops \
37 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \
38 | # --dataset cifar100 --n_classes 100 \
39 | # --data /usr/local/soft/temp-datastore/CIFAR/ \
40 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
41 | # --save search-cifar100-flops-w1.0 --iterations 30 --vld_size 5000
42 |
43 | # Maximize Top-1 Accuracy and Minimize #FLOPs on CINIC10
44 | #python msunas.py --sec_obj flops \
45 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \
46 | # --dataset cinic10 --n_classes 10 \
47 | # --data /usr/local/soft/temp-datastore/CINIC/ \
48 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
49 | # --save search-cinic10-flops-w1.0 --iterations 30 --vld_size 10000
50 |
51 | # Maximize Top-1 Accuracy and Minimize #FLOPs on STL-10
52 | #python msunas.py --sec_obj flops \
53 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \
54 | # --dataset stl10 --n_classes 10 \
55 | # --data /usr/local/soft/temp-datastore/STL10/ \
56 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
57 | # --save search-stl10-flops-w1.0 --iterations 30 --vld_size 500
58 |
59 | # Maximize Top-1 Accuracy and Minimize #FLOPs on Aircraft
60 | #python msunas.py --sec_obj flops \
61 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \
62 | # --dataset aircraft --n_classes 100 \
63 | # --data /usr/local/soft/temp-datastore/Aircraft/ \
64 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
65 | # --save search-aircraft-flops-w1.0 --iterations 30 --vld_size 500
66 |
67 | # Random Search on ImageNet (set `--n_doe` to the total # of archs you want to sample)
68 | #python msunas.py --n_doe 350 --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \
69 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \
70 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
71 | # --save random-imagenet-w1.0 --iterations 0 --vld_size 10000
72 |
73 | # Random Search on CIFAR-10 (set `--n_doe` to the total # of archs you want to sample)
74 | #python msunas.py --n_doe 350 --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \
75 | # --dataset cifar10 --n_classes 10 \
76 | # --data /usr/local/soft/temp-datastore/CIFAR/ \
77 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \
78 | # --save random-cifar10-w1.0 --iterations 0 --vld_size 5000
79 |
80 |
--------------------------------------------------------------------------------
/validation.py:
--------------------------------------------------------------------------------
1 | import time
2 | import json
3 | import torch
4 | import logging
5 | import argparse
6 | from collections import OrderedDict
7 |
8 | from timm.utils import accuracy, AverageMeter, setup_default_logging
9 |
10 | from codebase.run_manager import get_run_config
11 | from codebase.networks.nsganetv2 import NSGANetV2
12 |
13 |
14 | def validate(model, loader, criterion, log_freq=50):
15 | batch_time = AverageMeter()
16 | losses = AverageMeter()
17 | top1 = AverageMeter()
18 | top5 = AverageMeter()
19 |
20 | model.eval()
21 | end = time.time()
22 | with torch.no_grad():
23 | for i, (input, target) in enumerate(loader):
24 | target = target.cuda()
25 | input = input.cuda()
26 |
27 | # compute output
28 | output = model(input)
29 | loss = criterion(output, target)
30 |
31 | # measure accuracy and record loss
32 | acc1, acc5 = accuracy(output.data, target, topk=(1, 5))
33 | losses.update(loss.item(), input.size(0))
34 | top1.update(acc1.item(), input.size(0))
35 | top5.update(acc5.item(), input.size(0))
36 |
37 | # measure elapsed time
38 | batch_time.update(time.time() - end)
39 | end = time.time()
40 |
41 | if i % log_freq == 0:
42 | logging.info(
43 | 'Test: [{0:>4d}/{1}] '
44 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
45 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
46 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
47 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
48 | i, len(loader), batch_time=batch_time,
49 | rate_avg=input.size(0) / batch_time.avg,
50 | loss=losses, top1=top1, top5=top5))
51 |
52 | results = OrderedDict(
53 | top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
54 | top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4))
55 |
56 | logging.info(' * Acc@1 {:.1f} ({:.3f}) Acc@5 {:.1f} ({:.3f})'.format(
57 | results['top1'], results['top1_err'], results['top5'], results['top5_err']))
58 |
59 |
60 | def main(args):
61 | setup_default_logging()
62 |
63 | logging.info('Running validation on {}'.format(args.dataset))
64 |
65 | net_config = json.load(open(args.model))
66 | if 'img_size' in net_config:
67 | img_size = net_config['img_size']
68 | else:
69 | img_size = args.img_size
70 |
71 | run_config = get_run_config(
72 | dataset=args.dataset, data_path=args.data, image_size=img_size, n_epochs=0,
73 | train_batch_size=args.batch_size, test_batch_size=args.batch_size,
74 | n_worker=args.workers, valid_size=None)
75 |
76 | model = NSGANetV2.build_from_config(net_config)
77 | try:
78 | model.load_state_dict(torch.load(args.pretrained, map_location='cpu'))
79 | except:
80 | model.load_state_dict(torch.load(args.pretrained, map_location='cpu')['state_dict'])
81 |
82 | param_count = sum([m.numel() for m in model.parameters()])
83 | logging.info('Model created, param count: %d' % param_count)
84 |
85 | model = model.cuda()
86 | criterion = torch.nn.CrossEntropyLoss().cuda()
87 |
88 | validate(model, run_config.test_loader, criterion)
89 |
90 | return
91 |
92 |
93 | if __name__ == '__main__':
94 | parser = argparse.ArgumentParser()
95 | # data related settings
96 | parser.add_argument('--data', type=str, default='/mnt/datastore/ILSVRC2012',
97 | help='location of the data corpus')
98 | parser.add_argument('--dataset', type=str, default='imagenet',
99 | help='name of the dataset (imagenet, cifar10, cifar100, ...)')
100 | parser.add_argument('-j', '--workers', type=int, default=6,
101 | help='number of workers for data loading')
102 | parser.add_argument('-b', '--batch-size', type=int, default=256,
103 | help='test batch size for inference')
104 | parser.add_argument('--img-size', type=int, default=224,
105 | help='input resolution (192 -> 256)')
106 | # model related settings
107 | parser.add_argument('--model', '-m', metavar='MODEL', default='', type=str,
108 | help='model configuration file')
109 | parser.add_argument('--pretrained', type=str, default='',
110 | help='path to pretrained weights')
111 | cfgs = parser.parse_args()
112 |
113 | main(cfgs)
114 |
--------------------------------------------------------------------------------
/post_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | from pymoo.factory import get_decomposition
6 | from pymoo.visualization.scatter import Scatter
7 | from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
8 | from pymoo.model.decision_making import DecisionMaking, normalize, find_outliers_upper_tail, NeighborFinder
9 |
10 | _DEBUG = False
11 |
12 |
13 | class HighTradeoffPoints(DecisionMaking):
14 |
15 | def __init__(self, epsilon=0.125, n_survive=None, **kwargs) -> None:
16 | super().__init__(**kwargs)
17 | self.epsilon = epsilon
18 | self.n_survive = n_survive # number of points to be selected
19 |
20 | def _do(self, F, **kwargs):
21 | n, m = F.shape
22 |
23 | if self.normalize:
24 | F = normalize(F, self.ideal_point, self.nadir_point, estimate_bounds_if_none=True)
25 |
26 | neighbors_finder = NeighborFinder(F, epsilon=0.125, n_min_neigbors="auto", consider_2d=False)
27 |
28 | mu = np.full(n, - np.inf)
29 |
30 | # for each solution in the set calculate the least amount of improvement per unit deterioration
31 | for i in range(n):
32 |
33 | # for each neighbour in a specific radius of that solution
34 | neighbors = neighbors_finder.find(i)
35 |
36 | # calculate the trade-off to all neighbours
37 | diff = F[neighbors] - F[i]
38 |
39 | # calculate sacrifice and gain
40 | sacrifice = np.maximum(0, diff).sum(axis=1)
41 | gain = np.maximum(0, -diff).sum(axis=1)
42 |
43 | np.warnings.filterwarnings('ignore')
44 | tradeoff = sacrifice / gain
45 |
46 | # otherwise find the one with the smalled one
47 | mu[i] = np.nanmin(tradeoff)
48 | if self.n_survive is not None:
49 | return np.argsort(mu)[-self.n_survive:]
50 | else:
51 | return find_outliers_upper_tail(mu) # return points with trade-off > 2*sigma
52 |
53 |
54 | def main(args):
55 | # preferences
56 | if args.prefer is not None:
57 | preferences = {}
58 | for p in args.prefer.split("+"):
59 | k, v = p.split("#")
60 | if k == 'top1':
61 | preferences[k] = 100 - float(v) # assuming top-1 accuracy
62 | else:
63 | preferences[k] = float(v)
64 | weights = np.fromiter(preferences.values(), dtype=float)
65 |
66 | archive = json.load(open(args.expr))['archive']
67 | subnets, top1, sec_obj = [v[0] for v in archive], [v[1] for v in archive], [v[2] for v in archive]
68 | sort_idx = np.argsort(top1)
69 | F = np.column_stack((top1, sec_obj))[sort_idx, :]
70 | front = NonDominatedSorting().do(F, only_non_dominated_front=True)
71 | pf = F[front, :]
72 | ps = np.array(subnets)[sort_idx][front]
73 |
74 | if args.prefer is not None:
75 | # choose the architectures thats closest to the preferences
76 | I = get_decomposition("asf").do(pf, weights).argsort()[:args.n]
77 | else:
78 | # choose the architectures with highest trade-off
79 | dm = HighTradeoffPoints(n_survive=args.n)
80 | I = dm.do(pf)
81 |
82 | # always add most accurate architectures
83 | I = np.append(I, 0)
84 |
85 | # create the supernet
86 | from evaluator import OFAEvaluator
87 | supernet = OFAEvaluator(model_path=args.supernet_path)
88 |
89 | for idx in I:
90 | save = os.path.join(args.save, "net-flops@{:.0f}".format(pf[idx, 1]))
91 | os.makedirs(save, exist_ok=True)
92 | subnet, _ = supernet.sample({'ks': ps[idx]['ks'], 'e': ps[idx]['e'], 'd': ps[idx]['d']})
93 | with open(os.path.join(save, "net.subnet"), 'w') as handle:
94 | json.dump(ps[idx], handle)
95 | supernet.save_net_config(save, subnet, "net.config")
96 | supernet.save_net(save, subnet, "net.inherited")
97 |
98 | if _DEBUG:
99 | print(ps[I])
100 | plot = Scatter()
101 | plot.add(pf, alpha=0.2)
102 | plot.add(pf[I, :], color="red", s=100)
103 | plot.show()
104 |
105 | return
106 |
107 |
108 | if __name__ == '__main__':
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument('--save', type=str, default='.tmp',
111 | help='location of dir to save')
112 | parser.add_argument('--expr', type=str, default='',
113 | help='location of search experiment dir')
114 | parser.add_argument('--prefer', type=str, default=None,
115 | help='preferences in choosing architectures (top1#80+flops#150)')
116 | parser.add_argument('-n', type=int, default=1,
117 | help='number of architectures desired')
118 | parser.add_argument('--supernet_path', type=str, default='./data/ofa_mbv3_d234_e346_k357_w1.0',
119 | help='file path to supernet weights')
120 |
121 | cfgs = parser.parse_args()
122 | main(cfgs)
123 |
--------------------------------------------------------------------------------
/acc_predictor/mlp.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | from utils import get_correlation
6 |
7 |
8 | class Net(nn.Module):
9 | # N-layer MLP
10 | def __init__(self, n_feature, n_layers=2, n_hidden=300, n_output=1, drop=0.2):
11 | super(Net, self).__init__()
12 |
13 | self.stem = nn.Sequential(nn.Linear(n_feature, n_hidden), nn.ReLU())
14 |
15 | hidden_layers = []
16 | for _ in range(n_layers):
17 | hidden_layers.append(nn.Linear(n_hidden, n_hidden))
18 | hidden_layers.append(nn.ReLU())
19 | self.hidden = nn.Sequential(*hidden_layers)
20 |
21 | self.regressor = nn.Linear(n_hidden, n_output) # output layer
22 | self.drop = nn.Dropout(p=drop)
23 |
24 | def forward(self, x):
25 | x = self.stem(x)
26 | x = self.hidden(x)
27 | x = self.drop(x)
28 | x = self.regressor(x) # linear output
29 | return x
30 |
31 | @staticmethod
32 | def init_weights(m):
33 | if type(m) == nn.Linear:
34 | n = m.in_features
35 | y = 1.0 / np.sqrt(n)
36 | m.weight.data.uniform_(-y, y)
37 | m.bias.data.fill_(0)
38 |
39 |
40 | class MLP:
41 | """ Multi Layer Perceptron """
42 | def __init__(self, **kwargs):
43 | self.model = Net(**kwargs)
44 | self.name = 'mlp'
45 |
46 | def fit(self, **kwargs):
47 | self.model = train(self.model, **kwargs)
48 |
49 | def predict(self, test_data, device='cpu'):
50 | return predict(self.model, test_data, device=device)
51 |
52 |
53 | def train(net, x, y, trn_split=0.8, pretrained=None, device='cpu',
54 | lr=8e-4, epochs=2000, verbose=False):
55 |
56 | n_samples = x.shape[0]
57 | target = torch.zeros(n_samples, 1)
58 | perm = torch.randperm(target.size(0))
59 | trn_idx = perm[:int(n_samples * trn_split)]
60 | vld_idx = perm[int(n_samples * trn_split):]
61 |
62 | inputs = torch.from_numpy(x).float()
63 | target[:, 0] = torch.from_numpy(y).float()
64 |
65 | # back-propagation training of a NN
66 | if pretrained is not None:
67 | print("Constructing MLP surrogate model with pre-trained weights")
68 | init = torch.load(pretrained, map_location='cpu')
69 | net.load_state_dict(init)
70 | best_net = copy.deepcopy(net)
71 | else:
72 | # print("Constructing MLP surrogate model with "
73 | # "sample size = {}, epochs = {}".format(x.shape[0], epochs))
74 |
75 | # initialize the weights
76 | # net.apply(Net.init_weights)
77 | net = net.to(device)
78 | optimizer = torch.optim.Adam(net.parameters(), lr=lr)
79 | criterion = nn.SmoothL1Loss()
80 | # criterion = nn.MSELoss()
81 |
82 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs), eta_min=0)
83 |
84 | best_loss = 1e33
85 | for epoch in range(epochs):
86 | trn_inputs = inputs[trn_idx]
87 | trn_labels = target[trn_idx]
88 | loss_trn = train_one_epoch(net, trn_inputs, trn_labels, criterion, optimizer, device)
89 | loss_vld = infer(net, inputs[vld_idx], target[vld_idx], criterion, device)
90 | scheduler.step()
91 |
92 | # if epoch % 500 == 0 and verbose:
93 | # print("Epoch {:4d}: trn loss = {:.4E}, vld loss = {:.4E}".format(epoch, loss_trn, loss_vld))
94 |
95 | if loss_vld < best_loss:
96 | best_loss = loss_vld
97 | best_net = copy.deepcopy(net)
98 |
99 | validate(best_net, inputs, target, device=device)
100 |
101 | return best_net.to('cpu')
102 |
103 |
104 | def train_one_epoch(net, data, target, criterion, optimizer, device):
105 | net.train()
106 | optimizer.zero_grad()
107 |
108 | data, target = data.to(device), target.to(device)
109 | pred = net(data)
110 | loss = criterion(pred, target)
111 | loss.backward()
112 | optimizer.step()
113 |
114 | return loss.item()
115 |
116 |
117 | def infer(net, data, target, criterion, device):
118 | net.eval()
119 |
120 | with torch.no_grad():
121 | data, target = data.to(device), target.to(device)
122 | pred = net(data)
123 | loss = criterion(pred, target)
124 |
125 | return loss.item()
126 |
127 |
128 | def validate(net, data, target, device):
129 | net.eval()
130 |
131 | with torch.no_grad():
132 | data, target = data.to(device), target.to(device)
133 | pred = net(data)
134 | pred, target = pred.cpu().detach().numpy(), target.cpu().detach().numpy()
135 |
136 | rmse, rho, tau = get_correlation(pred, target)
137 |
138 | # print("Validation RMSE = {:.4f}, Spearman's Rho = {:.4f}, Kendall’s Tau = {:.4f}".format(rmse, rho, tau))
139 | return rmse, rho, tau, pred, target
140 |
141 |
142 | def predict(net, query, device):
143 |
144 | if query.ndim < 2:
145 | data = torch.zeros(1, query.shape[0])
146 | data[0, :] = torch.from_numpy(query).float()
147 | else:
148 | data = torch.from_numpy(query).float()
149 |
150 | net = net.to(device)
151 | net.eval()
152 | with torch.no_grad():
153 | data = data.to(device)
154 | pred = net(data)
155 |
156 | return pred.cpu().detach().numpy()
--------------------------------------------------------------------------------
/codebase/networks/nsganetv2.py:
--------------------------------------------------------------------------------
1 | from timm.models.layers import drop_path
2 | from ofa.utils.layers import *
3 | from ofa.utils import MyModule
4 | from ofa.imagenet_classification.networks import MobileNetV3
5 |
6 |
7 | class MobileInvertedResidualBlock(MyModule):
8 | """
9 | Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
10 | imagenet_codebase/networks/proxyless_nets.py to include drop path in training
11 |
12 | """
13 | def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0):
14 | super(MobileInvertedResidualBlock, self).__init__()
15 |
16 | self.mobile_inverted_conv = mobile_inverted_conv
17 | self.shortcut = shortcut
18 | self.drop_connect_rate = drop_connect_rate
19 |
20 | def forward(self, x):
21 | if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
22 | res = x
23 | elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
24 | res = self.mobile_inverted_conv(x)
25 | else:
26 | # res = self.mobile_inverted_conv(x) + self.shortcut(x)
27 | res = self.mobile_inverted_conv(x)
28 |
29 | if self.drop_connect_rate > 0.:
30 | res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training)
31 |
32 | res += self.shortcut(x)
33 |
34 | return res
35 |
36 | @property
37 | def module_str(self):
38 | return '(%s, %s)' % (
39 | self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
40 | self.shortcut.module_str if self.shortcut is not None else None
41 | )
42 |
43 | @property
44 | def config(self):
45 | return {
46 | 'name': MobileInvertedResidualBlock.__name__,
47 | 'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
48 | 'shortcut': self.shortcut.config if self.shortcut is not None else None,
49 | }
50 |
51 | @staticmethod
52 | def build_from_config(config):
53 | mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
54 | shortcut = set_layer_from_config(config['shortcut'])
55 | return MobileInvertedResidualBlock(
56 | mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate'])
57 |
58 |
59 | class NSGANetV2(MobileNetV3):
60 | """
61 | Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
62 | imagenet_codebase/networks/mobilenet_v3.py to include drop path in training
63 | and option to reset classification layer
64 | """
65 | @staticmethod
66 | def build_from_config(config, drop_connect_rate=0.0):
67 | first_conv = set_layer_from_config(config['first_conv'])
68 | final_expand_layer = set_layer_from_config(config['final_expand_layer'])
69 | feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
70 | classifier = set_layer_from_config(config['classifier'])
71 |
72 | blocks = []
73 | for block_idx, block_config in enumerate(config['blocks']):
74 | block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks'])
75 | blocks.append(MobileInvertedResidualBlock.build_from_config(block_config))
76 |
77 | net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
78 | if 'bn' in config:
79 | net.set_bn_param(**config['bn'])
80 | else:
81 | net.set_bn_param(momentum=0.1, eps=1e-3)
82 |
83 | return net
84 |
85 | def zero_last_gamma(self):
86 | for m in self.modules():
87 | if isinstance(m, MobileInvertedResidualBlock):
88 | if isinstance(m.mobile_inverted_conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
89 | m.mobile_inverted_conv.point_linear.bn.weight.data.zero_()
90 |
91 | @staticmethod
92 | def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
93 | # first conv layer
94 | first_conv = ConvLayer(
95 | 3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
96 | )
97 | # build mobile blocks
98 | feature_dim = input_channel
99 | blocks = []
100 | for stage_id, block_config_list in cfg.items():
101 | for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
102 | mb_conv = MBConvLayer(
103 | feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
104 | )
105 | if stride == 1 and out_channel == feature_dim:
106 | shortcut = IdentityLayer(out_channel, out_channel)
107 | else:
108 | shortcut = None
109 | blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut))
110 | feature_dim = out_channel
111 | # final expand layer
112 | final_expand_layer = ConvLayer(
113 | feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
114 | )
115 | feature_dim = feature_dim * 6
116 | # feature mix layer
117 | feature_mix_layer = ConvLayer(
118 | feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
119 | )
120 | # classifier
121 | classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
122 |
123 | return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
124 |
125 | @staticmethod
126 | def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0):
127 | model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
--------------------------------------------------------------------------------
/codebase/data_providers/stl10.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 |
5 | import torchvision
6 | import torch.utils.data
7 | import torchvision.transforms as transforms
8 |
9 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
10 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider
11 |
12 |
13 | class STL10DataProvider(DataProvider):
14 |
15 | def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
16 | n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
17 |
18 | self._save_path = save_path
19 |
20 | self.image_size = image_size # int or list of int
21 | self.distort_color = distort_color
22 | self.resize_scale = resize_scale
23 |
24 | self._valid_transform_dict = {}
25 | if not isinstance(self.image_size, int):
26 | assert isinstance(self.image_size, list)
27 | from ofa.utils.my_dataloader import MyDataLoader
28 | self.image_size.sort() # e.g., 160 -> 224
29 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
30 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
31 |
32 | for img_size in self.image_size:
33 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
34 | self.active_img_size = max(self.image_size)
35 | valid_transforms = self._valid_transform_dict[self.active_img_size]
36 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
37 | else:
38 | self.active_img_size = self.image_size
39 | valid_transforms = self.build_valid_transform()
40 | train_loader_class = torch.utils.data.DataLoader
41 |
42 | train_transforms = self.build_train_transform()
43 | train_dataset = self.train_dataset(train_transforms)
44 |
45 | if valid_size is not None:
46 | if not isinstance(valid_size, int):
47 | assert isinstance(valid_size, float) and 0 < valid_size < 1
48 | valid_size = int(len(train_dataset.data) * valid_size)
49 |
50 | valid_dataset = self.train_dataset(valid_transforms)
51 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
52 |
53 | if num_replicas is not None:
54 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
55 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
56 | else:
57 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
58 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
59 |
60 | self.train = train_loader_class(
61 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
62 | num_workers=n_worker, pin_memory=True,
63 | )
64 | self.valid = torch.utils.data.DataLoader(
65 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
66 | num_workers=n_worker, pin_memory=True,
67 | )
68 | else:
69 | if num_replicas is not None:
70 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
71 | self.train = train_loader_class(
72 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
73 | num_workers=n_worker, pin_memory=True
74 | )
75 | else:
76 | self.train = train_loader_class(
77 | train_dataset, batch_size=train_batch_size, shuffle=True,
78 | num_workers=n_worker, pin_memory=True,
79 | )
80 | self.valid = None
81 |
82 | test_dataset = self.test_dataset(valid_transforms)
83 | if num_replicas is not None:
84 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
85 | self.test = torch.utils.data.DataLoader(
86 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
87 | )
88 | else:
89 | self.test = torch.utils.data.DataLoader(
90 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
91 | )
92 |
93 | if self.valid is None:
94 | self.valid = self.test
95 |
96 | @staticmethod
97 | def name():
98 | return 'stl10'
99 |
100 | @property
101 | def data_shape(self):
102 | return 3, self.active_img_size, self.active_img_size # C, H, W
103 |
104 | @property
105 | def n_classes(self):
106 | return 10
107 |
108 | @property
109 | def save_path(self):
110 | if self._save_path is None:
111 | self._save_path = '/mnt/datastore/STL10' # home server
112 |
113 | if not os.path.exists(self._save_path):
114 | self._save_path = '/mnt/datastore/STL10' # home server
115 | return self._save_path
116 |
117 | @property
118 | def data_url(self):
119 | raise ValueError('unable to download %s' % self.name())
120 |
121 | def train_dataset(self, _transforms):
122 | # dataset = datasets.ImageFolder(self.train_path, _transforms)
123 | dataset = torchvision.datasets.STL10(
124 | root=self.valid_path, split='train', download=False, transform=_transforms)
125 | return dataset
126 |
127 | def test_dataset(self, _transforms):
128 | # dataset = datasets.ImageFolder(self.valid_path, _transforms)
129 | dataset = torchvision.datasets.STL10(
130 | root=self.valid_path, split='test', download=False, transform=_transforms)
131 | return dataset
132 |
133 | @property
134 | def train_path(self):
135 | # return os.path.join(self.save_path, 'train')
136 | return self.save_path
137 |
138 | @property
139 | def valid_path(self):
140 | # return os.path.join(self.save_path, 'val')
141 | return self.save_path
142 |
143 | @property
144 | def normalize(self):
145 | return transforms.Normalize(
146 | mean=[0.44671097, 0.4398105, 0.4066468],
147 | std=[0.2603405, 0.25657743, 0.27126738])
148 |
149 | def build_train_transform(self, image_size=None, print_log=True):
150 | if image_size is None:
151 | image_size = self.image_size
152 | if print_log:
153 | print('Color jitter: %s, resize_scale: %s, img_size: %s' %
154 | (self.distort_color, self.resize_scale, image_size))
155 |
156 | if self.distort_color == 'torch':
157 | color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
158 | elif self.distort_color == 'tf':
159 | color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
160 | else:
161 | color_transform = None
162 |
163 | if isinstance(image_size, list):
164 | resize_transform_class = MyRandomResizedCrop
165 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
166 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
167 | else:
168 | resize_transform_class = transforms.RandomResizedCrop
169 |
170 | train_transforms = [
171 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
172 | transforms.RandomHorizontalFlip(),
173 | ]
174 | if color_transform is not None:
175 | train_transforms.append(color_transform)
176 | train_transforms += [
177 | transforms.ToTensor(),
178 | self.normalize,
179 | ]
180 |
181 | train_transforms = transforms.Compose(train_transforms)
182 | return train_transforms
183 |
184 | def build_valid_transform(self, image_size=None):
185 | if image_size is None:
186 | image_size = self.active_img_size
187 | return transforms.Compose([
188 | transforms.Resize(int(math.ceil(image_size / 0.875))),
189 | transforms.CenterCrop(image_size),
190 | transforms.ToTensor(),
191 | self.normalize,
192 | ])
193 |
194 | def assign_active_img_size(self, new_img_size):
195 | self.active_img_size = new_img_size
196 | if self.active_img_size not in self._valid_transform_dict:
197 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
198 | # change the transform of the valid and test set
199 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
200 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
201 |
202 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
203 | # used for resetting running statistics
204 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
205 | if num_worker is None:
206 | num_worker = self.train.num_workers
207 |
208 | n_samples = len(self.train.dataset.data)
209 | g = torch.Generator()
210 | g.manual_seed(DataProvider.SUB_SEED)
211 | rand_indexes = torch.randperm(n_samples, generator=g).tolist()
212 |
213 | new_train_dataset = self.train_dataset(
214 | self.build_train_transform(image_size=self.active_img_size, print_log=False))
215 | chosen_indexes = rand_indexes[:n_images]
216 | if num_replicas is not None:
217 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
218 | else:
219 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
220 | sub_data_loader = torch.utils.data.DataLoader(
221 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
222 | num_workers=num_worker, pin_memory=True,
223 | )
224 | self.__dict__['sub_train_%d' % self.active_img_size] = []
225 | for images, labels in sub_data_loader:
226 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
227 | return self.__dict__['sub_train_%d' % self.active_img_size]
228 |
--------------------------------------------------------------------------------
/codebase/data_providers/flowers102.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import os
3 | import math
4 | import numpy as np
5 |
6 | import PIL
7 |
8 | import torch.utils.data
9 | import torchvision.transforms as transforms
10 | import torchvision.datasets as datasets
11 |
12 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
13 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider
14 |
15 |
16 | class Flowers102DataProvider(DataProvider):
17 |
18 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32,
19 | resize_scale=0.08, distort_color=None, image_size=224,
20 | num_replicas=None, rank=None):
21 |
22 | # warnings.filterwarnings('ignore')
23 | self._save_path = save_path
24 |
25 | self.image_size = image_size # int or list of int
26 | self.distort_color = distort_color
27 | self.resize_scale = resize_scale
28 |
29 | self._valid_transform_dict = {}
30 | if not isinstance(self.image_size, int):
31 | assert isinstance(self.image_size, list)
32 | from ofa.utils.my_dataloader import MyDataLoader
33 | self.image_size.sort() # e.g., 160 -> 224
34 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
35 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
36 |
37 | for img_size in self.image_size:
38 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
39 | self.active_img_size = max(self.image_size)
40 | valid_transforms = self._valid_transform_dict[self.active_img_size]
41 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
42 | else:
43 | self.active_img_size = self.image_size
44 | valid_transforms = self.build_valid_transform()
45 | train_loader_class = torch.utils.data.DataLoader
46 |
47 | train_transforms = self.build_train_transform()
48 | train_dataset = self.train_dataset(train_transforms)
49 |
50 | weights = self.make_weights_for_balanced_classes(
51 | train_dataset.imgs, self.n_classes)
52 | weights = torch.DoubleTensor(weights)
53 | train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
54 |
55 | if valid_size is not None:
56 | raise NotImplementedError("validation dataset not yet implemented")
57 | # valid_dataset = self.valid_dataset(valid_transforms)
58 |
59 | # self.train = train_loader_class(
60 | # train_dataset, batch_size=train_batch_size, sampler=train_sampler,
61 | # num_workers=n_worker, pin_memory=True)
62 | # self.valid = torch.utils.data.DataLoader(
63 | # valid_dataset, batch_size=test_batch_size,
64 | # num_workers=n_worker, pin_memory=True)
65 | else:
66 | self.train = train_loader_class(
67 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
68 | num_workers=n_worker, pin_memory=True,
69 | )
70 | self.valid = None
71 |
72 | test_dataset = self.test_dataset(valid_transforms)
73 | self.test = torch.utils.data.DataLoader(
74 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
75 | )
76 |
77 | if self.valid is None:
78 | self.valid = self.test
79 |
80 | @staticmethod
81 | def name():
82 | return 'flowers102'
83 |
84 | @property
85 | def data_shape(self):
86 | return 3, self.active_img_size, self.active_img_size # C, H, W
87 |
88 | @property
89 | def n_classes(self):
90 | return 102
91 |
92 | @property
93 | def save_path(self):
94 | if self._save_path is None:
95 | # self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
96 | self._save_path = '/mnt/datastore/Flowers102' # home server
97 |
98 | if not os.path.exists(self._save_path):
99 | # self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
100 | self._save_path = '/mnt/datastore/Flowers102' # home server
101 | return self._save_path
102 |
103 | @property
104 | def data_url(self):
105 | raise ValueError('unable to download %s' % self.name())
106 |
107 | def train_dataset(self, _transforms):
108 | dataset = datasets.ImageFolder(self.train_path, _transforms)
109 | return dataset
110 |
111 | # def valid_dataset(self, _transforms):
112 | # dataset = datasets.ImageFolder(self.valid_path, _transforms)
113 | # return dataset
114 |
115 | def test_dataset(self, _transforms):
116 | dataset = datasets.ImageFolder(self.test_path, _transforms)
117 | return dataset
118 |
119 | @property
120 | def train_path(self):
121 | return os.path.join(self.save_path, 'train')
122 |
123 | # @property
124 | # def valid_path(self):
125 | # return os.path.join(self.save_path, 'train')
126 |
127 | @property
128 | def test_path(self):
129 | return os.path.join(self.save_path, 'test')
130 |
131 | @property
132 | def normalize(self):
133 | return transforms.Normalize(
134 | mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803],
135 | std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665])
136 |
137 | @staticmethod
138 | def make_weights_for_balanced_classes(images, nclasses):
139 | count = [0] * nclasses
140 |
141 | # Counts per label
142 | for item in images:
143 | count[item[1]] += 1
144 |
145 | weight_per_class = [0.] * nclasses
146 |
147 | # Total number of images.
148 | N = float(sum(count))
149 |
150 | # super-sample the smaller classes.
151 | for i in range(nclasses):
152 | weight_per_class[i] = N / float(count[i])
153 |
154 | weight = [0] * len(images)
155 |
156 | # Calculate a weight per image.
157 | for idx, val in enumerate(images):
158 | weight[idx] = weight_per_class[val[1]]
159 |
160 | return weight
161 |
162 | def build_train_transform(self, image_size=None, print_log=True):
163 | if image_size is None:
164 | image_size = self.image_size
165 | if print_log:
166 | print('Color jitter: %s, resize_scale: %s, img_size: %s' %
167 | (self.distort_color, self.resize_scale, image_size))
168 |
169 | if self.distort_color == 'torch':
170 | color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
171 | elif self.distort_color == 'tf':
172 | color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
173 | else:
174 | color_transform = None
175 |
176 | if isinstance(image_size, list):
177 | resize_transform_class = MyRandomResizedCrop
178 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
179 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
180 | else:
181 | resize_transform_class = transforms.RandomResizedCrop
182 |
183 | train_transforms = [
184 | transforms.RandomAffine(
185 | 45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0),
186 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
187 | # transforms.RandomHorizontalFlip(),
188 | ]
189 | if color_transform is not None:
190 | train_transforms.append(color_transform)
191 | train_transforms += [
192 | transforms.ToTensor(),
193 | self.normalize,
194 | ]
195 |
196 | train_transforms = transforms.Compose(train_transforms)
197 | return train_transforms
198 |
199 | def build_valid_transform(self, image_size=None):
200 | if image_size is None:
201 | image_size = self.active_img_size
202 | return transforms.Compose([
203 | transforms.Resize(int(math.ceil(image_size / 0.875))),
204 | transforms.CenterCrop(image_size),
205 | transforms.ToTensor(),
206 | self.normalize,
207 | ])
208 |
209 | def assign_active_img_size(self, new_img_size):
210 | self.active_img_size = new_img_size
211 | if self.active_img_size not in self._valid_transform_dict:
212 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
213 | # change the transform of the valid and test set
214 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
215 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
216 |
217 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
218 | # used for resetting running statistics
219 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
220 | if num_worker is None:
221 | num_worker = self.train.num_workers
222 |
223 | n_samples = len(self.train.dataset.samples)
224 | g = torch.Generator()
225 | g.manual_seed(DataProvider.SUB_SEED)
226 | rand_indexes = torch.randperm(n_samples, generator=g).tolist()
227 |
228 | new_train_dataset = self.train_dataset(
229 | self.build_train_transform(image_size=self.active_img_size, print_log=False))
230 | chosen_indexes = rand_indexes[:n_images]
231 | if num_replicas is not None:
232 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
233 | else:
234 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
235 | sub_data_loader = torch.utils.data.DataLoader(
236 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
237 | num_workers=num_worker, pin_memory=True,
238 | )
239 | self.__dict__['sub_train_%d' % self.active_img_size] = []
240 | for images, labels in sub_data_loader:
241 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
242 | return self.__dict__['sub_train_%d' % self.active_img_size]
243 |
--------------------------------------------------------------------------------
/codebase/data_providers/imagenet.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import os
3 | import math
4 | import numpy as np
5 |
6 | import torch.utils.data
7 | import torchvision.transforms as transforms
8 | import torchvision.datasets as datasets
9 |
10 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
11 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider
12 |
13 |
14 | class ImagenetDataProvider(DataProvider):
15 |
16 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
17 | resize_scale=0.08, distort_color=None, image_size=224,
18 | num_replicas=None, rank=None):
19 |
20 | warnings.filterwarnings('ignore')
21 | self._save_path = save_path
22 |
23 | self.image_size = image_size # int or list of int
24 | self.distort_color = distort_color
25 | self.resize_scale = resize_scale
26 |
27 | self._valid_transform_dict = {}
28 | if not isinstance(self.image_size, int):
29 | assert isinstance(self.image_size, list)
30 | from ofa.utils.my_dataloader import MyDataLoader
31 | self.image_size.sort() # e.g., 160 -> 224
32 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
33 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
34 |
35 | for img_size in self.image_size:
36 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
37 | self.active_img_size = max(self.image_size)
38 | valid_transforms = self._valid_transform_dict[self.active_img_size]
39 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
40 | else:
41 | self.active_img_size = self.image_size
42 | valid_transforms = self.build_valid_transform()
43 | train_loader_class = torch.utils.data.DataLoader
44 |
45 | train_transforms = self.build_train_transform()
46 | train_dataset = self.train_dataset(train_transforms)
47 |
48 | if valid_size is not None:
49 | if not isinstance(valid_size, int):
50 | assert isinstance(valid_size, float) and 0 < valid_size < 1
51 | valid_size = int(len(train_dataset.samples) * valid_size)
52 |
53 | valid_dataset = self.train_dataset(valid_transforms)
54 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
55 |
56 | if num_replicas is not None:
57 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
58 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
59 | else:
60 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
61 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
62 |
63 | self.train = train_loader_class(
64 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
65 | num_workers=n_worker, pin_memory=True,
66 | )
67 | self.valid = torch.utils.data.DataLoader(
68 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
69 | num_workers=n_worker, pin_memory=True,
70 | )
71 | else:
72 | if num_replicas is not None:
73 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
74 | self.train = train_loader_class(
75 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
76 | num_workers=n_worker, pin_memory=True
77 | )
78 | else:
79 | self.train = train_loader_class(
80 | train_dataset, batch_size=train_batch_size, shuffle=True,
81 | num_workers=n_worker, pin_memory=True,
82 | )
83 | self.valid = None
84 |
85 | test_dataset = self.test_dataset(valid_transforms)
86 | if num_replicas is not None:
87 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
88 | self.test = torch.utils.data.DataLoader(
89 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
90 | )
91 | else:
92 | self.test = torch.utils.data.DataLoader(
93 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
94 | )
95 |
96 | if self.valid is None:
97 | self.valid = self.test
98 |
99 | @staticmethod
100 | def name():
101 | return 'imagenet'
102 |
103 | @property
104 | def data_shape(self):
105 | return 3, self.active_img_size, self.active_img_size # C, H, W
106 |
107 | @property
108 | def n_classes(self):
109 | return 1000
110 |
111 | @property
112 | def save_path(self):
113 | if self._save_path is None:
114 | # self._save_path = '/dataset/imagenet'
115 | # self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers
116 | self._save_path = '/mnt/datastore/ILSVRC2012' # home server
117 |
118 | if not os.path.exists(self._save_path):
119 | # self._save_path = os.path.expanduser('~/dataset/imagenet')
120 | # self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012')
121 | self._save_path = '/mnt/datastore/ILSVRC2012' # home server
122 | return self._save_path
123 |
124 | @property
125 | def data_url(self):
126 | raise ValueError('unable to download %s' % self.name())
127 |
128 | def train_dataset(self, _transforms):
129 | dataset = datasets.ImageFolder(self.train_path, _transforms)
130 | return dataset
131 |
132 | def test_dataset(self, _transforms):
133 | dataset = datasets.ImageFolder(self.valid_path, _transforms)
134 | return dataset
135 |
136 | @property
137 | def train_path(self):
138 | return os.path.join(self.save_path, 'train')
139 |
140 | @property
141 | def valid_path(self):
142 | return os.path.join(self.save_path, 'val')
143 |
144 | @property
145 | def normalize(self):
146 | return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
147 |
148 | def build_train_transform(self, image_size=None, print_log=True):
149 | if image_size is None:
150 | image_size = self.image_size
151 | if print_log:
152 | print('Color jitter: %s, resize_scale: %s, img_size: %s' %
153 | (self.distort_color, self.resize_scale, image_size))
154 |
155 | if self.distort_color == 'torch':
156 | color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
157 | elif self.distort_color == 'tf':
158 | color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
159 | else:
160 | color_transform = None
161 |
162 | if isinstance(image_size, list):
163 | resize_transform_class = MyRandomResizedCrop
164 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
165 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
166 | else:
167 | resize_transform_class = transforms.RandomResizedCrop
168 |
169 | train_transforms = [
170 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
171 | transforms.RandomHorizontalFlip(),
172 | ]
173 | if color_transform is not None:
174 | train_transforms.append(color_transform)
175 | train_transforms += [
176 | transforms.ToTensor(),
177 | self.normalize,
178 | ]
179 |
180 | train_transforms = transforms.Compose(train_transforms)
181 | return train_transforms
182 |
183 | def build_valid_transform(self, image_size=None):
184 | if image_size is None:
185 | image_size = self.active_img_size
186 | return transforms.Compose([
187 | transforms.Resize(int(math.ceil(image_size / 0.875))),
188 | transforms.CenterCrop(image_size),
189 | transforms.ToTensor(),
190 | self.normalize,
191 | ])
192 |
193 | def assign_active_img_size(self, new_img_size):
194 | self.active_img_size = new_img_size
195 | if self.active_img_size not in self._valid_transform_dict:
196 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
197 | # change the transform of the valid and test set
198 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
199 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
200 |
201 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
202 | # used for resetting running statistics
203 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
204 | if num_worker is None:
205 | num_worker = self.train.num_workers
206 |
207 | n_samples = len(self.train.dataset.samples)
208 | g = torch.Generator()
209 | g.manual_seed(DataProvider.SUB_SEED)
210 | rand_indexes = torch.randperm(n_samples, generator=g).tolist()
211 |
212 | new_train_dataset = self.train_dataset(
213 | self.build_train_transform(image_size=self.active_img_size, print_log=False))
214 | chosen_indexes = rand_indexes[:n_images]
215 | if num_replicas is not None:
216 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
217 | else:
218 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
219 | sub_data_loader = torch.utils.data.DataLoader(
220 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
221 | num_workers=num_worker, pin_memory=True,
222 | )
223 | self.__dict__['sub_train_%d' % self.active_img_size] = []
224 | for images, labels in sub_data_loader:
225 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
226 | return self.__dict__['sub_train_%d' % self.active_img_size]
227 |
--------------------------------------------------------------------------------
/codebase/data_providers/pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import warnings
4 | import numpy as np
5 |
6 | from timm.data.transforms import str_to_pil_interp
7 | from timm.data.auto_augment import rand_augment_transform
8 |
9 | import torch.utils.data
10 | import torchvision.transforms as transforms
11 | import torchvision.datasets as datasets
12 |
13 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
14 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider
15 |
16 |
17 | class OxfordIIITPetsDataProvider(DataProvider):
18 |
19 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
20 | resize_scale=0.08, distort_color=None, image_size=224,
21 | num_replicas=None, rank=None):
22 |
23 | warnings.filterwarnings('ignore')
24 | self._save_path = save_path
25 |
26 | self.image_size = image_size # int or list of int
27 | self.distort_color = distort_color
28 | self.resize_scale = resize_scale
29 |
30 | self._valid_transform_dict = {}
31 | if not isinstance(self.image_size, int):
32 | assert isinstance(self.image_size, list)
33 | from ofa.utils.my_dataloader import MyDataLoader
34 | self.image_size.sort() # e.g., 160 -> 224
35 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
36 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
37 |
38 | for img_size in self.image_size:
39 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
40 | self.active_img_size = max(self.image_size)
41 | valid_transforms = self._valid_transform_dict[self.active_img_size]
42 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
43 | else:
44 | self.active_img_size = self.image_size
45 | valid_transforms = self.build_valid_transform()
46 | train_loader_class = torch.utils.data.DataLoader
47 |
48 | train_transforms = self.build_train_transform()
49 | train_dataset = self.train_dataset(train_transforms)
50 |
51 | if valid_size is not None:
52 | if not isinstance(valid_size, int):
53 | assert isinstance(valid_size, float) and 0 < valid_size < 1
54 | valid_size = int(len(train_dataset.samples) * valid_size)
55 |
56 | valid_dataset = self.train_dataset(valid_transforms)
57 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
58 |
59 | if num_replicas is not None:
60 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
61 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
62 | else:
63 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
64 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
65 |
66 | self.train = train_loader_class(
67 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
68 | num_workers=n_worker, pin_memory=True,
69 | )
70 | self.valid = torch.utils.data.DataLoader(
71 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
72 | num_workers=n_worker, pin_memory=True,
73 | )
74 | else:
75 | if num_replicas is not None:
76 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
77 | self.train = train_loader_class(
78 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
79 | num_workers=n_worker, pin_memory=True
80 | )
81 | else:
82 | self.train = train_loader_class(
83 | train_dataset, batch_size=train_batch_size, shuffle=True,
84 | num_workers=n_worker, pin_memory=True,
85 | )
86 | self.valid = None
87 |
88 | test_dataset = self.test_dataset(valid_transforms)
89 | if num_replicas is not None:
90 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
91 | self.test = torch.utils.data.DataLoader(
92 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
93 | )
94 | else:
95 | self.test = torch.utils.data.DataLoader(
96 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
97 | )
98 |
99 | if self.valid is None:
100 | self.valid = self.test
101 |
102 | @staticmethod
103 | def name():
104 | return 'pets'
105 |
106 | @property
107 | def data_shape(self):
108 | return 3, self.active_img_size, self.active_img_size # C, H, W
109 |
110 | @property
111 | def n_classes(self):
112 | return 37
113 |
114 | @property
115 | def save_path(self):
116 | if self._save_path is None:
117 | self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
118 |
119 | if not os.path.exists(self._save_path):
120 | self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
121 | return self._save_path
122 |
123 | @property
124 | def data_url(self):
125 | raise ValueError('unable to download %s' % self.name())
126 |
127 | def train_dataset(self, _transforms):
128 | dataset = datasets.ImageFolder(self.train_path, _transforms)
129 | return dataset
130 |
131 | def test_dataset(self, _transforms):
132 | dataset = datasets.ImageFolder(self.valid_path, _transforms)
133 | return dataset
134 |
135 | @property
136 | def train_path(self):
137 | return os.path.join(self.save_path, 'train')
138 |
139 | @property
140 | def valid_path(self):
141 | return os.path.join(self.save_path, 'valid')
142 |
143 | @property
144 | def normalize(self):
145 | return transforms.Normalize(
146 | mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783],
147 | std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917])
148 |
149 | def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
150 | if image_size is None:
151 | image_size = self.image_size
152 | # if print_log:
153 | # print('Color jitter: %s, resize_scale: %s, img_size: %s' %
154 | # (self.distort_color, self.resize_scale, image_size))
155 |
156 | # if self.distort_color == 'torch':
157 | # color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
158 | # elif self.distort_color == 'tf':
159 | # color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
160 | # else:
161 | # color_transform = None
162 |
163 | if isinstance(image_size, list):
164 | resize_transform_class = MyRandomResizedCrop
165 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
166 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
167 | img_size_min = min(image_size)
168 | else:
169 | resize_transform_class = transforms.RandomResizedCrop
170 | img_size_min = image_size
171 |
172 | train_transforms = [
173 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
174 | transforms.RandomHorizontalFlip(),
175 | ]
176 |
177 | aa_params = dict(
178 | translate_const=int(img_size_min * 0.45),
179 | img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807,
180 | 0.39566558230789783]]),
181 | )
182 | aa_params['interpolation'] = str_to_pil_interp('bicubic')
183 | train_transforms += [rand_augment_transform(auto_augment, aa_params)]
184 |
185 | # if color_transform is not None:
186 | # train_transforms.append(color_transform)
187 | train_transforms += [
188 | transforms.ToTensor(),
189 | self.normalize,
190 | ]
191 |
192 | train_transforms = transforms.Compose(train_transforms)
193 | return train_transforms
194 |
195 | def build_valid_transform(self, image_size=None):
196 | if image_size is None:
197 | image_size = self.active_img_size
198 | return transforms.Compose([
199 | transforms.Resize(int(math.ceil(image_size / 0.875))),
200 | transforms.CenterCrop(image_size),
201 | transforms.ToTensor(),
202 | self.normalize,
203 | ])
204 |
205 | def assign_active_img_size(self, new_img_size):
206 | self.active_img_size = new_img_size
207 | if self.active_img_size not in self._valid_transform_dict:
208 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
209 | # change the transform of the valid and test set
210 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
211 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
212 |
213 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
214 | # used for resetting running statistics
215 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
216 | if num_worker is None:
217 | num_worker = self.train.num_workers
218 |
219 | n_samples = len(self.train.dataset.samples)
220 | g = torch.Generator()
221 | g.manual_seed(DataProvider.SUB_SEED)
222 | rand_indexes = torch.randperm(n_samples, generator=g).tolist()
223 |
224 | new_train_dataset = self.train_dataset(
225 | self.build_train_transform(image_size=self.active_img_size, print_log=False))
226 | chosen_indexes = rand_indexes[:n_images]
227 | if num_replicas is not None:
228 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
229 | else:
230 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
231 | sub_data_loader = torch.utils.data.DataLoader(
232 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
233 | num_workers=num_worker, pin_memory=True,
234 | )
235 | self.__dict__['sub_train_%d' % self.active_img_size] = []
236 | for images, labels in sub_data_loader:
237 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
238 | return self.__dict__['sub_train_%d' % self.active_img_size]
--------------------------------------------------------------------------------
/codebase/data_providers/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | import numpy as np
4 |
5 | from timm.data.transforms import str_to_pil_interp
6 | from timm.data.auto_augment import rand_augment_transform
7 |
8 | import torch.utils.data
9 | import torchvision.transforms as transforms
10 | import torchvision.datasets as datasets
11 |
12 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
13 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider
14 |
15 |
16 | class DTDDataProvider(DataProvider):
17 |
18 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
19 | resize_scale=0.08, distort_color=None, image_size=224,
20 | num_replicas=None, rank=None):
21 |
22 | warnings.filterwarnings('ignore')
23 | self._save_path = save_path
24 |
25 | self.image_size = image_size # int or list of int
26 | self.distort_color = distort_color
27 | self.resize_scale = resize_scale
28 |
29 | self._valid_transform_dict = {}
30 | if not isinstance(self.image_size, int):
31 | assert isinstance(self.image_size, list)
32 | from ofa.utils.my_dataloader import MyDataLoader
33 | self.image_size.sort() # e.g., 160 -> 224
34 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
35 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
36 |
37 | for img_size in self.image_size:
38 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
39 | self.active_img_size = max(self.image_size)
40 | valid_transforms = self._valid_transform_dict[self.active_img_size]
41 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
42 | else:
43 | self.active_img_size = self.image_size
44 | valid_transforms = self.build_valid_transform()
45 | train_loader_class = torch.utils.data.DataLoader
46 |
47 | train_transforms = self.build_train_transform()
48 | train_dataset = self.train_dataset(train_transforms)
49 |
50 | if valid_size is not None:
51 | if not isinstance(valid_size, int):
52 | assert isinstance(valid_size, float) and 0 < valid_size < 1
53 | valid_size = int(len(train_dataset.samples) * valid_size)
54 |
55 | valid_dataset = self.train_dataset(valid_transforms)
56 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
57 |
58 | if num_replicas is not None:
59 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
60 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
61 | else:
62 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
63 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
64 |
65 | self.train = train_loader_class(
66 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
67 | num_workers=n_worker, pin_memory=True,
68 | )
69 | self.valid = torch.utils.data.DataLoader(
70 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
71 | num_workers=n_worker, pin_memory=True,
72 | )
73 | else:
74 | if num_replicas is not None:
75 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
76 | self.train = train_loader_class(
77 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
78 | num_workers=n_worker, pin_memory=True
79 | )
80 | else:
81 | self.train = train_loader_class(
82 | train_dataset, batch_size=train_batch_size, shuffle=True,
83 | num_workers=n_worker, pin_memory=True,
84 | )
85 | self.valid = None
86 |
87 | test_dataset = self.test_dataset(valid_transforms)
88 | if num_replicas is not None:
89 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
90 | self.test = torch.utils.data.DataLoader(
91 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
92 | )
93 | else:
94 | self.test = torch.utils.data.DataLoader(
95 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
96 | )
97 |
98 | if self.valid is None:
99 | self.valid = self.test
100 |
101 | @staticmethod
102 | def name():
103 | return 'dtd'
104 |
105 | @property
106 | def data_shape(self):
107 | return 3, self.active_img_size, self.active_img_size # C, H, W
108 |
109 | @property
110 | def n_classes(self):
111 | return 47
112 |
113 | @property
114 | def save_path(self):
115 | if self._save_path is None:
116 | self._save_path = '/mnt/datastore/dtd' # home server
117 |
118 | if not os.path.exists(self._save_path):
119 | self._save_path = '/mnt/datastore/dtd' # home server
120 | return self._save_path
121 |
122 | @property
123 | def data_url(self):
124 | raise ValueError('unable to download %s' % self.name())
125 |
126 | def train_dataset(self, _transforms):
127 | dataset = datasets.ImageFolder(self.train_path, _transforms)
128 | return dataset
129 |
130 | def test_dataset(self, _transforms):
131 | dataset = datasets.ImageFolder(self.valid_path, _transforms)
132 | return dataset
133 |
134 | @property
135 | def train_path(self):
136 | return os.path.join(self.save_path, 'train')
137 |
138 | @property
139 | def valid_path(self):
140 | return os.path.join(self.save_path, 'valid')
141 |
142 | @property
143 | def normalize(self):
144 | return transforms.Normalize(
145 | mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676],
146 | std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832])
147 |
148 | def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
149 | if image_size is None:
150 | image_size = self.image_size
151 | # if print_log:
152 | # print('Color jitter: %s, resize_scale: %s, img_size: %s' %
153 | # (self.distort_color, self.resize_scale, image_size))
154 |
155 | # if self.distort_color == 'torch':
156 | # color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
157 | # elif self.distort_color == 'tf':
158 | # color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
159 | # else:
160 | # color_transform = None
161 |
162 | if isinstance(image_size, list):
163 | resize_transform_class = MyRandomResizedCrop
164 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
165 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
166 | img_size_min = min(image_size)
167 | else:
168 | resize_transform_class = transforms.RandomResizedCrop
169 | img_size_min = image_size
170 |
171 | train_transforms = [
172 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
173 | transforms.RandomHorizontalFlip(),
174 | ]
175 |
176 | aa_params = dict(
177 | translate_const=int(img_size_min * 0.45),
178 | img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454,
179 | 0.42627281899380676]]),
180 | )
181 | aa_params['interpolation'] = str_to_pil_interp('bicubic')
182 | train_transforms += [rand_augment_transform(auto_augment, aa_params)]
183 |
184 | # if color_transform is not None:
185 | # train_transforms.append(color_transform)
186 | train_transforms += [
187 | transforms.ToTensor(),
188 | self.normalize,
189 | ]
190 |
191 | train_transforms = transforms.Compose(train_transforms)
192 | return train_transforms
193 |
194 | def build_valid_transform(self, image_size=None):
195 | if image_size is None:
196 | image_size = self.active_img_size
197 | return transforms.Compose([
198 | # transforms.Resize(int(math.ceil(image_size / 0.875))),
199 | transforms.Resize((image_size, image_size), interpolation=3),
200 | transforms.CenterCrop(image_size),
201 | transforms.ToTensor(),
202 | self.normalize,
203 | ])
204 |
205 | def assign_active_img_size(self, new_img_size):
206 | self.active_img_size = new_img_size
207 | if self.active_img_size not in self._valid_transform_dict:
208 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
209 | # change the transform of the valid and test set
210 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
211 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
212 |
213 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
214 | # used for resetting running statistics
215 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
216 | if num_worker is None:
217 | num_worker = self.train.num_workers
218 |
219 | n_samples = len(self.train.dataset.samples)
220 | g = torch.Generator()
221 | g.manual_seed(DataProvider.SUB_SEED)
222 | rand_indexes = torch.randperm(n_samples, generator=g).tolist()
223 |
224 | new_train_dataset = self.train_dataset(
225 | self.build_train_transform(image_size=self.active_img_size, print_log=False))
226 | chosen_indexes = rand_indexes[:n_images]
227 | if num_replicas is not None:
228 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
229 | else:
230 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
231 | sub_data_loader = torch.utils.data.DataLoader(
232 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
233 | num_workers=num_worker, pin_memory=True,
234 | )
235 | self.__dict__['sub_train_%d' % self.active_img_size] = []
236 | for images, labels in sub_data_loader:
237 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
238 | return self.__dict__['sub_train_%d' % self.active_img_size]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/codebase/data_providers/autoaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
3 | """
4 |
5 | from PIL import Image, ImageEnhance, ImageOps
6 | import numpy as np
7 | import random
8 |
9 |
10 | class ImageNetPolicy(object):
11 | """ Randomly choose one of the best 24 Sub-policies on ImageNet.
12 |
13 | Example:
14 | >>> policy = ImageNetPolicy()
15 | >>> transformed = policy(image)
16 |
17 | Example as a PyTorch Transform:
18 | >>> transform=transforms.Compose([
19 | >>> transforms.Resize(256),
20 | >>> ImageNetPolicy(),
21 | >>> transforms.ToTensor()])
22 | """
23 | def __init__(self, fillcolor=(128, 128, 128)):
24 | self.policies = [
25 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
26 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
27 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
28 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
29 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
30 |
31 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
32 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
33 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
34 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
35 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
36 |
37 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
38 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
39 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
40 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
41 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
42 |
43 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
44 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
45 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
46 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
47 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
48 |
49 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
50 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
51 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
52 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
53 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
54 | ]
55 |
56 |
57 | def __call__(self, img):
58 | policy_idx = random.randint(0, len(self.policies) - 1)
59 | return self.policies[policy_idx](img)
60 |
61 | def __repr__(self):
62 | return "AutoAugment ImageNet Policy"
63 |
64 |
65 | class CIFAR10Policy(object):
66 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
67 |
68 | Example:
69 | >>> policy = CIFAR10Policy()
70 | >>> transformed = policy(image)
71 |
72 | Example as a PyTorch Transform:
73 | >>> transform=transforms.Compose([
74 | >>> transforms.Resize(256),
75 | >>> CIFAR10Policy(),
76 | >>> transforms.ToTensor()])
77 | """
78 | def __init__(self, fillcolor=(128, 128, 128)):
79 | self.policies = [
80 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
81 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
82 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
83 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
84 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
85 |
86 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
87 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
88 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
89 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
90 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
91 |
92 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
93 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
94 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
95 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
96 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
97 |
98 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
99 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
100 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
101 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
102 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
103 |
104 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
105 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
106 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
107 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
108 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
109 | ]
110 |
111 |
112 | def __call__(self, img):
113 | policy_idx = random.randint(0, len(self.policies) - 1)
114 | return self.policies[policy_idx](img)
115 |
116 | def __repr__(self):
117 | return "AutoAugment CIFAR10 Policy"
118 |
119 |
120 | class SVHNPolicy(object):
121 | """ Randomly choose one of the best 25 Sub-policies on SVHN.
122 |
123 | Example:
124 | >>> policy = SVHNPolicy()
125 | >>> transformed = policy(image)
126 |
127 | Example as a PyTorch Transform:
128 | >>> transform=transforms.Compose([
129 | >>> transforms.Resize(256),
130 | >>> SVHNPolicy(),
131 | >>> transforms.ToTensor()])
132 | """
133 | def __init__(self, fillcolor=(128, 128, 128)):
134 | self.policies = [
135 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
136 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
137 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
138 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
139 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
140 |
141 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
142 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
143 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
144 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
145 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
146 |
147 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
148 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
149 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
150 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
151 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
152 |
153 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
154 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
155 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
156 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
157 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
158 |
159 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
160 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
161 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
162 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
163 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
164 | ]
165 |
166 |
167 | def __call__(self, img):
168 | policy_idx = random.randint(0, len(self.policies) - 1)
169 | return self.policies[policy_idx](img)
170 |
171 | def __repr__(self):
172 | return "AutoAugment SVHN Policy"
173 |
174 |
175 | class SubPolicy(object):
176 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
177 | ranges = {
178 | "shearX": np.linspace(0, 0.3, 10),
179 | "shearY": np.linspace(0, 0.3, 10),
180 | "translateX": np.linspace(0, 150 / 331, 10),
181 | "translateY": np.linspace(0, 150 / 331, 10),
182 | "rotate": np.linspace(0, 30, 10),
183 | "color": np.linspace(0.0, 0.9, 10),
184 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
185 | "solarize": np.linspace(256, 0, 10),
186 | "contrast": np.linspace(0.0, 0.9, 10),
187 | "sharpness": np.linspace(0.0, 0.9, 10),
188 | "brightness": np.linspace(0.0, 0.9, 10),
189 | "autocontrast": [0] * 10,
190 | "equalize": [0] * 10,
191 | "invert": [0] * 10
192 | }
193 |
194 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
195 | def rotate_with_fill(img, magnitude):
196 | rot = img.convert("RGBA").rotate(magnitude)
197 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
198 |
199 | func = {
200 | "shearX": lambda img, magnitude: img.transform(
201 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
202 | Image.BICUBIC, fillcolor=fillcolor),
203 | "shearY": lambda img, magnitude: img.transform(
204 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
205 | Image.BICUBIC, fillcolor=fillcolor),
206 | "translateX": lambda img, magnitude: img.transform(
207 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
208 | fillcolor=fillcolor),
209 | "translateY": lambda img, magnitude: img.transform(
210 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
211 | fillcolor=fillcolor),
212 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
213 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
214 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
215 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
216 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
217 | 1 + magnitude * random.choice([-1, 1])),
218 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
219 | 1 + magnitude * random.choice([-1, 1])),
220 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
221 | 1 + magnitude * random.choice([-1, 1])),
222 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
223 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
224 | "invert": lambda img, magnitude: ImageOps.invert(img)
225 | }
226 |
227 | self.p1 = p1
228 | self.operation1 = func[operation1]
229 | self.magnitude1 = ranges[operation1][magnitude_idx1]
230 | self.p2 = p2
231 | self.operation2 = func[operation2]
232 | self.magnitude2 = ranges[operation2][magnitude_idx2]
233 |
234 |
235 | def __call__(self, img):
236 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
237 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
238 | return img
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import json
4 | import yaml
5 | import numpy as np
6 | from collections import OrderedDict
7 | from torchprofile import profile_macs
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.backends.cudnn as cudnn
12 |
13 | from pymoo.model.mutation import Mutation
14 | from pymoo.model.sampling import Sampling
15 | from pymoo.model.crossover import Crossover
16 |
17 | DEFAULT_CFG = {
18 | 'gpus': '0', 'config': None, 'init': None, 'trn_batch_size': 128, 'vld_batch_size': 250, 'num_workers': 4,
19 | 'n_epochs': 0, 'save': None, 'resolution': 224, 'valid_size': 10000, 'test': True, 'latency': None,
20 | 'verbose': False, 'classifier_only': False, 'reset_running_statistics': True,
21 | }
22 |
23 |
24 | def get_correlation(prediction, target):
25 | import scipy.stats as stats
26 |
27 | rmse = np.sqrt(((prediction - target) ** 2).mean())
28 | rho, _ = stats.spearmanr(prediction, target)
29 | tau, _ = stats.kendalltau(prediction, target)
30 |
31 | return rmse, rho, tau
32 |
33 |
34 | def bash_command_template(**kwargs):
35 | gpus = kwargs.pop('gpus', DEFAULT_CFG['gpus'])
36 | cfg = OrderedDict()
37 | cfg['subnet'] = kwargs['subnet']
38 | cfg['data'] = kwargs['data']
39 | cfg['dataset'] = kwargs['dataset']
40 | cfg['n_classes'] = kwargs['n_classes']
41 | cfg['supernet_path'] = kwargs['supernet_path']
42 | cfg['config'] = kwargs.pop('config', DEFAULT_CFG['config'])
43 | cfg['init'] = kwargs.pop('init', DEFAULT_CFG['init'])
44 | cfg['save'] = kwargs.pop('save', DEFAULT_CFG['save'])
45 | cfg['trn_batch_size'] = kwargs.pop('trn_batch_size', DEFAULT_CFG['trn_batch_size'])
46 | cfg['vld_batch_size'] = kwargs.pop('vld_batch_size', DEFAULT_CFG['vld_batch_size'])
47 | cfg['num_workers'] = kwargs.pop('num_workers', DEFAULT_CFG['num_workers'])
48 | cfg['n_epochs'] = kwargs.pop('n_epochs', DEFAULT_CFG['n_epochs'])
49 | cfg['resolution'] = kwargs.pop('resolution', DEFAULT_CFG['resolution'])
50 | cfg['valid_size'] = kwargs.pop('valid_size', DEFAULT_CFG['valid_size'])
51 | cfg['test'] = kwargs.pop('test', DEFAULT_CFG['test'])
52 | cfg['latency'] = kwargs.pop('latency', DEFAULT_CFG['latency'])
53 | cfg['verbose'] = kwargs.pop('verbose', DEFAULT_CFG['verbose'])
54 | cfg['classifier_only'] = kwargs.pop('classifier_only', DEFAULT_CFG['classifier_only'])
55 | cfg['reset_running_statistics'] = kwargs.pop(
56 | 'reset_running_statistics', DEFAULT_CFG['reset_running_statistics'])
57 |
58 | execution_line = "CUDA_VISIBLE_DEVICES={} python evaluator.py".format(gpus)
59 | for k, v in cfg.items():
60 | if v is not None:
61 | if isinstance(v, bool):
62 | if v:
63 | execution_line += " --{}".format(k)
64 | else:
65 | execution_line += " --{} {}".format(k, v)
66 | execution_line += ' &'
67 | return execution_line
68 |
69 |
70 | def prepare_eval_folder(path, configs, gpu=2, n_gpus=8, **kwargs):
71 | """ create a folder for parallel evaluation of a population of architectures """
72 | os.makedirs(path, exist_ok=True)
73 | gpu_template = ','.join(['{}'] * gpu)
74 | gpus = [gpu_template.format(i, i + 1) for i in range(0, n_gpus, gpu)]
75 | bash_file = ['#!/bin/bash']
76 | for i in range(0, len(configs), n_gpus//gpu):
77 | for j in range(n_gpus//gpu):
78 | if i + j < len(configs):
79 | job = os.path.join(path, "net_{}.subnet".format(i + j))
80 | with open(job, 'w') as handle:
81 | json.dump(configs[i + j], handle)
82 | bash_file.append(bash_command_template(
83 | gpus=gpus[j], subnet=job, save=os.path.join(
84 | path, "net_{}.stats".format(i + j)), **kwargs))
85 | bash_file.append('wait')
86 |
87 | with open(os.path.join(path, 'run_bash.sh'), 'w') as handle:
88 | for line in bash_file:
89 | handle.write(line + os.linesep)
90 |
91 |
92 | class MySampling(Sampling):
93 |
94 | def _do(self, problem, n_samples, **kwargs):
95 | X = np.full((n_samples, problem.n_var), False, dtype=np.bool)
96 |
97 | for k in range(n_samples):
98 | I = np.random.permutation(problem.n_var)[:problem.n_max]
99 | X[k, I] = True
100 |
101 | return X
102 |
103 |
104 | class BinaryCrossover(Crossover):
105 | def __init__(self):
106 | super().__init__(2, 1)
107 |
108 | def _do(self, problem, X, **kwargs):
109 | n_parents, n_matings, n_var = X.shape
110 |
111 | _X = np.full((self.n_offsprings, n_matings, problem.n_var), False)
112 |
113 | for k in range(n_matings):
114 | p1, p2 = X[0, k], X[1, k]
115 |
116 | both_are_true = np.logical_and(p1, p2)
117 | _X[0, k, both_are_true] = True
118 |
119 | n_remaining = problem.n_max - np.sum(both_are_true)
120 |
121 | I = np.where(np.logical_xor(p1, p2))[0]
122 |
123 | S = I[np.random.permutation(len(I))][:n_remaining]
124 | _X[0, k, S] = True
125 |
126 | return _X
127 |
128 |
129 | class MyMutation(Mutation):
130 | def _do(self, problem, X, **kwargs):
131 | for i in range(X.shape[0]):
132 | X[i, :] = X[i, :]
133 | is_false = np.where(np.logical_not(X[i, :]))[0]
134 | is_true = np.where(X[i, :])[0]
135 | try:
136 | X[i, np.random.choice(is_false)] = True
137 | X[i, np.random.choice(is_true)] = False
138 | except ValueError:
139 | pass
140 |
141 | return X
142 |
143 |
144 | class LatencyEstimator(object):
145 | """
146 | Modified from https://github.com/mit-han-lab/proxylessnas/blob/
147 | f273683a77c4df082dd11cc963b07fc3613079a0/search/utils/latency_estimator.py#L29
148 | """
149 | def __init__(self, fname):
150 | # fname = download_url(url, overwrite=True)
151 |
152 | with open(fname, 'r') as fp:
153 | self.lut = yaml.load(fp, yaml.SafeLoader)
154 |
155 | @staticmethod
156 | def repr_shape(shape):
157 | if isinstance(shape, (list, tuple)):
158 | return 'x'.join(str(_) for _ in shape)
159 | elif isinstance(shape, str):
160 | return shape
161 | else:
162 | return TypeError
163 |
164 | def predict(self, ltype: str, _input, output, expand=None,
165 | kernel=None, stride=None, idskip=None, se=None):
166 | """
167 | :param ltype:
168 | Layer type must be one of the followings
169 | 1. `first_conv`: The initial stem 3x3 conv with stride 2
170 | 2. `final_expand_layer`: (Only for MobileNet-V3)
171 | The upsample 1x1 conv that increases num_filters by 6 times + GAP.
172 | 3. 'feature_mix_layer':
173 | The upsample 1x1 conv that increase num_filters to num_features + torch.squeeze
174 | 3. `classifier`: fully connected linear layer (num_features to num_classes)
175 | 4. `MBConv`: MobileInvertedResidual
176 | :param _input: input shape (h, w, #channels)
177 | :param output: output shape (h, w, #channels)
178 | :param expand: expansion ratio
179 | :param kernel: kernel size
180 | :param stride:
181 | :param idskip: indicate whether has the residual connection
182 | :param se: indicate whether has squeeze-and-excitation
183 | """
184 | infos = [ltype, 'input:%s' % self.repr_shape(_input),
185 | 'output:%s' % self.repr_shape(output), ]
186 | if ltype in ('MBConv',):
187 | assert None not in (expand, kernel, stride, idskip, se)
188 | infos += ['expand:%d' % expand, 'kernel:%d' % kernel,
189 | 'stride:%d' % stride, 'idskip:%d' % idskip, 'se:%d' % se]
190 | key = '-'.join(infos)
191 | return self.lut[key]['mean']
192 |
193 |
194 | def look_up_latency(net, lut, resolution=224):
195 | def _half(x, times=1):
196 | for _ in range(times):
197 | x = np.ceil(x / 2)
198 | return int(x)
199 |
200 | predicted_latency = 0
201 |
202 | # first_conv
203 | predicted_latency += lut.predict(
204 | 'first_conv', [resolution, resolution, 3],
205 | [resolution // 2, resolution // 2, net.first_conv.out_channels])
206 |
207 | # final_expand_layer (only for MobileNet V3 models)
208 | input_resolution = _half(resolution, times=5)
209 | predicted_latency += lut.predict(
210 | 'final_expand_layer',
211 | [input_resolution, input_resolution, net.final_expand_layer.in_channels],
212 | [input_resolution, input_resolution, net.final_expand_layer.out_channels]
213 | )
214 |
215 | # feature_mix_layer
216 | predicted_latency += lut.predict(
217 | 'feature_mix_layer',
218 | [1, 1, net.feature_mix_layer.in_channels],
219 | [1, 1, net.feature_mix_layer.out_channels]
220 | )
221 |
222 | # classifier
223 | predicted_latency += lut.predict(
224 | 'classifier',
225 | [net.classifier.in_features],
226 | [net.classifier.out_features]
227 | )
228 |
229 | # blocks
230 | fsize = _half(resolution)
231 | for block in net.blocks:
232 | idskip = 0 if block.config['shortcut'] is None else 1
233 | se = 1 if block.config['mobile_inverted_conv']['use_se'] else 0
234 | stride = block.config['mobile_inverted_conv']['stride']
235 | out_fz = _half(fsize) if stride > 1 else fsize
236 | block_latency = lut.predict(
237 | 'MBConv',
238 | [fsize, fsize, block.config['mobile_inverted_conv']['in_channels']],
239 | [out_fz, out_fz, block.config['mobile_inverted_conv']['out_channels']],
240 | expand=block.config['mobile_inverted_conv']['expand_ratio'],
241 | kernel=block.config['mobile_inverted_conv']['kernel_size'],
242 | stride=stride, idskip=idskip, se=se
243 | )
244 | predicted_latency += block_latency
245 | fsize = out_fz
246 |
247 | return predicted_latency
248 |
249 |
250 | def get_net_info(net, input_shape=(3, 224, 224), measure_latency=None, print_info=True, clean=False, lut=None):
251 | """
252 | Modified from https://github.com/mit-han-lab/once-for-all/blob/
253 | 35ddcb9ca30905829480770a6a282d49685aa282/ofa/imagenet_codebase/utils/pytorch_utils.py#L139
254 | """
255 | from ofa.imagenet_codebase.utils.pytorch_utils import count_parameters, measure_net_latency
256 |
257 | # artificial input data
258 | inputs = torch.randn(1, 3, input_shape[-2], input_shape[-1])
259 |
260 | # move network to GPU if available
261 | if torch.cuda.is_available():
262 | device = torch.device('cuda:0')
263 | net = net.to(device)
264 | cudnn.benchmark = True
265 | inputs = inputs.to(device)
266 |
267 | net_info = {}
268 | if isinstance(net, nn.DataParallel):
269 | net = net.module
270 |
271 | # parameters
272 | net_info['params'] = count_parameters(net)
273 |
274 | # flops
275 | net_info['flops'] = int(profile_macs(copy.deepcopy(net), inputs))
276 |
277 | # latencies
278 | latency_types = [] if measure_latency is None else measure_latency.split('#')
279 |
280 | # print(latency_types)
281 | for l_type in latency_types:
282 | if lut is not None and l_type in lut:
283 | latency_estimator = LatencyEstimator(lut[l_type])
284 | latency = look_up_latency(net, latency_estimator, input_shape[2])
285 | measured_latency = None
286 | else:
287 | latency, measured_latency = measure_net_latency(
288 | net, l_type, fast=False, input_shape=input_shape, clean=clean)
289 | net_info['%s latency' % l_type] = {
290 | 'val': latency,
291 | 'hist': measured_latency
292 | }
293 |
294 | if print_info:
295 | # print(net)
296 | print('Total training params: %.2fM' % (net_info['params'] / 1e6))
297 | print('Total FLOPs: %.2fM' % (net_info['flops'] / 1e6))
298 | for l_type in latency_types:
299 | print('Estimated %s latency: %.3fms' % (l_type, net_info['%s latency' % l_type]['val']))
300 |
301 | return net_info
302 |
--------------------------------------------------------------------------------
/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import argparse
5 | import numpy as np
6 |
7 | import utils
8 | from codebase.networks import NSGANetV2
9 | from codebase.run_manager import get_run_config
10 | from ofa.elastic_nn.networks import OFAMobileNetV3
11 | from ofa.imagenet_codebase.run_manager import RunManager
12 | from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d
13 |
14 | import warnings
15 | warnings.simplefilter("ignore")
16 |
17 | DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1
18 |
19 |
20 | def parse_string_list(string):
21 | if isinstance(string, str):
22 | # convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3]
23 | return list(map(int, string[1:-1].split()))
24 | else:
25 | return string
26 |
27 |
28 | def pad_none(x, depth, max_depth):
29 | new_x, counter = [], 0
30 | for d in depth:
31 | for _ in range(d):
32 | new_x.append(x[counter])
33 | counter += 1
34 | if d < max_depth:
35 | new_x += [None] * (max_depth - d)
36 | return new_x
37 |
38 |
39 | def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None):
40 |
41 | net_info = utils.get_net_info(
42 | net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut)
43 |
44 | gpu_latency, cpu_latency = None, None
45 | for k in net_info.keys():
46 | if 'gpu' in k:
47 | gpu_latency = np.round(net_info[k]['val'], 2)
48 | if 'cpu' in k:
49 | cpu_latency = np.round(net_info[k]['val'], 2)
50 |
51 | return {
52 | 'params': np.round(net_info['params'] / 1e6, 2),
53 | 'flops': np.round(net_info['flops'] / 1e6, 2),
54 | 'gpu': gpu_latency, 'cpu': cpu_latency
55 | }
56 |
57 |
58 | def validate_config(config, max_depth=4):
59 | kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d']
60 |
61 | if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size)
62 | if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio)
63 | if isinstance(depth, str): depth = parse_string_list(depth)
64 |
65 | assert (isinstance(kernel_size, list) or isinstance(kernel_size, int))
66 | assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int))
67 | assert isinstance(depth, list)
68 |
69 | if len(kernel_size) < len(depth) * max_depth:
70 | kernel_size = pad_none(kernel_size, depth, max_depth)
71 | if len(exp_ratio) < len(depth) * max_depth:
72 | exp_ratio = pad_none(exp_ratio, depth, max_depth)
73 |
74 | # return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']}
75 | return {'ks': kernel_size, 'e': exp_ratio, 'd': depth}
76 |
77 |
78 | class OFAEvaluator:
79 | """ based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """
80 | def __init__(self,
81 | n_classes=1000,
82 | model_path='./data/ofa_mbv3_d234_e346_k357_w1.0',
83 | kernel_size=None, exp_ratio=None, depth=None):
84 | # default configurations
85 | self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size
86 | self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate
87 | self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition
88 |
89 | if 'w1.0' in model_path:
90 | self.width_mult = 1.0
91 | elif 'w1.2' in model_path:
92 | self.width_mult = 1.2
93 | else:
94 | raise ValueError
95 |
96 | self.engine = OFAMobileNetV3(
97 | n_classes=n_classes,
98 | dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size,
99 | expand_ratio_list=self.exp_ratio, depth_list=self.depth)
100 |
101 | init = torch.load(model_path, map_location='cpu')['state_dict']
102 | self.engine.load_weights_from_net(init)
103 |
104 | def sample(self, config=None):
105 | """ randomly sample a sub-network """
106 | if config is not None:
107 | config = validate_config(config)
108 | self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
109 | else:
110 | config = self.engine.sample_active_subnet()
111 |
112 | subnet = self.engine.get_active_subnet(preserve_weight=True)
113 | return subnet, config
114 |
115 | @staticmethod
116 | def save_net_config(path, net, config_name='net.config'):
117 | """ dump run_config and net_config to the model_folder """
118 | net_save_path = os.path.join(path, config_name)
119 | json.dump(net.config, open(net_save_path, 'w'), indent=4)
120 | print('Network configs dump to %s' % net_save_path)
121 |
122 | @staticmethod
123 | def save_net(path, net, model_name):
124 | """ dump net weight as checkpoint """
125 | if isinstance(net, torch.nn.DataParallel):
126 | checkpoint = {'state_dict': net.module.state_dict()}
127 | else:
128 | checkpoint = {'state_dict': net.state_dict()}
129 | model_path = os.path.join(path, model_name)
130 | torch.save(checkpoint, model_path)
131 | print('Network model dump to %s' % model_path)
132 |
133 | @staticmethod
134 | def eval(subnet, data_path, dataset='imagenet', n_epochs=0, resolution=224, trn_batch_size=128, vld_batch_size=250,
135 | num_workers=4, valid_size=None, is_test=True, log_dir='.tmp/eval', measure_latency=None, no_logs=False,
136 | reset_running_statistics=True):
137 |
138 | lut = {'cpu': 'data/i7-8700K_lut.yaml'}
139 |
140 | info = get_net_info(
141 | subnet, (3, resolution, resolution), measure_latency=measure_latency,
142 | print_info=False, clean=True, lut=lut)
143 |
144 | run_config = get_run_config(
145 | dataset=dataset, data_path=data_path, image_size=resolution, n_epochs=n_epochs,
146 | train_batch_size=trn_batch_size, test_batch_size=vld_batch_size,
147 | n_worker=num_workers, valid_size=valid_size)
148 |
149 | # set the image size. You can set any image size from 192 to 256 here
150 | run_config.data_provider.assign_active_img_size(resolution)
151 |
152 | if n_epochs > 0:
153 | # for datasets other than the one supernet was trained on (ImageNet)
154 | # a few epochs of training need to be applied
155 | subnet.reset_classifier(
156 | last_channel=subnet.classifier.in_features,
157 | n_classes=run_config.data_provider.n_classes, dropout_rate=cfgs.drop_rate)
158 |
159 | run_manager = RunManager(log_dir, subnet, run_config, init=False)
160 | if reset_running_statistics:
161 | # run_manager.reset_running_statistics(net=subnet, batch_size=vld_batch_size)
162 | run_manager.reset_running_statistics(net=subnet)
163 |
164 | if n_epochs > 0:
165 | subnet = run_manager.train(cfgs)
166 |
167 | loss, top1, top5 = run_manager.validate(net=subnet, is_test=is_test, no_logs=no_logs)
168 |
169 | info['loss'], info['top1'], info['top5'] = loss, top1, top5
170 |
171 | save_path = os.path.join(log_dir, 'net.stats') if cfgs.save is None else cfgs.save
172 | if cfgs.save_config:
173 | OFAEvaluator.save_net_config(log_dir, subnet, "net.config")
174 | OFAEvaluator.save_net(log_dir, subnet, "net.init")
175 | with open(save_path, 'w') as handle:
176 | json.dump(info, handle)
177 |
178 | print(info)
179 |
180 |
181 | def main(args):
182 | """ one evaluation of a subnet or a config from a file """
183 | mode = 'subnet'
184 | if args.config is not None:
185 | if args.init is not None:
186 | mode = 'config'
187 |
188 | print('Evaluation mode: {}'.format(mode))
189 | if mode == 'config':
190 | net_config = json.load(open(args.config))
191 | subnet = NSGANetV2.build_from_config(net_config, drop_connect_rate=args.drop_connect_rate)
192 | init = torch.load(args.init, map_location='cpu')['state_dict']
193 | subnet.load_state_dict(init)
194 | subnet.classifier.dropout_rate = args.drop_rate
195 | try:
196 | resolution = net_config['resolution']
197 | except KeyError:
198 | resolution = args.resolution
199 |
200 | elif mode == 'subnet':
201 | config = json.load(open(args.subnet))
202 | evaluator = OFAEvaluator(n_classes=args.n_classes, model_path=args.supernet_path)
203 | subnet, _ = evaluator.sample({'ks': config['ks'], 'e': config['e'], 'd': config['d']})
204 | resolution = config['r']
205 |
206 | else:
207 | raise NotImplementedError
208 |
209 | OFAEvaluator.eval(
210 | subnet, log_dir=args.log_dir, data_path=args.data, dataset=args.dataset, n_epochs=args.n_epochs,
211 | resolution=resolution, trn_batch_size=args.trn_batch_size, vld_batch_size=args.vld_batch_size,
212 | num_workers=args.num_workers, valid_size=args.valid_size, is_test=args.test, measure_latency=args.latency,
213 | no_logs=(not args.verbose), reset_running_statistics=args.reset_running_statistics)
214 |
215 |
216 | if __name__ == '__main__':
217 | parser = argparse.ArgumentParser()
218 | parser.add_argument('--data', type=str, default='/mnt/datastore/ILSVRC2012',
219 | help='location of the data corpus')
220 | parser.add_argument('--log_dir', type=str, default='.tmp',
221 | help='directory for logging')
222 | parser.add_argument('--dataset', type=str, default='imagenet',
223 | help='name of the dataset (imagenet, cifar10, cifar100, ...)')
224 | parser.add_argument('--n_classes', type=int, default=1000,
225 | help='number of classes for the given dataset')
226 | parser.add_argument('--supernet_path', type=str, default='./data/ofa_mbv3_d234_e346_k357_w1.0',
227 | help='file path to supernet weights')
228 | parser.add_argument('--subnet', type=str, default=None,
229 | help='location of a json file of ks, e, d, and e')
230 | parser.add_argument('--config', type=str, default=None,
231 | help='location of a json file of specific model declaration')
232 | parser.add_argument('--init', type=str, default=None,
233 | help='location of initial weight to load')
234 | parser.add_argument('--trn_batch_size', type=int, default=128,
235 | help='test batch size for inference')
236 | parser.add_argument('--vld_batch_size', type=int, default=256,
237 | help='test batch size for inference')
238 | parser.add_argument('--num_workers', type=int, default=6,
239 | help='number of workers for data loading')
240 | parser.add_argument('--n_epochs', type=int, default=0,
241 | help='number of training epochs')
242 | parser.add_argument('--save', type=str, default=None,
243 | help='location to save the evaluated metrics')
244 | parser.add_argument('--resolution', type=int, default=224,
245 | help='input resolution (192 -> 256)')
246 | parser.add_argument('--valid_size', type=int, default=None,
247 | help='validation set size, randomly sampled from training set')
248 | parser.add_argument('--test', action='store_true', default=False,
249 | help='evaluation performance on testing set')
250 | parser.add_argument('--latency', type=str, default=None,
251 | help='latency measurement settings (gpu64#cpu)')
252 | parser.add_argument('--verbose', action='store_true', default=False,
253 | help='whether to display evaluation progress')
254 | parser.add_argument('--reset_running_statistics', action='store_true', default=False,
255 | help='reset the running mean / std of BN')
256 | parser.add_argument('--drop_rate', type=float, default=0.2,
257 | help='dropout rate')
258 | parser.add_argument('--drop_connect_rate', type=float, default=0.0,
259 | help='connection dropout rate')
260 | parser.add_argument('--save_config', action='store_true', default=False,
261 | help='save config file')
262 | cfgs = parser.parse_args()
263 |
264 | cfgs.teacher_model = None
265 |
266 | main(cfgs)
267 |
268 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NSGANetV2: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search [[slides]](https://www.zhichaolu.com/assets/nsganetv2/0343-long.pdf)[[arXiv]](https://arxiv.org/abs/2007.10396)
2 | ```BibTex
3 | @inproceedings{
4 | lu2020nsganetv2,
5 | title={{NSGANetV2}: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search},
6 | author={Zhichao Lu and Kalyanmoy Deb and Erik Goodman and Wolfgang Banzhaf and Vishnu Naresh Boddeti},
7 | booktitle={European Conference on Computer Vision (ECCV)},
8 | year={2020}
9 | }
10 | ```
11 |
12 | ## Overview
13 |
14 | NSGANetV2 is an efficient NAS algorithm for generating task-specific models that are competitive under multiple competing objectives. It comprises of two surrogates, one at the architecture level to improve sample efficiency and one at the weights level, through a supernet, to improve gradient descent training efficiency.
15 |
16 | ## Datasets
17 | Download the datasets from the links embedded in the names. Datasets with * can be automatically downloaded.
18 |
19 | | Dataset | Type | Train Size | Test Size | #Classes |
20 | |:-:|:-:|:-:|:-:|:-:|
21 | | [ImageNet](http://www.image-net.org/) | multi-class | 1,281,167 | 50,000 | 1,000 |
22 | | [CINIC-10](https://github.com/BayesWatch/cinic-10) | | 180,000 | 9,000 | 10 |
23 | | [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* | | 50,000 | 10,000 | 10 |
24 | | [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html)* | | 50,000 | 10,000 | 10 |
25 | | [STL-10](https://ai.stanford.edu/~acoates/stl10/)* | | 5,000 | 8,000 | 10 |
26 | | [FGVC Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)* | fine-grained | 6,667 | 3,333 | 100 |
27 | | [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) | | 3,760 | 1,880 | 47 |
28 | | [Oxford-IIIT Pets](https://www.robots.ox.ac.uk/~vgg/data/pets/) | | 3,680 | 3,369 | 37 |
29 | | [Oxford Flowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) | | 2,040 | 6,149 | 102 |
30 |
31 | ## How to evalute NSGANetV2 models
32 | Download the models (`net.config`) and weights (`net.init`) from [[Google Drive]](https://drive.google.com/drive/folders/1owwmRNYQc8hIKOOFCYFCl2dukF5y1o-d?usp=sharing) or [[Baidu Yun]](https://pan.baidu.com/s/126FysPlOVTtDb6GQgx7eaA)(提取码:4isq).
33 | ```python
34 | """ NSGANetV2 pretrained models
35 | Syntax: python validation.py \
36 | --dataset [imagenet/cifar10/...] --data /path/to/data \
37 | --model /path/to/model/config/file --pretrained /path/to/model/weights
38 | """
39 | ```
40 | ImageNet | CIFAR-10 | CINIC10
41 | :-------------------------:|:-------------------------:|:-------------------------:
42 |  |  | 
43 | FLOPs@225: [[Google Drive]](https://drive.google.com/drive/folders/1jNvM_JgWBcnIR8mLDmW_0Q6H_EMlZ_97?usp=sharing)
FLOPs@312: [[Google Drive]](https://drive.google.com/drive/folders/1w2bRfAGCcDWjVzIPwn9Olqa8sbC5UpjZ?usp=sharing)
FLOPs@400: [[Google Drive]](https://drive.google.com/drive/folders/1vwOUzO6iNfyVNi3hgWSx4TANebrCkhri?usp=sharing)
FLOPs@593: [[Google Drive]](https://drive.google.com/drive/folders/1r6jk3ausee2PqHsOqessnmmS_3RZYvhY?usp=sharing) | FLOPs@232: [[Google Drive]](https://drive.google.com/drive/folders/1-bdVtZdzTCTPWYS4B-03qaoQycYs266Z?usp=sharing)
FLOPs@291: [[Google Drive]](https://drive.google.com/drive/folders/1SFuM5zSeqJjSxrNxLEzmixzaU4MvM19t?usp=sharing)
FLOPs@392: [[Google Drive]](https://drive.google.com/drive/folders/1sHrTlAB6fQVnAzOIOLUhKWcOQNusTfot?usp=sharing)
FLOPs@468: [[Google Drive]](https://drive.google.com/drive/folders/1F5Ar8h3DbpV1WmDfhrtPKmpuQYLFVXq-?usp=sharing) | FLOPs@317: [[Google Drive]](https://drive.google.com/drive/folders/1J_vOo1ym9LlATtAQXSehCQt2aJx2Hfm-?usp=sharing)
FLOPs@411: [[Google Drive]](https://drive.google.com/drive/folders/1tzjSn_wF9AwiArV7rusCOylu6IxZnCXq?usp=sharing)
FLOPs@501: [[Google Drive]](https://drive.google.com/drive/folders/1_u7ZCqMr_3HeOQJfbsq-5oW93ZqNpNpS?usp=sharing)
FLOPs@710: [[Google Drive]](https://drive.google.com/drive/folders/1BBnPx-nEPtGNMeixSYkGb2omW-QHO6o_?usp=sharing)
44 |
45 | Flowers102 | Aircraft | Oxford-IIIT Pets
46 | :-------------------------:|:-------------------------:|:-------------------------:
47 |  |  | 
48 | FLOPs@151: [[Google Drive]](https://drive.google.com/drive/folders/12yaZcl_wEYyUkRbLsrTotYtFjlwjIY7g?usp=sharing)
FLOPs@218: [[Google Drive]](https://drive.google.com/drive/folders/1i2nWhz0rMeRSLjZ-JcWylssdEQCX_EfT?usp=sharing)
FLOPs@249: [[Google Drive]](https://drive.google.com/drive/folders/1maqXm9yRv69tbElGKqHbZVIZXuSQ0uRy?usp=sharing)
FLOPs@317: [[Google Drive]](https://drive.google.com/drive/folders/14HK-zEYKr5sySmKbFRbEaQ66IX88e4uC?usp=sharing) | FLOPs@176: [[Google Drive]](https://drive.google.com/drive/folders/1q8AcI-zU_z3PIMjbqUEgszvAwL1p-n-5?usp=sharing)
FLOPs@271: [[Google Drive]](https://drive.google.com/drive/folders/1VIeOJXaVk6NXrNkGkcShIKP_fEfaZv3Y?usp=sharing)
FLOPs@331: [[Google Drive]](https://drive.google.com/drive/folders/1QQ4Lnpoe2obCnf0r6O5yfF3Rh4llZmTg?usp=sharing)
FLOPs@502: [[Google Drive]](https://drive.google.com/drive/folders/1yM4nmZMSz7EcB3xHX4QxKV7Ifcpk_KRS?usp=sharing) | FLOPs@137: [[Google Drive]](https://drive.google.com/drive/folders/1TWNkSBUQTrpq8IU6eRTHVoVtBtdjThoQ?usp=sharing)
FLOPs@189: [[Google Drive]](https://drive.google.com/drive/folders/1Xu-Dh6rDBS608a2p0ebti6mBs33GUSsc?usp=sharing)
FLOPs@284: [[Google Drive]](https://drive.google.com/drive/folders/1sNV4FquZifPw-VonjWmogS3LojmwPPd1?usp=sharing)
FLOPs@391: [[Google Drive]](https://drive.google.com/drive/folders/1E5PYuI-kb7Gr308kAzIwpo9fzFO4TXVQ?usp=sharing)
49 |
50 | CIFAR-100 | DTD | STL-10
51 | :-------------------------:|:-------------------------:|:-------------------------:
52 |  |  | 
53 | FLOPs@261: [[Google Drive]](https://drive.google.com/drive/folders/1RlnOUhKKaexpwG2pSWiOGSZVGedbf9jd?usp=sharing)
FLOPs@398: [[Google Drive]](https://drive.google.com/drive/folders/1aqbQZvX-Zr6OgNi1-k_x-YzWBnzZyL92?usp=sharing)
FLOPs@492: [[Google Drive]](https://drive.google.com/drive/folders/1aqbQZvX-Zr6OgNi1-k_x-YzWBnzZyL92?usp=sharing)
FLOPs@796: [[Google Drive]](https://drive.google.com/drive/folders/1PJE6rtoJoKChXhw40PHLrT6tjy276ohJ?usp=sharing) | FLOPs@123: [[Google Drive]](https://drive.google.com/drive/folders/1isUPK0pHjVmXqDUqXgR3JOvP3zmq-PgJ?usp=sharing)
FLOPs@164: [[Google Drive]](https://drive.google.com/drive/folders/1pRGHy-bbw9Gz2g2QHau4z7dNFOdKT_w2?usp=sharing)
FLOPs@202: [[Google Drive]](https://drive.google.com/drive/folders/1CsGdcIv79rTDQbXSEJEgTOYZ6ViRwz1S?usp=sharing)
FLOPs@213: [[Google Drive]](https://drive.google.com/drive/folders/1MJVapbO7V5fT5Vjf1G0wjkgCOgkLKWEa?usp=sharing) | FLOPs@240: [[Google Drive]](https://drive.google.com/drive/folders/1tHOiEJhOpplOPnNqaDkHGPGfD5YRZrsJ?usp=sharing)
FLOPs@303: [[Google Drive]](https://drive.google.com/drive/folders/1IfFeg0CkPEFNN3iseOyeR1u-sozgSaya?usp=sharing)
FLOPs@436: [[Google Drive]](https://drive.google.com/drive/folders/13naG888dNouXj8Bd-gL4pMzrskDwdBC8?usp=sharing)
FLOPs@573: [[Google Drive]](https://drive.google.com/drive/folders/1iig97a1Xr-K40xbgwgBvMOPe8aLaqW2a?usp=sharing)
54 |
55 | ## How to use MSuNAS to search
56 | ```python
57 | """ Bi-objective search
58 | Syntax: python msunas.py \
59 | --dataset [imagenet/cifar10/...] --data /path/to/dataset/images \
60 | --save search-xxx \ # dir to save search results
61 | --sec_obj [params/flops/cpu] \ # objective (in addition to top-1 acc)
62 | --n_gpus 8 \ # number of available gpus
63 | --supernet_path /path/to/supernet/weights \
64 | --vld_size [10000/5000/...] \ # number of subset images from training set to guide search
65 | --n_epochs [0/5]
66 | """
67 | ```
68 | - Download the pre-trained (on ImageNet) supernet from [here](https://hanlab.mit.edu/files/OnceForAll/ofa_nets/).
69 | - It supports searching for *FLOPs*, *Params*, and *Latency* as the second objective.
70 | - To optimize latency on your own device, you need to first construct a `look-up-table` for your own device, like [this](https://github.com/mikelzc1990/nsganetv2/blob/master/data/i7-8700K_lut.yaml).
71 | - Choose an appropriate `--vld_size` to guide the search, e.g. 10,000 for ImageNet, 5,000 for CIFAR-10/100.
72 | - Set `--n_epochs` to `0` for ImageNet and `5` for all other datasets.
73 | - See [here](https://github.com/mikelzc1990/nsganetv2/blob/master/scripts/search.sh) for some examples.
74 | - Output file structure:
75 | - Every architecture sampled during search has `net_x.subnet` and `net_x.stats` stored in the corresponding iteration dir.
76 | - A stats file is generated by the end of each iteration, `iter_x.stats`; it stores every architectures evaluated so far in `["archive"]`, and iteration-wise statistics, e.g. hypervolume in `["hv"]`, accuracy predictor related in `["surrogate"]`.
77 | - In case any architectures failed to evaluate during search, you may re-visit them in `failed` sub-dir under experiment dir.
78 |
79 | ImageNet | CIFAR-10
80 | :-------------------------:|:-------------------------:
81 |  | 
82 |
83 | ## How to choose architectures
84 | Once the search is completed, you can choose suitable architectures by:
85 | - You have preferences, e.g. architectures with xx.x% top-1 acc. and xxxM FLOPs, etc.
86 | ```python
87 | """ Find architectures with objectives close to your preferences
88 | Syntax: python post_search.py \
89 | -n 3 \ # number of desired architectures you want, the most accurate archecture will always be selected
90 | --save search-imagenet/final \ # path to the dir to store the selected architectures
91 | --expr search-imagenet/iter_30.stats \ # path to last iteration stats file in experiment dir
92 | --prefer top1#80+flops#150 \ # your preferences, i.e. you want an architecture with 80% top-1 acc. and 150M FLOPs
93 | --supernet_path /path/to/imagenet/supernet/weights \
94 | """
95 | ```
96 | - If you do not have preferences, pass `None` to argument `--prefer`, architectures will then be selected based on trade-offs.
97 | - All selected architectures should have three files created:
98 | - `net.subnet`: use to sample the architecture from the supernet
99 | - `net.config`: configuration file that defines the full architectural components
100 | - `net.inherited`: the inherited weights from supernet
101 |
102 | ## How to validate architectures
103 | To realize the full potential of the searched architectures, we further fine-tune from the inherited weights. Assuming that you have both `net.config` and `net.inherited` files.
104 | ```python
105 | """ Fine-tune on ImageNet from inherited weights
106 | Syntax: sh scripts/distributed_train.sh 8 \ # of available gpus
107 | /path/to/imagenet/data/ \
108 | --model [nsganetv2_s/nsganetv2_m/...] \ # just for naming the output dir
109 | --model-config /path/to/model/.config/file \
110 | --initial-checkpoint /path/to/model/.inherited/file \
111 | --img-size [192, ..., 224, ..., 256] \ # image resolution, check "r" in net.subnet
112 | -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 \
113 | --opt rmsproptf --opt-eps .001 -j 6 --warmup-lr 1e-6 \
114 | --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 \
115 | --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .024 \
116 | --teacher /path/to/supernet/weights \
117 | """
118 | ```
119 | - Adjust learning rate as `(batch_size_per_gpu * #GPUs / 256) * 0.006` depending on your system config.
120 | ```python
121 | """ Fine-tune on CIFAR-10 from inherited weights
122 | Syntax: python train_cifar.py \
123 | --data /path/to/CIFAR-10/data/ \
124 | --model [nsganetv2_s/nsganetv2_m/...] \ # just for naming the output dir
125 | --model-config /path/to/model/.config/file \
126 | --img-size [192, ..., 224, ..., 256] \ # image resolution, check "r" in net.subnet
127 | --drop 0.2 --drop-path 0.2 \
128 | --cutout --autoaugment --save
129 | """
130 | ```
131 |
132 | ## More Use Cases (coming soon)
133 | - [ ] With a different supernet (search space).
134 | - [ ] NASBench 101/201.
135 | - [ ] Architecture visualization.
136 |
137 | ## Requirements
138 | - Python 3.7
139 | - Cython 0.29 (optional; makes `non_dominated_sorting` faster in pymoo)
140 | - PyTorch 1.5.1
141 | - [pymoo](https://github.com/msu-coinlab/pymoo) 0.4.1
142 | - [torchprofile](https://github.com/zhijian-liu/torchprofile) 0.0.1 (for FLOPs calculation)
143 | - [OnceForAll](https://github.com/mit-han-lab/once-for-all) 0.0.4 (lower level supernet)
144 | - [timm](https://github.com/rwightman/pytorch-image-models) 0.1.30
145 | - [pySOT](https://github.com/dme65/pySOT) 0.2.3 (RBF surrogate model)
146 | - [pydacefit](https://github.com/msu-coinlab/pydacefit) 1.0.1 (GP surrogate model)
147 |
--------------------------------------------------------------------------------
/train_cifar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import copy
5 | import logging
6 | import argparse
7 | import numpy as np
8 | from datetime import datetime
9 |
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torchvision.utils
14 | import torch.optim as optim
15 | import torch.nn.functional as F
16 | import torchvision.transforms as transforms
17 |
18 | from codebase.data_providers.autoaugment import CIFAR10Policy
19 |
20 | from evaluator import OFAEvaluator
21 | from torchprofile import profile_macs
22 | from codebase.networks import NSGANetV2
23 |
24 |
25 | torch.backends.cudnn.benchmark = True
26 |
27 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
28 | parser.add_argument('--seed', type=int, default=None, help='random seed')
29 | parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
30 | parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, cifar100, or cinic10')
31 | parser.add_argument('--batch-size', type=int, default=96, help='batch size')
32 | parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading')
33 | parser.add_argument('--n_gpus', type=int, default=1, help='number of available gpus for training')
34 | parser.add_argument('--lr', type=float, default=0.01, help='init learning rate')
35 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
36 | parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay')
37 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
38 | parser.add_argument('--epochs', type=int, default=150, help='num of training epochs')
39 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
40 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
41 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
42 | parser.add_argument('--autoaugment', action='store_true', default=False, help='use auto augmentation')
43 | parser.add_argument('--save', action='store_true', default=False, help='dump output')
44 | parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save')
45 | parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model')
46 | # model related
47 | parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
48 | help='Name of model to train (default: "countception"')
49 | parser.add_argument('--model-config', type=str, default=None,
50 | help='location of a json file of specific model declaration')
51 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
52 | help='Initialize model from this checkpoint (default: none)')
53 | parser.add_argument('--drop', type=float, default=0.2,
54 | help='dropout rate')
55 | parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
56 | help='Drop path rate (default: None)')
57 | parser.add_argument('--img-size', type=int, default=224,
58 | help='input resolution (192 -> 256)')
59 | args = parser.parse_args()
60 |
61 | dataset = args.dataset
62 |
63 | log_format = '%(asctime)s %(message)s'
64 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
65 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
66 |
67 | if args.save:
68 | args.save = '-'.join([
69 | datetime.now().strftime("%Y%m%d-%H%M%S"),
70 | args.dataset,
71 | args.model,
72 | str(args.img_size)
73 | ])
74 |
75 | if not os.path.exists(args.save):
76 | os.makedirs(args.save, exist_ok=True)
77 | print('Experiment dir : {}'.format(args.save))
78 |
79 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
80 | fh.setFormatter(logging.Formatter(log_format))
81 | logging.getLogger().addHandler(fh)
82 |
83 | device = 'cuda'
84 |
85 | NUM_CLASSES = 100 if 'cifar100' in dataset else 10
86 |
87 |
88 | def main():
89 | if not torch.cuda.is_available():
90 | logging.info('no gpu device available')
91 | sys.exit(1)
92 |
93 | logging.info("args = %s", args)
94 |
95 | if args.seed is not None:
96 | np.random.seed(args.seed)
97 | torch.manual_seed(args.seed)
98 |
99 | best_acc = 0 # initiate a artificial best accuracy so far
100 | top_checkpoints = [] # initiate a list to keep track of
101 |
102 | # Data
103 | train_transform, valid_transform = _data_transforms(args)
104 | if dataset == 'cifar100':
105 | train_data = torchvision.datasets.CIFAR100(
106 | root=args.data, train=True, download=True, transform=train_transform)
107 | valid_data = torchvision.datasets.CIFAR100(
108 | root=args.data, train=False, download=True, transform=valid_transform)
109 | elif dataset == 'cifar10':
110 | train_data = torchvision.datasets.CIFAR10(
111 | root=args.data, train=True, download=True, transform=train_transform)
112 | valid_data = torchvision.datasets.CIFAR10(
113 | root=args.data, train=False, download=True, transform=valid_transform)
114 | elif dataset == 'cinic10':
115 | train_data = torchvision.datasets.ImageFolder(
116 | args.data + 'train_and_valid', transform=train_transform)
117 | valid_data = torchvision.datasets.ImageFolder(
118 | args.data + 'test', transform=valid_transform)
119 | else:
120 | raise KeyError
121 |
122 | train_queue = torch.utils.data.DataLoader(
123 | train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers)
124 |
125 | valid_queue = torch.utils.data.DataLoader(
126 | valid_data, batch_size=200, shuffle=False, pin_memory=True, num_workers=args.num_workers)
127 |
128 | net_config = json.load(open(args.model_config))
129 | net = NSGANetV2.build_from_config(net_config, drop_connect_rate=args.drop_path)
130 | init = torch.load(args.initial_checkpoint, map_location='cpu')['state_dict']
131 | net.load_state_dict(init)
132 |
133 | NSGANetV2.reset_classifier(
134 | net, last_channel=net.classifier.in_features,
135 | n_classes=NUM_CLASSES, dropout_rate=args.drop)
136 |
137 | # calculate #Paramaters and #FLOPS
138 | inputs = torch.randn(1, 3, args.img_size, args.img_size)
139 | flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
140 | params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
141 | net_name = "net_flops@{:.0f}".format(flops)
142 | logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
143 |
144 | if args.n_gpus > 1:
145 | net = nn.DataParallel(net) # data parallel in case more than 1 gpu available
146 |
147 | net = net.to(device)
148 |
149 | n_epochs = args.epochs
150 |
151 | parameters = filter(lambda p: p.requires_grad, net.parameters())
152 |
153 | criterion = nn.CrossEntropyLoss().to(device)
154 |
155 | optimizer = optim.SGD(parameters,
156 | lr=args.lr,
157 | momentum=args.momentum,
158 | weight_decay=args.weight_decay)
159 |
160 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
161 |
162 | if args.evaluate:
163 | infer(valid_queue, net, criterion)
164 | sys.exit(0)
165 |
166 | for epoch in range(n_epochs):
167 |
168 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
169 |
170 | train(train_queue, net, criterion, optimizer)
171 | _, valid_acc = infer(valid_queue, net, criterion)
172 |
173 | # checkpoint saving
174 | if args.save:
175 | if len(top_checkpoints) < args.topk:
176 | OFAEvaluator.save_net(args.save, net, net_name+'.ckpt{}'.format(epoch))
177 | top_checkpoints.append((os.path.join(args.save, net_name+'.ckpt{}'.format(epoch)), valid_acc))
178 | else:
179 | idx = np.argmin([x[1] for x in top_checkpoints])
180 | if valid_acc > top_checkpoints[idx][1]:
181 | OFAEvaluator.save_net(args.save, net, net_name + '.ckpt{}'.format(epoch))
182 | top_checkpoints.append((os.path.join(args.save, net_name+'.ckpt{}'.format(epoch)), valid_acc))
183 | # remove the idx
184 | os.remove(top_checkpoints[idx][0])
185 | top_checkpoints.pop(idx)
186 | print(top_checkpoints)
187 |
188 | if valid_acc > best_acc:
189 | OFAEvaluator.save_net(args.save, net, net_name + '.best')
190 | best_acc = valid_acc
191 |
192 | scheduler.step()
193 |
194 | OFAEvaluator.save_net_config(args.save, net, net_name+'.config')
195 |
196 |
197 | # Training
198 | def train(train_queue, net, criterion, optimizer):
199 | net.train()
200 | train_loss = 0
201 | correct = 0
202 | total = 0
203 |
204 | for step, (inputs, targets) in enumerate(train_queue):
205 | # upsample by bicubic to match imagenet training size
206 | inputs = F.interpolate(inputs, size=args.img_size, mode='bicubic', align_corners=False)
207 | inputs, targets = inputs.to(device), targets.to(device)
208 | optimizer.zero_grad()
209 | outputs = net(inputs)
210 | loss = criterion(outputs, targets)
211 |
212 | loss.backward()
213 | nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
214 | optimizer.step()
215 |
216 | train_loss += loss.item()
217 | _, predicted = outputs.max(1)
218 | total += targets.size(0)
219 | correct += predicted.eq(targets).sum().item()
220 |
221 | if step % args.report_freq == 0:
222 | logging.info('train %03d %e %f', step, train_loss/total, 100.*correct/total)
223 |
224 | logging.info('train acc %f', 100. * correct / total)
225 |
226 | return train_loss/total, 100.*correct/total
227 |
228 |
229 | def infer(valid_queue, net, criterion):
230 | net.eval()
231 | test_loss = 0
232 | correct = 0
233 | total = 0
234 |
235 | with torch.no_grad():
236 | for step, (inputs, targets) in enumerate(valid_queue):
237 | inputs, targets = inputs.to(device), targets.to(device)
238 | outputs = net(inputs)
239 | loss = criterion(outputs, targets)
240 |
241 | test_loss += loss.item()
242 | _, predicted = outputs.max(1)
243 | total += targets.size(0)
244 | correct += predicted.eq(targets).sum().item()
245 |
246 | if step % args.report_freq == 0:
247 | logging.info('valid %03d %e %f', step, test_loss/total, 100.*correct/total)
248 |
249 | acc = 100.*correct/total
250 | logging.info('valid acc %f', 100. * correct / total)
251 |
252 | return test_loss/total, acc
253 |
254 |
255 | class Cutout(object):
256 | def __init__(self, length):
257 | self.length = length
258 |
259 | def __call__(self, img):
260 | h, w = img.size(1), img.size(2)
261 | mask = np.ones((h, w), np.float32)
262 | y = np.random.randint(h)
263 | x = np.random.randint(w)
264 |
265 | y1 = np.clip(y - self.length // 2, 0, h)
266 | y2 = np.clip(y + self.length // 2, 0, h)
267 | x1 = np.clip(x - self.length // 2, 0, w)
268 | x2 = np.clip(x + self.length // 2, 0, w)
269 |
270 | mask[y1: y2, x1: x2] = 0.
271 | mask = torch.from_numpy(mask)
272 | mask = mask.expand_as(img)
273 | img *= mask
274 | return img
275 |
276 |
277 | def _data_transforms(args):
278 |
279 | if 'cifar' in args.dataset:
280 | norm_mean = [0.49139968, 0.48215827, 0.44653124]
281 | norm_std = [0.24703233, 0.24348505, 0.26158768]
282 | elif 'cinic' in args.dataset:
283 | norm_mean = [0.47889522, 0.47227842, 0.43047404]
284 | norm_std = [0.24205776, 0.23828046, 0.25874835]
285 | else:
286 | raise KeyError
287 |
288 | train_transform = transforms.Compose([
289 | transforms.RandomCrop(32, padding=4),
290 | # transforms.Resize(224, interpolation=3), # BICUBIC interpolation
291 | transforms.RandomHorizontalFlip(),
292 | ])
293 |
294 | if args.autoaugment:
295 | train_transform.transforms.append(CIFAR10Policy())
296 |
297 | train_transform.transforms.append(transforms.ToTensor())
298 |
299 | if args.cutout:
300 | train_transform.transforms.append(Cutout(args.cutout_length))
301 |
302 | train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std))
303 |
304 | valid_transform = transforms.Compose([
305 | transforms.Resize(args.img_size, interpolation=3), # BICUBIC interpolation
306 | transforms.ToTensor(),
307 | transforms.Normalize(norm_mean, norm_std),
308 | ])
309 | return train_transform, valid_transform
310 |
311 |
312 | if __name__ == '__main__':
313 | main()
314 |
--------------------------------------------------------------------------------
/codebase/run_manager/__init__.py:
--------------------------------------------------------------------------------
1 | from codebase.data_providers.imagenet import *
2 | from codebase.data_providers.cifar import *
3 | from codebase.data_providers.flowers102 import *
4 | from codebase.data_providers.stl10 import *
5 | from codebase.data_providers.dtd import *
6 | from codebase.data_providers.pets import *
7 | from codebase.data_providers.aircraft import *
8 |
9 | from ofa.imagenet_classification.run_manager.run_config import RunConfig
10 |
11 |
12 | class ImagenetRunConfig(RunConfig):
13 |
14 | def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None,
15 | dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None,
16 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
17 | mixup_alpha=None,
18 | model_init='he_fout', validation_frequency=1, print_frequency=10,
19 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
20 | data_path='/mnt/datastore/ILSVRC2012',
21 | **kwargs):
22 | super(ImagenetRunConfig, self).__init__(
23 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
24 | dataset, train_batch_size, test_batch_size, valid_size,
25 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
26 | mixup_alpha,
27 | model_init, validation_frequency, print_frequency
28 | )
29 | self.n_worker = n_worker
30 | self.resize_scale = resize_scale
31 | self.distort_color = distort_color
32 | self.image_size = image_size
33 | self.imagenet_data_path = data_path
34 |
35 | @property
36 | def data_provider(self):
37 | if self.__dict__.get('_data_provider', None) is None:
38 | if self.dataset == ImagenetDataProvider.name():
39 | DataProviderClass = ImagenetDataProvider
40 | else:
41 | raise NotImplementedError
42 | self.__dict__['_data_provider'] = DataProviderClass(
43 | save_path=self.imagenet_data_path,
44 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
45 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
46 | distort_color=self.distort_color, image_size=self.image_size,
47 | )
48 | return self.__dict__['_data_provider']
49 |
50 |
51 | class CIFARRunConfig(RunConfig):
52 | def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None,
53 | dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None,
54 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
55 | mixup_alpha=None,
56 | model_init='he_fout', validation_frequency=1, print_frequency=10,
57 | n_worker=2, resize_scale=0.08, distort_color=None, image_size=224,
58 | data_path='/mnt/datastore/CIFAR',
59 | **kwargs):
60 | super(CIFARRunConfig, self).__init__(
61 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
62 | dataset, train_batch_size, test_batch_size, valid_size,
63 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
64 | mixup_alpha,
65 | model_init, validation_frequency, print_frequency
66 | )
67 |
68 | self.n_worker = n_worker
69 | self.resize_scale = resize_scale
70 | self.distort_color = distort_color
71 | self.image_size = image_size
72 | self.cifar_data_path = data_path
73 |
74 | @property
75 | def data_provider(self):
76 | if self.__dict__.get('_data_provider', None) is None:
77 | if self.dataset == CIFAR10DataProvider.name():
78 | DataProviderClass = CIFAR10DataProvider
79 | elif self.dataset == CIFAR100DataProvider.name():
80 | DataProviderClass = CIFAR100DataProvider
81 | elif self.dataset == CINIC10DataProvider.name():
82 | DataProviderClass = CINIC10DataProvider
83 | else:
84 | raise NotImplementedError
85 | self.__dict__['_data_provider'] = DataProviderClass(
86 | save_path=self.cifar_data_path,
87 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
88 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
89 | distort_color=self.distort_color, image_size=self.image_size,
90 | )
91 | return self.__dict__['_data_provider']
92 |
93 |
94 | class Flowers102RunConfig(RunConfig):
95 |
96 | def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
97 | dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None,
98 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
99 | mixup_alpha=None,
100 | model_init='he_fout', validation_frequency=1, print_frequency=10,
101 | n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
102 | data_path='/mnt/datastore/Flowers102',
103 | **kwargs):
104 | super(Flowers102RunConfig, self).__init__(
105 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
106 | dataset, train_batch_size, test_batch_size, valid_size,
107 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
108 | mixup_alpha,
109 | model_init, validation_frequency, print_frequency
110 | )
111 |
112 | self.n_worker = n_worker
113 | self.resize_scale = resize_scale
114 | self.distort_color = distort_color
115 | self.image_size = image_size
116 | self.flowers102_data_path = data_path
117 |
118 | @property
119 | def data_provider(self):
120 | if self.__dict__.get('_data_provider', None) is None:
121 | if self.dataset == Flowers102DataProvider.name():
122 | DataProviderClass = Flowers102DataProvider
123 | else:
124 | raise NotImplementedError
125 | self.__dict__['_data_provider'] = DataProviderClass(
126 | save_path=self.flowers102_data_path,
127 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
128 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
129 | distort_color=self.distort_color, image_size=self.image_size,
130 | )
131 | return self.__dict__['_data_provider']
132 |
133 |
134 | class STL10RunConfig(RunConfig):
135 |
136 | def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
137 | dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None,
138 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
139 | mixup_alpha=None,
140 | model_init='he_fout', validation_frequency=1, print_frequency=10,
141 | n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
142 | data_path='/mnt/datastore/STL10',
143 | **kwargs):
144 | super(STL10RunConfig, self).__init__(
145 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
146 | dataset, train_batch_size, test_batch_size, valid_size,
147 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
148 | mixup_alpha,
149 | model_init, validation_frequency, print_frequency
150 | )
151 |
152 | self.n_worker = n_worker
153 | self.resize_scale = resize_scale
154 | self.distort_color = distort_color
155 | self.image_size = image_size
156 | self.stl10_data_path = data_path
157 |
158 | @property
159 | def data_provider(self):
160 | if self.__dict__.get('_data_provider', None) is None:
161 | if self.dataset == STL10DataProvider.name():
162 | DataProviderClass = STL10DataProvider
163 | else:
164 | raise NotImplementedError
165 | self.__dict__['_data_provider'] = DataProviderClass(
166 | save_path=self.stl10_data_path,
167 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
168 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
169 | distort_color=self.distort_color, image_size=self.image_size,
170 | )
171 | return self.__dict__['_data_provider']
172 |
173 |
174 | class DTDRunConfig(RunConfig):
175 |
176 | def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
177 | dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None,
178 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
179 | mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
180 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
181 | data_path='/mnt/datastore/dtd',
182 | **kwargs):
183 | super(DTDRunConfig, self).__init__(
184 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
185 | dataset, train_batch_size, test_batch_size, valid_size,
186 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
187 | mixup_alpha,
188 | model_init, validation_frequency, print_frequency
189 | )
190 | self.n_worker = n_worker
191 | self.resize_scale = resize_scale
192 | self.distort_color = distort_color
193 | self.image_size = image_size
194 | self.data_path = data_path
195 |
196 | @property
197 | def data_provider(self):
198 | if self.__dict__.get('_data_provider', None) is None:
199 | if self.dataset == DTDDataProvider.name():
200 | DataProviderClass = DTDDataProvider
201 | else:
202 | raise NotImplementedError
203 | self.__dict__['_data_provider'] = DataProviderClass(
204 | save_path=self.data_path,
205 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
206 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
207 | distort_color=self.distort_color, image_size=self.image_size,
208 | )
209 | return self.__dict__['_data_provider']
210 |
211 |
212 | class PetsRunConfig(RunConfig):
213 |
214 | def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
215 | dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None,
216 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
217 | mixup_alpha=None,
218 | model_init='he_fout', validation_frequency=1, print_frequency=10,
219 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
220 | data_path='/mnt/datastore/Oxford-IIITPets',
221 | **kwargs):
222 | super(PetsRunConfig, self).__init__(
223 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
224 | dataset, train_batch_size, test_batch_size, valid_size,
225 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
226 | mixup_alpha,
227 | model_init, validation_frequency, print_frequency
228 | )
229 | self.n_worker = n_worker
230 | self.resize_scale = resize_scale
231 | self.distort_color = distort_color
232 | self.image_size = image_size
233 | self.imagenet_data_path = data_path
234 |
235 | @property
236 | def data_provider(self):
237 | if self.__dict__.get('_data_provider', None) is None:
238 | if self.dataset == OxfordIIITPetsDataProvider.name():
239 | DataProviderClass = OxfordIIITPetsDataProvider
240 | else:
241 | raise NotImplementedError
242 | self.__dict__['_data_provider'] = DataProviderClass(
243 | save_path=self.imagenet_data_path,
244 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
245 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
246 | distort_color=self.distort_color, image_size=self.image_size,
247 | )
248 | return self.__dict__['_data_provider']
249 |
250 |
251 | class AircraftRunConfig(RunConfig):
252 |
253 | def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
254 | dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None,
255 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
256 | mixup_alpha=None,
257 | model_init='he_fout', validation_frequency=1, print_frequency=10,
258 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
259 | data_path='/mnt/datastore/Aircraft',
260 | **kwargs):
261 | super(AircraftRunConfig, self).__init__(
262 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
263 | dataset, train_batch_size, test_batch_size, valid_size,
264 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
265 | mixup_alpha,
266 | model_init, validation_frequency, print_frequency
267 | )
268 | self.n_worker = n_worker
269 | self.resize_scale = resize_scale
270 | self.distort_color = distort_color
271 | self.image_size = image_size
272 | self.data_path = data_path
273 |
274 | @property
275 | def data_provider(self):
276 | if self.__dict__.get('_data_provider', None) is None:
277 | if self.dataset == FGVCAircraftDataProvider.name():
278 | DataProviderClass = FGVCAircraftDataProvider
279 | else:
280 | raise NotImplementedError
281 | self.__dict__['_data_provider'] = DataProviderClass(
282 | save_path=self.data_path,
283 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
284 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
285 | distort_color=self.distort_color, image_size=self.image_size,
286 | )
287 | return self.__dict__['_data_provider']
288 |
289 |
290 | def get_run_config(**kwargs):
291 | if kwargs['dataset'] == 'imagenet':
292 | run_config = ImagenetRunConfig(**kwargs)
293 | elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'):
294 | run_config = CIFARRunConfig(**kwargs)
295 | elif kwargs['dataset'] == 'flowers102':
296 | run_config = Flowers102RunConfig(**kwargs)
297 | elif kwargs['dataset'] == 'stl10':
298 | run_config = STL10RunConfig(**kwargs)
299 | elif kwargs['dataset'] == 'dtd':
300 | run_config = DTDRunConfig(**kwargs)
301 | elif kwargs['dataset'] == 'pets':
302 | run_config = PetsRunConfig(**kwargs)
303 | elif kwargs['dataset'] == 'aircraft':
304 | run_config = AircraftRunConfig(**kwargs)
305 | else:
306 | raise NotImplementedError
307 |
308 | return run_config
309 |
310 |
311 |
--------------------------------------------------------------------------------
/msunas.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import shutil
4 | import argparse
5 | import subprocess
6 | import numpy as np
7 | from utils import get_correlation
8 | from evaluator import OFAEvaluator, get_net_info
9 |
10 | from pymoo.optimize import minimize
11 | from pymoo.model.problem import Problem
12 | from pymoo.factory import get_performance_indicator
13 | from pymoo.algorithms.so_genetic_algorithm import GA
14 | from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting
15 | from pymoo.factory import get_algorithm, get_crossover, get_mutation
16 |
17 | from search_space.ofa import OFASearchSpace
18 | from acc_predictor.factory import get_acc_predictor
19 | from utils import prepare_eval_folder, MySampling, BinaryCrossover, MyMutation
20 |
21 | _DEBUG = False
22 | if _DEBUG: from pymoo.visualization.scatter import Scatter
23 |
24 |
25 | class MSuNAS:
26 | def __init__(self, kwargs):
27 | self.search_space = OFASearchSpace()
28 | self.save_path = kwargs.pop('save', '.tmp') # path to save results
29 | self.resume = kwargs.pop('resume', None) # resume search from a checkpoint
30 | self.sec_obj = kwargs.pop('sec_obj', 'flops') # second objective to optimize simultaneously
31 | self.iterations = kwargs.pop('iterations', 30) # number of iterations to run search
32 | self.n_doe = kwargs.pop('n_doe', 100) # number of architectures to train before fit surrogate model
33 | self.n_iter = kwargs.pop('n_iter', 8) # number of architectures to train in each iteration
34 | self.predictor = kwargs.pop('predictor', 'rbf') # which surrogate model to fit
35 | self.n_gpus = kwargs.pop('n_gpus', 1) # number of available gpus
36 | self.gpu = kwargs.pop('gpu', 1) # required number of gpus per evaluation job
37 | self.data = kwargs.pop('data', '../data') # location of the data files
38 | self.dataset = kwargs.pop('dataset', 'imagenet') # which dataset to run search on
39 | self.n_classes = kwargs.pop('n_classes', 1000) # number of classes of the given dataset
40 | self.n_workers = kwargs.pop('n_workers', 6) # number of threads for dataloader
41 | self.vld_size = kwargs.pop('vld_size', 10000) # number of images from train set to validate performance
42 | self.trn_batch_size = kwargs.pop('trn_batch_size', 96) # batch size for SGD training
43 | self.vld_batch_size = kwargs.pop('vld_batch_size', 250) # batch size for validation
44 | self.n_epochs = kwargs.pop('n_epochs', 5) # number of epochs to SGD training
45 | self.test = kwargs.pop('test', False) # evaluate performance on test set
46 | self.supernet_path = kwargs.pop(
47 | 'supernet_path', './data/ofa_mbv3_d234_e346_k357_w1.0') # supernet model path
48 | self.latency = self.sec_obj if "cpu" in self.sec_obj or "gpu" in self.sec_obj else None
49 |
50 | def search(self):
51 |
52 | if self.resume:
53 | archive = self._resume_from_dir()
54 | else:
55 | # the following lines corresponding to Algo 1 line 1-7 in the paper
56 | archive = [] # initialize an empty archive to store all trained CNNs
57 |
58 | # Design Of Experiment
59 | if self.iterations < 1:
60 | arch_doe = self.search_space.sample(self.n_doe)
61 | else:
62 | arch_doe = self.search_space.initialize(self.n_doe)
63 |
64 | # parallel evaluation of arch_doe
65 | top1_err, complexity = self._evaluate(arch_doe, it=0)
66 |
67 | # store evaluated / trained architectures
68 | for member in zip(arch_doe, top1_err, complexity):
69 | archive.append(member)
70 |
71 | # reference point (nadir point) for calculating hypervolume
72 | ref_pt = np.array([np.max([x[1] for x in archive]), np.max([x[2] for x in archive])])
73 |
74 | # main loop of the search
75 | for it in range(1, self.iterations + 1):
76 |
77 | # construct accuracy predictor surrogate model from archive
78 | # Algo 1 line 9 / Fig. 3(a) in the paper
79 | acc_predictor, a_top1_err_pred = self._fit_acc_predictor(archive)
80 |
81 | # search for the next set of candidates for high-fidelity evaluation (lower level)
82 | # Algo 1 line 10-11 / Fig. 3(b)-(d) in the paper
83 | candidates, c_top1_err_pred = self._next(archive, acc_predictor, self.n_iter)
84 |
85 | # high-fidelity evaluation (lower level)
86 | # Algo 1 line 13-14 / Fig. 3(e) in the paper
87 | c_top1_err, complexity = self._evaluate(candidates, it=it)
88 |
89 | # check for accuracy predictor's performance
90 | rmse, rho, tau = get_correlation(
91 | np.vstack((a_top1_err_pred, c_top1_err_pred)), np.array([x[1] for x in archive] + c_top1_err))
92 |
93 | # add to archive
94 | # Algo 1 line 15 / Fig. 3(e) in the paper
95 | for member in zip(candidates, c_top1_err, complexity):
96 | archive.append(member)
97 |
98 | # calculate hypervolume
99 | hv = self._calc_hv(
100 | ref_pt, np.column_stack(([x[1] for x in archive], [x[2] for x in archive])))
101 |
102 | # print iteration-wise statistics
103 | print("Iter {}: hv = {:.2f}".format(it, hv))
104 | print("fitting {}: RMSE = {:.4f}, Spearman's Rho = {:.4f}, Kendall’s Tau = {:.4f}".format(
105 | self.predictor, rmse, rho, tau))
106 |
107 | # dump the statistics
108 | with open(os.path.join(self.save_path, "iter_{}.stats".format(it)), "w") as handle:
109 | json.dump({'archive': archive, 'candidates': archive[-self.n_iter:], 'hv': hv,
110 | 'surrogate': {
111 | 'model': self.predictor, 'name': acc_predictor.name,
112 | 'winner': acc_predictor.winner if self.predictor == 'as' else acc_predictor.name,
113 | 'rmse': rmse, 'rho': rho, 'tau': tau}}, handle)
114 | if _DEBUG:
115 | # plot
116 | plot = Scatter(legend={'loc': 'lower right'})
117 | F = np.full((len(archive), 2), np.nan)
118 | F[:, 0] = np.array([x[2] for x in archive]) # second obj. (complexity)
119 | F[:, 1] = 100 - np.array([x[1] for x in archive]) # top-1 accuracy
120 | plot.add(F, s=15, facecolors='none', edgecolors='b', label='archive')
121 | F = np.full((len(candidates), 2), np.nan)
122 | F[:, 0] = np.array(complexity)
123 | F[:, 1] = 100 - np.array(c_top1_err)
124 | plot.add(F, s=30, color='r', label='candidates evaluated')
125 | F = np.full((len(candidates), 2), np.nan)
126 | F[:, 0] = np.array(complexity)
127 | F[:, 1] = 100 - c_top1_err_pred[:, 0]
128 | plot.add(F, s=20, facecolors='none', edgecolors='g', label='candidates predicted')
129 | plot.save(os.path.join(self.save_path, 'iter_{}.png'.format(it)))
130 |
131 | return
132 |
133 | def _resume_from_dir(self):
134 | """ resume search from a previous iteration """
135 | import glob
136 |
137 | archive = []
138 | for file in glob.glob(os.path.join(self.resume, "net_*.subnet")):
139 | arch = json.load(open(file))
140 | pre, ext = os.path.splitext(file)
141 | stats = json.load(open(pre + ".stats"))
142 | archive.append((arch, 100 - stats['top1'], stats[self.sec_obj]))
143 |
144 | return archive
145 |
146 | def _evaluate(self, archs, it):
147 | gen_dir = os.path.join(self.save_path, "iter_{}".format(it))
148 | prepare_eval_folder(
149 | gen_dir, archs, self.gpu, self.n_gpus, data=self.data, dataset=self.dataset,
150 | n_classes=self.n_classes, supernet_path=self.supernet_path,
151 | num_workers=self.n_workers, valid_size=self.vld_size,
152 | trn_batch_size=self.trn_batch_size, vld_batch_size=self.vld_batch_size,
153 | n_epochs=self.n_epochs, test=self.test, latency=self.latency, verbose=False)
154 |
155 | subprocess.call("sh {}/run_bash.sh".format(gen_dir), shell=True)
156 |
157 | top1_err, complexity = [], []
158 |
159 | for i in range(len(archs)):
160 | try:
161 | stats = json.load(open(os.path.join(gen_dir, "net_{}.stats".format(i))))
162 | except FileNotFoundError:
163 | # just in case the subprocess evaluation failed
164 | stats = {'top1': 0, self.sec_obj: 1000} # makes the solution artificially bad so it won't survive
165 | # store this architecture to a separate in case we want to revisit after the search
166 | os.makedirs(os.path.join(self.save_path, "failed"), exist_ok=True)
167 | shutil.copy(os.path.join(gen_dir, "net_{}.subnet".format(i)),
168 | os.path.join(self.save_path, "failed", "it_{}_net_{}".format(it, i)))
169 |
170 | top1_err.append(100 - stats['top1'])
171 | complexity.append(stats[self.sec_obj])
172 |
173 | return top1_err, complexity
174 |
175 | def _fit_acc_predictor(self, archive):
176 | inputs = np.array([self.search_space.encode(x[0]) for x in archive])
177 | targets = np.array([x[1] for x in archive])
178 | assert len(inputs) > len(inputs[0]), "# of training samples have to be > # of dimensions"
179 |
180 | acc_predictor = get_acc_predictor(self.predictor, inputs, targets)
181 |
182 | return acc_predictor, acc_predictor.predict(inputs)
183 |
184 | def _next(self, archive, predictor, K):
185 | """ searching for next K candidate for high-fidelity evaluation (lower level) """
186 |
187 | # the following lines corresponding to Algo 1 line 10 / Fig. 3(b) in the paper
188 | # get non-dominated architectures from archive
189 | F = np.column_stack(([x[1] for x in archive], [x[2] for x in archive]))
190 | front = NonDominatedSorting().do(F, only_non_dominated_front=True)
191 | # non-dominated arch bit-strings
192 | nd_X = np.array([self.search_space.encode(x[0]) for x in archive])[front]
193 |
194 | # initialize the candidate finding optimization problem
195 | problem = AuxiliarySingleLevelProblem(
196 | self.search_space, predictor, self.sec_obj,
197 | {'n_classes': self.n_classes, 'model_path': self.supernet_path})
198 |
199 | # initiate a multi-objective solver to optimize the problem
200 | method = get_algorithm(
201 | "nsga2", pop_size=40, sampling=nd_X, # initialize with current nd archs
202 | crossover=get_crossover("int_two_point", prob=0.9),
203 | mutation=get_mutation("int_pm", eta=1.0),
204 | eliminate_duplicates=True)
205 |
206 | # kick-off the search
207 | res = minimize(
208 | problem, method, termination=('n_gen', 20), save_history=True, verbose=True)
209 |
210 | # check for duplicates
211 | not_duplicate = np.logical_not(
212 | [x in [x[0] for x in archive] for x in [self.search_space.decode(x) for x in res.pop.get("X")]])
213 |
214 | # the following lines corresponding to Algo 1 line 11 / Fig. 3(c)-(d) in the paper
215 | # form a subset selection problem to short list K from pop_size
216 | indices = self._subset_selection(res.pop[not_duplicate], F[front, 1], K)
217 | pop = res.pop[not_duplicate][indices]
218 |
219 | candidates = []
220 | for x in pop.get("X"):
221 | candidates.append(self.search_space.decode(x))
222 |
223 | # decode integer bit-string to config and also return predicted top1_err
224 | return candidates, predictor.predict(pop.get("X"))
225 |
226 | @staticmethod
227 | def _subset_selection(pop, nd_F, K):
228 | problem = SubsetProblem(pop.get("F")[:, 1], nd_F, K)
229 | algorithm = GA(
230 | pop_size=100, sampling=MySampling(), crossover=BinaryCrossover(),
231 | mutation=MyMutation(), eliminate_duplicates=True)
232 |
233 | res = minimize(
234 | problem, algorithm, ('n_gen', 60), verbose=False)
235 |
236 | return res.X
237 |
238 | @staticmethod
239 | def _calc_hv(ref_pt, F, normalized=True):
240 | # calculate hypervolume on the non-dominated set of F
241 | front = NonDominatedSorting().do(F, only_non_dominated_front=True)
242 | nd_F = F[front, :]
243 | ref_point = 1.01 * ref_pt
244 | hv = get_performance_indicator("hv", ref_point=ref_point).calc(nd_F)
245 | if normalized:
246 | hv = hv / np.prod(ref_point)
247 | return hv
248 |
249 |
250 | class AuxiliarySingleLevelProblem(Problem):
251 | """ The optimization problem for finding the next N candidate architectures """
252 |
253 | def __init__(self, search_space, predictor, sec_obj='flops', supernet=None):
254 | super().__init__(n_var=46, n_obj=2, n_constr=0, type_var=np.int)
255 |
256 | self.ss = search_space
257 | self.predictor = predictor
258 | self.xl = np.zeros(self.n_var)
259 | self.xu = 2 * np.ones(self.n_var)
260 | self.xu[-1] = int(len(self.ss.resolution) - 1)
261 | self.sec_obj = sec_obj
262 | self.lut = {'cpu': 'data/i7-8700K_lut.yaml'}
263 |
264 | # supernet engine for measuring complexity
265 | self.engine = OFAEvaluator(
266 | n_classes=supernet['n_classes'], model_path=supernet['model_path'])
267 |
268 | def _evaluate(self, x, out, *args, **kwargs):
269 | f = np.full((x.shape[0], self.n_obj), np.nan)
270 |
271 | top1_err = self.predictor.predict(x)[:, 0] # predicted top1 error
272 |
273 | for i, (_x, err) in enumerate(zip(x, top1_err)):
274 | config = self.ss.decode(_x)
275 | subnet, _ = self.engine.sample({'ks': config['ks'], 'e': config['e'], 'd': config['d']})
276 | info = get_net_info(subnet, (3, config['r'], config['r']),
277 | measure_latency=self.sec_obj, print_info=False, clean=True, lut=self.lut)
278 | f[i, 0] = err
279 | f[i, 1] = info[self.sec_obj]
280 |
281 | out["F"] = f
282 |
283 |
284 | class SubsetProblem(Problem):
285 | """ select a subset to diversify the pareto front """
286 | def __init__(self, candidates, archive, K):
287 | super().__init__(n_var=len(candidates), n_obj=1,
288 | n_constr=1, xl=0, xu=1, type_var=np.bool)
289 | self.archive = archive
290 | self.candidates = candidates
291 | self.n_max = K
292 |
293 | def _evaluate(self, x, out, *args, **kwargs):
294 | f = np.full((x.shape[0], 1), np.nan)
295 | g = np.full((x.shape[0], 1), np.nan)
296 |
297 | for i, _x in enumerate(x):
298 | # append selected candidates to archive then sort
299 | tmp = np.sort(np.concatenate((self.archive, self.candidates[_x])))
300 | f[i, 0] = np.std(np.diff(tmp))
301 | # we penalize if the number of selected candidates is not exactly K
302 | g[i, 0] = (self.n_max - np.sum(_x)) ** 2
303 |
304 | out["F"] = f
305 | out["G"] = g
306 |
307 |
308 | def main(args):
309 | engine = MSuNAS(vars(args))
310 | engine.search()
311 | return
312 |
313 |
314 | if __name__ == '__main__':
315 | parser = argparse.ArgumentParser()
316 | parser.add_argument('--save', type=str, default='.tmp',
317 | help='location of dir to save')
318 | parser.add_argument('--resume', type=str, default=None,
319 | help='resume search from a checkpoint')
320 | parser.add_argument('--sec_obj', type=str, default='flops',
321 | help='second objective to optimize simultaneously')
322 | parser.add_argument('--iterations', type=int, default=30,
323 | help='number of search iterations')
324 | parser.add_argument('--n_doe', type=int, default=100,
325 | help='initial sample size for DOE')
326 | parser.add_argument('--n_iter', type=int, default=8,
327 | help='number of architectures to high-fidelity eval (low level) in each iteration')
328 | parser.add_argument('--predictor', type=str, default='rbf',
329 | help='which accuracy predictor model to fit (rbf/gp/cart/mlp/as)')
330 | parser.add_argument('--n_gpus', type=int, default=8,
331 | help='total number of available gpus')
332 | parser.add_argument('--gpu', type=int, default=1,
333 | help='number of gpus per evaluation job')
334 | parser.add_argument('--data', type=str, default='/mnt/datastore/ILSVRC2012',
335 | help='location of the data corpus')
336 | parser.add_argument('--dataset', type=str, default='imagenet',
337 | help='name of the dataset (imagenet, cifar10, cifar100, ...)')
338 | parser.add_argument('--n_classes', type=int, default=1000,
339 | help='number of classes of the given dataset')
340 | parser.add_argument('--supernet_path', type=str, default='./data/ofa_mbv3_d234_e346_k357_w1.0',
341 | help='file path to supernet weights')
342 | parser.add_argument('--n_workers', type=int, default=4,
343 | help='number of workers for dataloader per evaluation job')
344 | parser.add_argument('--vld_size', type=int, default=None,
345 | help='validation set size, randomly sampled from training set')
346 | parser.add_argument('--trn_batch_size', type=int, default=128,
347 | help='train batch size for training')
348 | parser.add_argument('--vld_batch_size', type=int, default=200,
349 | help='test batch size for inference')
350 | parser.add_argument('--n_epochs', type=int, default=5,
351 | help='number of epochs for CNN training')
352 | parser.add_argument('--test', action='store_true', default=False,
353 | help='evaluation performance on testing set')
354 | cfgs = parser.parse_args()
355 | main(cfgs)
356 |
357 |
--------------------------------------------------------------------------------
/codebase/data_providers/aircraft.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import math
5 | import warnings
6 | import numpy as np
7 |
8 | from timm.data.transforms import str_to_pil_interp
9 | from timm.data.auto_augment import rand_augment_transform
10 |
11 | import torch.utils.data
12 | import torchvision.transforms as transforms
13 | from torchvision.datasets.folder import default_loader
14 |
15 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
16 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider
17 |
18 |
19 | def make_dataset(dir, image_ids, targets):
20 | assert(len(image_ids) == len(targets))
21 | images = []
22 | dir = os.path.expanduser(dir)
23 | for i in range(len(image_ids)):
24 | item = (os.path.join(dir, 'data', 'images',
25 | '%s.jpg' % image_ids[i]), targets[i])
26 | images.append(item)
27 | return images
28 |
29 |
30 | def find_classes(classes_file):
31 | # read classes file, separating out image IDs and class names
32 | image_ids = []
33 | targets = []
34 | f = open(classes_file, 'r')
35 | for line in f:
36 | split_line = line.split(' ')
37 | image_ids.append(split_line[0])
38 | targets.append(' '.join(split_line[1:]))
39 | f.close()
40 |
41 | # index class names
42 | classes = np.unique(targets)
43 | class_to_idx = {classes[i]: i for i in range(len(classes))}
44 | targets = [class_to_idx[c] for c in targets]
45 |
46 | return (image_ids, targets, classes, class_to_idx)
47 |
48 |
49 | class FGVCAircraft(torch.utils.data.Dataset):
50 | """`FGVC-Aircraft `_ Dataset.
51 | Args:
52 | root (string): Root directory path to dataset.
53 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
54 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
55 | transform (callable, optional): A function/transform that takes in a PIL image
56 | and returns a transformed version. E.g. ``transforms.RandomCrop``
57 | target_transform (callable, optional): A function/transform that takes in the
58 | target and transforms it.
59 | loader (callable, optional): A function to load an image given its path.
60 | download (bool, optional): If true, downloads the dataset from the internet and
61 | puts it in the root directory. If dataset is already downloaded, it is not
62 | downloaded again.
63 | """
64 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
65 | class_types = ('variant', 'family', 'manufacturer')
66 | splits = ('train', 'val', 'trainval', 'test')
67 |
68 | def __init__(self, root, class_type='variant', split='train', transform=None,
69 | target_transform=None, loader=default_loader, download=False):
70 | if split not in self.splits:
71 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
72 | split, ', '.join(self.splits),
73 | ))
74 | if class_type not in self.class_types:
75 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
76 | class_type, ', '.join(self.class_types),
77 | ))
78 | self.root = os.path.expanduser(root)
79 | self.class_type = class_type
80 | self.split = split
81 | self.classes_file = os.path.join(self.root, 'data',
82 | 'images_%s_%s.txt' % (self.class_type, self.split))
83 |
84 | if download:
85 | self.download()
86 |
87 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
88 | samples = make_dataset(self.root, image_ids, targets)
89 |
90 | self.transform = transform
91 | self.target_transform = target_transform
92 | self.loader = loader
93 |
94 | self.samples = samples
95 | self.classes = classes
96 | self.class_to_idx = class_to_idx
97 |
98 | def __getitem__(self, index):
99 | """
100 | Args:
101 | index (int): Index
102 | Returns:
103 | tuple: (sample, target) where target is class_index of the target class.
104 | """
105 |
106 | path, target = self.samples[index]
107 | sample = self.loader(path)
108 | if self.transform is not None:
109 | sample = self.transform(sample)
110 | if self.target_transform is not None:
111 | target = self.target_transform(target)
112 |
113 | return sample, target
114 |
115 | def __len__(self):
116 | return len(self.samples)
117 |
118 | def __repr__(self):
119 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
120 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
121 | fmt_str += ' Root Location: {}\n'.format(self.root)
122 | tmp = ' Transforms (if any): '
123 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
124 | tmp = ' Target Transforms (if any): '
125 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
126 | return fmt_str
127 |
128 | def _check_exists(self):
129 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
130 | os.path.exists(self.classes_file)
131 |
132 | def download(self):
133 | """Download the FGVC-Aircraft data if it doesn't exist already."""
134 | from six.moves import urllib
135 | import tarfile
136 |
137 | if self._check_exists():
138 | return
139 |
140 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
141 | print('Downloading %s ... (may take a few minutes)' % self.url)
142 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
143 | tar_name = self.url.rpartition('/')[-1]
144 | tar_path = os.path.join(parent_dir, tar_name)
145 | data = urllib.request.urlopen(self.url)
146 |
147 | # download .tar.gz file
148 | with open(tar_path, 'wb') as f:
149 | f.write(data.read())
150 |
151 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
152 | data_folder = tar_path.strip('.tar.gz')
153 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
154 | tar = tarfile.open(tar_path)
155 | tar.extractall(parent_dir)
156 |
157 | # if necessary, rename data folder to self.root
158 | if not os.path.samefile(data_folder, self.root):
159 | print('Renaming %s to %s ...' % (data_folder, self.root))
160 | os.rename(data_folder, self.root)
161 |
162 | # delete .tar.gz file
163 | print('Deleting %s ...' % tar_path)
164 | os.remove(tar_path)
165 |
166 | print('Done!')
167 |
168 |
169 | class FGVCAircraftDataProvider(DataProvider):
170 |
171 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
172 | resize_scale=0.08, distort_color=None, image_size=224,
173 | num_replicas=None, rank=None):
174 |
175 | warnings.filterwarnings('ignore')
176 | self._save_path = save_path
177 |
178 | self.image_size = image_size # int or list of int
179 | self.distort_color = distort_color
180 | self.resize_scale = resize_scale
181 |
182 | self._valid_transform_dict = {}
183 | if not isinstance(self.image_size, int):
184 | assert isinstance(self.image_size, list)
185 | from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
186 | self.image_size.sort() # e.g., 160 -> 224
187 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
188 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
189 |
190 | for img_size in self.image_size:
191 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
192 | self.active_img_size = max(self.image_size)
193 | valid_transforms = self._valid_transform_dict[self.active_img_size]
194 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
195 | else:
196 | self.active_img_size = self.image_size
197 | valid_transforms = self.build_valid_transform()
198 | train_loader_class = torch.utils.data.DataLoader
199 |
200 | train_transforms = self.build_train_transform()
201 | train_dataset = self.train_dataset(train_transforms)
202 |
203 | if valid_size is not None:
204 | if not isinstance(valid_size, int):
205 | assert isinstance(valid_size, float) and 0 < valid_size < 1
206 | valid_size = int(len(train_dataset.samples) * valid_size)
207 |
208 | valid_dataset = self.train_dataset(valid_transforms)
209 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
210 |
211 | if num_replicas is not None:
212 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
213 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
214 | else:
215 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
216 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
217 |
218 | self.train = train_loader_class(
219 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
220 | num_workers=n_worker, pin_memory=True,
221 | )
222 | self.valid = torch.utils.data.DataLoader(
223 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
224 | num_workers=n_worker, pin_memory=True,
225 | )
226 | else:
227 | if num_replicas is not None:
228 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
229 | self.train = train_loader_class(
230 | train_dataset, batch_size=train_batch_size, sampler=train_sampler,
231 | num_workers=n_worker, pin_memory=True
232 | )
233 | else:
234 | self.train = train_loader_class(
235 | train_dataset, batch_size=train_batch_size, shuffle=True,
236 | num_workers=n_worker, pin_memory=True,
237 | )
238 | self.valid = None
239 |
240 | test_dataset = self.test_dataset(valid_transforms)
241 | if num_replicas is not None:
242 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
243 | self.test = torch.utils.data.DataLoader(
244 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
245 | )
246 | else:
247 | self.test = torch.utils.data.DataLoader(
248 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
249 | )
250 |
251 | if self.valid is None:
252 | self.valid = self.test
253 |
254 | @staticmethod
255 | def name():
256 | return 'aircraft'
257 |
258 | @property
259 | def data_shape(self):
260 | return 3, self.active_img_size, self.active_img_size # C, H, W
261 |
262 | @property
263 | def n_classes(self):
264 | return 100
265 |
266 | @property
267 | def save_path(self):
268 | if self._save_path is None:
269 | self._save_path = '/mnt/datastore/Aircraft' # home server
270 |
271 | if not os.path.exists(self._save_path):
272 | self._save_path = '/mnt/datastore/Aircraft' # home server
273 | return self._save_path
274 |
275 | @property
276 | def data_url(self):
277 | raise ValueError('unable to download %s' % self.name())
278 |
279 | def train_dataset(self, _transforms):
280 | # dataset = datasets.ImageFolder(self.train_path, _transforms)
281 | dataset = FGVCAircraft(
282 | root=self.train_path, split='trainval', download=True, transform=_transforms)
283 | return dataset
284 |
285 | def test_dataset(self, _transforms):
286 | # dataset = datasets.ImageFolder(self.valid_path, _transforms)
287 | dataset = FGVCAircraft(
288 | root=self.valid_path, split='test', download=True, transform=_transforms)
289 | return dataset
290 |
291 | @property
292 | def train_path(self):
293 | return self.save_path
294 |
295 | @property
296 | def valid_path(self):
297 | return self.save_path
298 |
299 | @property
300 | def normalize(self):
301 | return transforms.Normalize(
302 | mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883],
303 | std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115])
304 |
305 | def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
306 | if image_size is None:
307 | image_size = self.image_size
308 | # if print_log:
309 | # print('Color jitter: %s, resize_scale: %s, img_size: %s' %
310 | # (self.distort_color, self.resize_scale, image_size))
311 |
312 | # if self.distort_color == 'torch':
313 | # color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
314 | # elif self.distort_color == 'tf':
315 | # color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
316 | # else:
317 | # color_transform = None
318 |
319 | if isinstance(image_size, list):
320 | resize_transform_class = MyRandomResizedCrop
321 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
322 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
323 | img_size_min = min(image_size)
324 | else:
325 | resize_transform_class = transforms.RandomResizedCrop
326 | img_size_min = image_size
327 |
328 | train_transforms = [
329 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
330 | transforms.RandomHorizontalFlip(),
331 | ]
332 |
333 | aa_params = dict(
334 | translate_const=int(img_size_min * 0.45),
335 | img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618,
336 | 0.5387914411673883]]),
337 | )
338 | aa_params['interpolation'] = str_to_pil_interp('bicubic')
339 | train_transforms += [rand_augment_transform(auto_augment, aa_params)]
340 |
341 | # if color_transform is not None:
342 | # train_transforms.append(color_transform)
343 | train_transforms += [
344 | transforms.ToTensor(),
345 | self.normalize,
346 | ]
347 |
348 | train_transforms = transforms.Compose(train_transforms)
349 | return train_transforms
350 |
351 | def build_valid_transform(self, image_size=None):
352 | if image_size is None:
353 | image_size = self.active_img_size
354 | return transforms.Compose([
355 | transforms.Resize(int(math.ceil(image_size / 0.875))),
356 | transforms.CenterCrop(image_size),
357 | transforms.ToTensor(),
358 | self.normalize,
359 | ])
360 |
361 | def assign_active_img_size(self, new_img_size):
362 | self.active_img_size = new_img_size
363 | if self.active_img_size not in self._valid_transform_dict:
364 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
365 | # change the transform of the valid and test set
366 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
367 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
368 |
369 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
370 | # used for resetting running statistics
371 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
372 | if num_worker is None:
373 | num_worker = self.train.num_workers
374 |
375 | n_samples = len(self.train.dataset.samples)
376 | g = torch.Generator()
377 | g.manual_seed(DataProvider.SUB_SEED)
378 | rand_indexes = torch.randperm(n_samples, generator=g).tolist()
379 |
380 | new_train_dataset = self.train_dataset(
381 | self.build_train_transform(image_size=self.active_img_size, print_log=False))
382 | chosen_indexes = rand_indexes[:n_images]
383 | if num_replicas is not None:
384 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
385 | else:
386 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
387 | sub_data_loader = torch.utils.data.DataLoader(
388 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
389 | num_workers=num_worker, pin_memory=True,
390 | )
391 | self.__dict__['sub_train_%d' % self.active_img_size] = []
392 | for images, labels in sub_data_loader:
393 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
394 | return self.__dict__['sub_train_%d' % self.active_img_size]
395 |
396 |
397 | if __name__ == '__main__':
398 | data = FGVCAircraft(root='/mnt/datastore/Aircraft',
399 | split='trainval', download=True)
400 | print(len(data.classes))
401 | print(len(data.samples))
--------------------------------------------------------------------------------