├── codebase ├── __init__.py ├── networks │ ├── __init__.py │ └── nsganetv2.py ├── data_providers │ ├── stl10.py │ ├── flowers102.py │ ├── imagenet.py │ ├── pets.py │ ├── dtd.py │ ├── autoaugment.py │ └── aircraft.py └── run_manager │ └── __init__.py ├── assets ├── c10.gif ├── c10.png ├── c100.png ├── dtd.png ├── pets.png ├── cinic10.png ├── stl-10.png ├── aircraft.png ├── imagenet.gif ├── imagenet.png ├── overview.png └── flowers102.png ├── scripts ├── distributed_train.sh └── search.sh ├── acc_predictor ├── factory.py ├── rbf.py ├── gp.py ├── adaptive_switching.py ├── carts.py └── mlp.py ├── .gitignore ├── search_space └── ofa.py ├── validation.py ├── post_search.py ├── LICENSE ├── utils.py ├── evaluator.py ├── README.md ├── train_cifar.py └── msunas.py /codebase/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/c10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/c10.gif -------------------------------------------------------------------------------- /assets/c10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/c10.png -------------------------------------------------------------------------------- /assets/c100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/c100.png -------------------------------------------------------------------------------- /assets/dtd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/dtd.png -------------------------------------------------------------------------------- /assets/pets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/pets.png -------------------------------------------------------------------------------- /assets/cinic10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/cinic10.png -------------------------------------------------------------------------------- /assets/stl-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/stl-10.png -------------------------------------------------------------------------------- /assets/aircraft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/aircraft.png -------------------------------------------------------------------------------- /assets/imagenet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/imagenet.gif -------------------------------------------------------------------------------- /assets/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/imagenet.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/overview.png -------------------------------------------------------------------------------- /assets/flowers102.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikelzc1990/nsganetv2/HEAD/assets/flowers102.png -------------------------------------------------------------------------------- /scripts/distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python -m torch.distributed.launch --nproc_per_node=$NUM_PROC train_imagenet.py "$@" -------------------------------------------------------------------------------- /codebase/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from ofa.imagenet_classification.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2 2 | from ofa.imagenet_classification.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large 3 | from codebase.networks.nsganetv2 import NSGANetV2 4 | 5 | -------------------------------------------------------------------------------- /acc_predictor/factory.py: -------------------------------------------------------------------------------- 1 | def get_acc_predictor(model, inputs, targets): 2 | 3 | if model == 'rbf': 4 | from acc_predictor.rbf import RBF 5 | acc_predictor = RBF() 6 | acc_predictor.fit(inputs, targets) 7 | 8 | elif model == 'carts': 9 | from acc_predictor.carts import CART 10 | acc_predictor = CART(n_tree=5000) 11 | acc_predictor.fit(inputs, targets) 12 | 13 | elif model == 'gp': 14 | from acc_predictor.gp import GP 15 | acc_predictor = GP() 16 | acc_predictor.fit(inputs, targets) 17 | 18 | elif model == 'mlp': 19 | from acc_predictor.mlp import MLP 20 | acc_predictor = MLP(n_feature=inputs.shape[1]) 21 | acc_predictor.fit(x=inputs, y=targets) 22 | 23 | elif model == 'as': 24 | from acc_predictor.adaptive_switching import AdaptiveSwitching 25 | acc_predictor = AdaptiveSwitching() 26 | acc_predictor.fit(inputs, targets) 27 | 28 | else: 29 | raise NotImplementedError 30 | 31 | return acc_predictor 32 | 33 | -------------------------------------------------------------------------------- /acc_predictor/rbf.py: -------------------------------------------------------------------------------- 1 | from pySOT.surrogate import RBFInterpolant, CubicKernel, TPSKernel, LinearTail, ConstantTail 2 | 3 | 4 | class RBF: 5 | """ Radial Basis Function """ 6 | 7 | def __init__(self, kernel='cubic', tail='linear'): 8 | self.kernel = kernel 9 | self.tail = tail 10 | self.name = 'rbf' 11 | self.model = None 12 | 13 | def fit(self, train_data, train_label): 14 | if self.kernel == 'cubic': 15 | kernel = CubicKernel 16 | elif self.kernel == 'tps': 17 | kernel = TPSKernel 18 | else: 19 | raise NotImplementedError("unknown RBF kernel") 20 | 21 | if self.tail == 'linear': 22 | tail = LinearTail 23 | elif self.tail == 'constant': 24 | tail = ConstantTail 25 | else: 26 | raise NotImplementedError("unknown RBF tail") 27 | 28 | self.model = RBFInterpolant( 29 | dim=train_data.shape[1], kernel=kernel(), tail=tail(train_data.shape[1])) 30 | 31 | for i in range(len(train_data)): 32 | self.model.add_points(train_data[i, :], train_label[i]) 33 | 34 | def predict(self, test_data): 35 | assert self.model is not None, "RBF model does not exist, call fit to obtain rbf model first" 36 | return self.model.predict(test_data) 37 | -------------------------------------------------------------------------------- /acc_predictor/gp.py: -------------------------------------------------------------------------------- 1 | from pydacefit.regr import regr_constant 2 | from pydacefit.dace import DACE, regr_linear, regr_quadratic 3 | from pydacefit.corr import corr_gauss, corr_cubic, corr_exp, corr_expg, corr_spline, corr_spherical 4 | 5 | 6 | class GP: 7 | """ Gaussian Process (Kriging) """ 8 | def __init__(self, regr='linear', corr='gauss'): 9 | self.regr = regr 10 | self.corr = corr 11 | self.name = 'gp' 12 | self.model = None 13 | 14 | def fit(self, train_data, train_label): 15 | if self.regr == 'linear': 16 | regr = regr_linear 17 | elif self.regr == 'constant': 18 | regr = regr_constant 19 | elif self.regr == 'quadratic': 20 | regr = regr_quadratic 21 | else: 22 | raise NotImplementedError("unknown GP regression") 23 | 24 | if self.corr == 'gauss': 25 | corr = corr_gauss 26 | elif self.corr == 'cubic': 27 | corr = corr_cubic 28 | elif self.corr == 'exp': 29 | corr = corr_exp 30 | elif self.corr == 'expg': 31 | corr = corr_expg 32 | elif self.corr == 'spline': 33 | corr = corr_spline 34 | elif self.corr == 'spherical': 35 | corr = corr_spherical 36 | else: 37 | raise NotImplementedError("unknown GP correlation") 38 | 39 | self.model = DACE( 40 | regr=regr, corr=corr, theta=1.0, thetaL=0.00001, thetaU=100) 41 | self.model.fit(train_data, train_label) 42 | 43 | def predict(self, test_data): 44 | assert self.model is not None, "GP does not exist, call fit to obtain GP first" 45 | return self.model.predict(test_data) 46 | -------------------------------------------------------------------------------- /acc_predictor/adaptive_switching.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import numpy as np 3 | from acc_predictor.factory import get_acc_predictor 4 | 5 | 6 | class AdaptiveSwitching: 7 | """ ensemble surrogate model """ 8 | """ try all available models, pick one based on 10-fold crx vld """ 9 | def __init__(self, n_fold=10): 10 | # self.model_pool = ['rbf', 'gp', 'mlp', 'carts'] 11 | self.model_pool = ['rbf', 'gp', 'carts'] 12 | self.n_fold = n_fold 13 | self.name = 'adaptive switching' 14 | self.model = None 15 | 16 | def fit(self, train_data, train_target): 17 | self._n_fold_validation(train_data, train_target, n=self.n_fold) 18 | 19 | def _n_fold_validation(self, train_data, train_target, n=10): 20 | 21 | n_samples = len(train_data) 22 | perm = np.random.permutation(n_samples) 23 | 24 | kendall_tau = np.full((n, len(self.model_pool)), np.nan) 25 | 26 | for i, tst_split in enumerate(np.array_split(perm, n)): 27 | trn_split = np.setdiff1d(perm, tst_split, assume_unique=True) 28 | 29 | # loop over all considered surrogate model in pool 30 | for j, model in enumerate(self.model_pool): 31 | 32 | acc_predictor = get_acc_predictor(model, train_data[trn_split], train_target[trn_split]) 33 | 34 | rmse, rho, tau = utils.get_correlation( 35 | acc_predictor.predict(train_data[tst_split]), train_target[tst_split]) 36 | 37 | kendall_tau[i, j] = tau 38 | 39 | winner = int(np.argmax(np.mean(kendall_tau, axis=0) - np.std(kendall_tau, axis=0))) 40 | print("winner model = {}, tau = {}".format(self.model_pool[winner], 41 | np.mean(kendall_tau, axis=0)[winner])) 42 | self.winner = self.model_pool[winner] 43 | # re-fit the winner model with entire data 44 | acc_predictor = get_acc_predictor(self.model_pool[winner], train_data, train_target) 45 | self.model = acc_predictor 46 | 47 | def predict(self, test_data): 48 | return self.model.predict(test_data) 49 | -------------------------------------------------------------------------------- /acc_predictor/carts.py: -------------------------------------------------------------------------------- 1 | # implementation based on 2 | # https://github.com/yn-sun/e2epp/blob/master/build_predict_model.py 3 | # and https://github.com/HandingWang/RF-CMOCO 4 | import numpy as np 5 | from sklearn.tree import DecisionTreeRegressor 6 | 7 | 8 | class CART: 9 | """ Classification and Regression Tree """ 10 | def __init__(self, n_tree=1000): 11 | self.n_tree = n_tree 12 | self.name = 'carts' 13 | self.model = None 14 | 15 | @staticmethod 16 | def _make_decision_trees(train_data, train_label, n_tree): 17 | feature_record = [] 18 | tree_record = [] 19 | 20 | for i in range(n_tree): 21 | sample_idx = np.arange(train_data.shape[0]) 22 | np.random.shuffle(sample_idx) 23 | train_data = train_data[sample_idx, :] 24 | train_label = train_label[sample_idx] 25 | 26 | feature_idx = np.arange(train_data.shape[1]) 27 | np.random.shuffle(feature_idx) 28 | n_feature = np.random.randint(1, train_data.shape[1] + 1) 29 | selected_feature_ids = feature_idx[0:n_feature] 30 | feature_record.append(selected_feature_ids) 31 | 32 | dt = DecisionTreeRegressor() 33 | dt.fit(train_data[:, selected_feature_ids], train_label) 34 | tree_record.append(dt) 35 | 36 | return tree_record, feature_record 37 | 38 | def fit(self, train_data, train_label): 39 | self.model = self._make_decision_trees(train_data, train_label, self.n_tree) 40 | 41 | def predict(self, test_data): 42 | assert self.model is not None, "carts does not exist, call fit to obtain cart first" 43 | 44 | # redundant variable device 45 | trees, features = self.model[0], self.model[1] 46 | test_num, n_tree = len(test_data), len(trees) 47 | 48 | predict_labels = np.zeros((test_num, 1)) 49 | for i in range(test_num): 50 | this_test_data = test_data[i, :] 51 | predict_this_list = np.zeros(n_tree) 52 | 53 | for j, (tree, feature) in enumerate(zip(trees, features)): 54 | predict_this_list[j] = tree.predict([this_test_data[feature]])[0] 55 | 56 | # find the top 100 prediction 57 | predict_this_list = np.sort(predict_this_list) 58 | predict_this_list = predict_this_list[::-1] 59 | this_predict = np.mean(predict_this_list) 60 | predict_labels[i, 0] = this_predict 61 | 62 | return predict_labels 63 | 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .tmp/ 6 | .idea 7 | ipynb/ 8 | #subnets/ 9 | # C extensions 10 | *.so 11 | .DS_Store 12 | subnets/ 13 | data/ 14 | search-*/ 15 | sample_run/ 16 | # Distribution / packaging 17 | .Python 18 | eccv20/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | *.pdf 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | *.pptx 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | experiments/ 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | *.o 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | /pretrained/ 141 | .idea 142 | -------------------------------------------------------------------------------- /search_space/ofa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class OFASearchSpace: 5 | def __init__(self): 6 | self.num_blocks = 5 # number of blocks 7 | self.kernel_size = [3, 5, 7] # depth-wise conv kernel size 8 | self.exp_ratio = [3, 4, 6] # expansion rate 9 | self.depth = [2, 3, 4] # number of Inverted Residual Bottleneck layers repetition 10 | self.resolution = list(range(192, 257, 4)) # input image resolutions 11 | 12 | def sample(self, n_samples=1, nb=None, ks=None, e=None, d=None, r=None): 13 | """ randomly sample a architecture""" 14 | nb = self.num_blocks if nb is None else nb 15 | ks = self.kernel_size if ks is None else ks 16 | e = self.exp_ratio if e is None else e 17 | d = self.depth if d is None else d 18 | r = self.resolution if r is None else r 19 | 20 | data = [] 21 | for n in range(n_samples): 22 | # first sample layers 23 | depth = np.random.choice(d, nb, replace=True).tolist() 24 | # then sample kernel size, expansion rate and resolution 25 | kernel_size = np.random.choice(ks, size=int(np.sum(depth)), replace=True).tolist() 26 | exp_ratio = np.random.choice(e, size=int(np.sum(depth)), replace=True).tolist() 27 | resolution = int(np.random.choice(r)) 28 | 29 | data.append({'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'r': resolution}) 30 | return data 31 | 32 | def initialize(self, n_doe): 33 | # sample one arch with least (lb of hyperparameters) and most complexity (ub of hyperparameters) 34 | data = [ 35 | self.sample(1, ks=[min(self.kernel_size)], e=[min(self.exp_ratio)], 36 | d=[min(self.depth)], r=[min(self.resolution)])[0], 37 | self.sample(1, ks=[max(self.kernel_size)], e=[max(self.exp_ratio)], 38 | d=[max(self.depth)], r=[max(self.resolution)])[0] 39 | ] 40 | data.extend(self.sample(n_samples=n_doe - 2)) 41 | return data 42 | 43 | def pad_zero(self, x, depth): 44 | # pad zeros to make bit-string of equal length 45 | new_x, counter = [], 0 46 | for d in depth: 47 | for _ in range(d): 48 | new_x.append(x[counter]) 49 | counter += 1 50 | if d < max(self.depth): 51 | new_x += [0] * (max(self.depth) - d) 52 | return new_x 53 | 54 | def encode(self, config): 55 | # encode config ({'ks': , 'd': , etc}) to integer bit-string [1, 0, 2, 1, ...] 56 | x = [] 57 | depth = [np.argwhere(_x == np.array(self.depth))[0, 0] for _x in config['d']] 58 | kernel_size = [np.argwhere(_x == np.array(self.kernel_size))[0, 0] for _x in config['ks']] 59 | exp_ratio = [np.argwhere(_x == np.array(self.exp_ratio))[0, 0] for _x in config['e']] 60 | 61 | kernel_size = self.pad_zero(kernel_size, config['d']) 62 | exp_ratio = self.pad_zero(exp_ratio, config['d']) 63 | 64 | for i in range(len(depth)): 65 | x = x + [depth[i]] + kernel_size[i * max(self.depth):i * max(self.depth) + max(self.depth)] \ 66 | + exp_ratio[i * max(self.depth):i * max(self.depth) + max(self.depth)] 67 | x.append(np.argwhere(config['r'] == np.array(self.resolution))[0, 0]) 68 | 69 | return x 70 | 71 | def decode(self, x): 72 | """ 73 | remove un-expressed part of the chromosome 74 | assumes x = [block1, block2, ..., block5, resolution, width_mult]; 75 | block_i = [depth, kernel_size, exp_rate] 76 | """ 77 | depth, kernel_size, exp_rate = [], [], [] 78 | for i in range(0, len(x) - 2, 9): 79 | depth.append(self.depth[x[i]]) 80 | kernel_size.extend(np.array(self.kernel_size)[x[i + 1:i + 1 + self.depth[x[i]]]].tolist()) 81 | exp_rate.extend(np.array(self.exp_ratio)[x[i + 5:i + 5 + self.depth[x[i]]]].tolist()) 82 | return {'ks': kernel_size, 'e': exp_rate, 'd': depth, 'r': self.resolution[x[-1]]} 83 | 84 | -------------------------------------------------------------------------------- /scripts/search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Search Examples 4 | ## In general, set `--n_iter` to `--n_gpus`, which is the # of available gpus you have 5 | 6 | # Maximize Top-1 Accuracy and Minimize #FLOPs on ImageNet 7 | #python msunas.py --sec_obj flops \ 8 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \ 9 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \ 10 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 11 | # --save search-imagenet-flops-w1.0 --iterations 30 --vld_size 10000 12 | 13 | # Maximize Top-1 Accuracy and Minimize #Params on ImageNet 14 | #python msunas.py --sec_obj params \ 15 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \ 16 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \ 17 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 18 | # --save search-imagenet-params-w1.0 --iterations 30 --vld_size 10000 19 | 20 | # Maximize Top-1 Accuracy and Minimize CPU Latency on ImageNet 21 | #python msunas.py --sec_obj cpu \ 22 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \ 23 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \ 24 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 25 | # --save search-imagenet-cpu-w1.0 --iterations 30 --vld_size 10000 26 | 27 | # Maximize Top-1 Accuracy and Minimize #FLOPs on CIFAR-10 28 | #python msunas.py --sec_obj flops \ 29 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \ 30 | # --dataset cifar10 --n_classes 10 \ 31 | # --data /usr/local/soft/temp-datastore/CIFAR/ \ 32 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 33 | # --save search-cifar10-flops-w1.0 --iterations 30 --vld_size 5000 34 | 35 | # Maximize Top-1 Accuracy and Minimize #FLOPs on CIFAR-100 36 | #python msunas.py --sec_obj flops \ 37 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \ 38 | # --dataset cifar100 --n_classes 100 \ 39 | # --data /usr/local/soft/temp-datastore/CIFAR/ \ 40 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 41 | # --save search-cifar100-flops-w1.0 --iterations 30 --vld_size 5000 42 | 43 | # Maximize Top-1 Accuracy and Minimize #FLOPs on CINIC10 44 | #python msunas.py --sec_obj flops \ 45 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \ 46 | # --dataset cinic10 --n_classes 10 \ 47 | # --data /usr/local/soft/temp-datastore/CINIC/ \ 48 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 49 | # --save search-cinic10-flops-w1.0 --iterations 30 --vld_size 10000 50 | 51 | # Maximize Top-1 Accuracy and Minimize #FLOPs on STL-10 52 | #python msunas.py --sec_obj flops \ 53 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \ 54 | # --dataset stl10 --n_classes 10 \ 55 | # --data /usr/local/soft/temp-datastore/STL10/ \ 56 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 57 | # --save search-stl10-flops-w1.0 --iterations 30 --vld_size 500 58 | 59 | # Maximize Top-1 Accuracy and Minimize #FLOPs on Aircraft 60 | #python msunas.py --sec_obj flops \ 61 | # --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 5 \ 62 | # --dataset aircraft --n_classes 100 \ 63 | # --data /usr/local/soft/temp-datastore/Aircraft/ \ 64 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 65 | # --save search-aircraft-flops-w1.0 --iterations 30 --vld_size 500 66 | 67 | # Random Search on ImageNet (set `--n_doe` to the total # of archs you want to sample) 68 | #python msunas.py --n_doe 350 --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \ 69 | # --data /usr/local/soft/temp-datastore/ILSVRC2012/ \ 70 | # --predictor rbf --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 71 | # --save random-imagenet-w1.0 --iterations 0 --vld_size 10000 72 | 73 | # Random Search on CIFAR-10 (set `--n_doe` to the total # of archs you want to sample) 74 | #python msunas.py --n_doe 350 --n_gpus 8 --gpu 1 --n_workers 4 --n_epochs 0 \ 75 | # --dataset cifar10 --n_classes 10 \ 76 | # --data /usr/local/soft/temp-datastore/CIFAR/ \ 77 | # --predictor as --supernet_path data/ofa_mbv3_d234_e346_k357_w1.0 \ 78 | # --save random-cifar10-w1.0 --iterations 0 --vld_size 5000 79 | 80 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import torch 4 | import logging 5 | import argparse 6 | from collections import OrderedDict 7 | 8 | from timm.utils import accuracy, AverageMeter, setup_default_logging 9 | 10 | from codebase.run_manager import get_run_config 11 | from codebase.networks.nsganetv2 import NSGANetV2 12 | 13 | 14 | def validate(model, loader, criterion, log_freq=50): 15 | batch_time = AverageMeter() 16 | losses = AverageMeter() 17 | top1 = AverageMeter() 18 | top5 = AverageMeter() 19 | 20 | model.eval() 21 | end = time.time() 22 | with torch.no_grad(): 23 | for i, (input, target) in enumerate(loader): 24 | target = target.cuda() 25 | input = input.cuda() 26 | 27 | # compute output 28 | output = model(input) 29 | loss = criterion(output, target) 30 | 31 | # measure accuracy and record loss 32 | acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) 33 | losses.update(loss.item(), input.size(0)) 34 | top1.update(acc1.item(), input.size(0)) 35 | top5.update(acc5.item(), input.size(0)) 36 | 37 | # measure elapsed time 38 | batch_time.update(time.time() - end) 39 | end = time.time() 40 | 41 | if i % log_freq == 0: 42 | logging.info( 43 | 'Test: [{0:>4d}/{1}] ' 44 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 45 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 46 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 47 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( 48 | i, len(loader), batch_time=batch_time, 49 | rate_avg=input.size(0) / batch_time.avg, 50 | loss=losses, top1=top1, top5=top5)) 51 | 52 | results = OrderedDict( 53 | top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), 54 | top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4)) 55 | 56 | logging.info(' * Acc@1 {:.1f} ({:.3f}) Acc@5 {:.1f} ({:.3f})'.format( 57 | results['top1'], results['top1_err'], results['top5'], results['top5_err'])) 58 | 59 | 60 | def main(args): 61 | setup_default_logging() 62 | 63 | logging.info('Running validation on {}'.format(args.dataset)) 64 | 65 | net_config = json.load(open(args.model)) 66 | if 'img_size' in net_config: 67 | img_size = net_config['img_size'] 68 | else: 69 | img_size = args.img_size 70 | 71 | run_config = get_run_config( 72 | dataset=args.dataset, data_path=args.data, image_size=img_size, n_epochs=0, 73 | train_batch_size=args.batch_size, test_batch_size=args.batch_size, 74 | n_worker=args.workers, valid_size=None) 75 | 76 | model = NSGANetV2.build_from_config(net_config) 77 | try: 78 | model.load_state_dict(torch.load(args.pretrained, map_location='cpu')) 79 | except: 80 | model.load_state_dict(torch.load(args.pretrained, map_location='cpu')['state_dict']) 81 | 82 | param_count = sum([m.numel() for m in model.parameters()]) 83 | logging.info('Model created, param count: %d' % param_count) 84 | 85 | model = model.cuda() 86 | criterion = torch.nn.CrossEntropyLoss().cuda() 87 | 88 | validate(model, run_config.test_loader, criterion) 89 | 90 | return 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser() 95 | # data related settings 96 | parser.add_argument('--data', type=str, default='/mnt/datastore/ILSVRC2012', 97 | help='location of the data corpus') 98 | parser.add_argument('--dataset', type=str, default='imagenet', 99 | help='name of the dataset (imagenet, cifar10, cifar100, ...)') 100 | parser.add_argument('-j', '--workers', type=int, default=6, 101 | help='number of workers for data loading') 102 | parser.add_argument('-b', '--batch-size', type=int, default=256, 103 | help='test batch size for inference') 104 | parser.add_argument('--img-size', type=int, default=224, 105 | help='input resolution (192 -> 256)') 106 | # model related settings 107 | parser.add_argument('--model', '-m', metavar='MODEL', default='', type=str, 108 | help='model configuration file') 109 | parser.add_argument('--pretrained', type=str, default='', 110 | help='path to pretrained weights') 111 | cfgs = parser.parse_args() 112 | 113 | main(cfgs) 114 | -------------------------------------------------------------------------------- /post_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from pymoo.factory import get_decomposition 6 | from pymoo.visualization.scatter import Scatter 7 | from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting 8 | from pymoo.model.decision_making import DecisionMaking, normalize, find_outliers_upper_tail, NeighborFinder 9 | 10 | _DEBUG = False 11 | 12 | 13 | class HighTradeoffPoints(DecisionMaking): 14 | 15 | def __init__(self, epsilon=0.125, n_survive=None, **kwargs) -> None: 16 | super().__init__(**kwargs) 17 | self.epsilon = epsilon 18 | self.n_survive = n_survive # number of points to be selected 19 | 20 | def _do(self, F, **kwargs): 21 | n, m = F.shape 22 | 23 | if self.normalize: 24 | F = normalize(F, self.ideal_point, self.nadir_point, estimate_bounds_if_none=True) 25 | 26 | neighbors_finder = NeighborFinder(F, epsilon=0.125, n_min_neigbors="auto", consider_2d=False) 27 | 28 | mu = np.full(n, - np.inf) 29 | 30 | # for each solution in the set calculate the least amount of improvement per unit deterioration 31 | for i in range(n): 32 | 33 | # for each neighbour in a specific radius of that solution 34 | neighbors = neighbors_finder.find(i) 35 | 36 | # calculate the trade-off to all neighbours 37 | diff = F[neighbors] - F[i] 38 | 39 | # calculate sacrifice and gain 40 | sacrifice = np.maximum(0, diff).sum(axis=1) 41 | gain = np.maximum(0, -diff).sum(axis=1) 42 | 43 | np.warnings.filterwarnings('ignore') 44 | tradeoff = sacrifice / gain 45 | 46 | # otherwise find the one with the smalled one 47 | mu[i] = np.nanmin(tradeoff) 48 | if self.n_survive is not None: 49 | return np.argsort(mu)[-self.n_survive:] 50 | else: 51 | return find_outliers_upper_tail(mu) # return points with trade-off > 2*sigma 52 | 53 | 54 | def main(args): 55 | # preferences 56 | if args.prefer is not None: 57 | preferences = {} 58 | for p in args.prefer.split("+"): 59 | k, v = p.split("#") 60 | if k == 'top1': 61 | preferences[k] = 100 - float(v) # assuming top-1 accuracy 62 | else: 63 | preferences[k] = float(v) 64 | weights = np.fromiter(preferences.values(), dtype=float) 65 | 66 | archive = json.load(open(args.expr))['archive'] 67 | subnets, top1, sec_obj = [v[0] for v in archive], [v[1] for v in archive], [v[2] for v in archive] 68 | sort_idx = np.argsort(top1) 69 | F = np.column_stack((top1, sec_obj))[sort_idx, :] 70 | front = NonDominatedSorting().do(F, only_non_dominated_front=True) 71 | pf = F[front, :] 72 | ps = np.array(subnets)[sort_idx][front] 73 | 74 | if args.prefer is not None: 75 | # choose the architectures thats closest to the preferences 76 | I = get_decomposition("asf").do(pf, weights).argsort()[:args.n] 77 | else: 78 | # choose the architectures with highest trade-off 79 | dm = HighTradeoffPoints(n_survive=args.n) 80 | I = dm.do(pf) 81 | 82 | # always add most accurate architectures 83 | I = np.append(I, 0) 84 | 85 | # create the supernet 86 | from evaluator import OFAEvaluator 87 | supernet = OFAEvaluator(model_path=args.supernet_path) 88 | 89 | for idx in I: 90 | save = os.path.join(args.save, "net-flops@{:.0f}".format(pf[idx, 1])) 91 | os.makedirs(save, exist_ok=True) 92 | subnet, _ = supernet.sample({'ks': ps[idx]['ks'], 'e': ps[idx]['e'], 'd': ps[idx]['d']}) 93 | with open(os.path.join(save, "net.subnet"), 'w') as handle: 94 | json.dump(ps[idx], handle) 95 | supernet.save_net_config(save, subnet, "net.config") 96 | supernet.save_net(save, subnet, "net.inherited") 97 | 98 | if _DEBUG: 99 | print(ps[I]) 100 | plot = Scatter() 101 | plot.add(pf, alpha=0.2) 102 | plot.add(pf[I, :], color="red", s=100) 103 | plot.show() 104 | 105 | return 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('--save', type=str, default='.tmp', 111 | help='location of dir to save') 112 | parser.add_argument('--expr', type=str, default='', 113 | help='location of search experiment dir') 114 | parser.add_argument('--prefer', type=str, default=None, 115 | help='preferences in choosing architectures (top1#80+flops#150)') 116 | parser.add_argument('-n', type=int, default=1, 117 | help='number of architectures desired') 118 | parser.add_argument('--supernet_path', type=str, default='./data/ofa_mbv3_d234_e346_k357_w1.0', 119 | help='file path to supernet weights') 120 | 121 | cfgs = parser.parse_args() 122 | main(cfgs) 123 | -------------------------------------------------------------------------------- /acc_predictor/mlp.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from utils import get_correlation 6 | 7 | 8 | class Net(nn.Module): 9 | # N-layer MLP 10 | def __init__(self, n_feature, n_layers=2, n_hidden=300, n_output=1, drop=0.2): 11 | super(Net, self).__init__() 12 | 13 | self.stem = nn.Sequential(nn.Linear(n_feature, n_hidden), nn.ReLU()) 14 | 15 | hidden_layers = [] 16 | for _ in range(n_layers): 17 | hidden_layers.append(nn.Linear(n_hidden, n_hidden)) 18 | hidden_layers.append(nn.ReLU()) 19 | self.hidden = nn.Sequential(*hidden_layers) 20 | 21 | self.regressor = nn.Linear(n_hidden, n_output) # output layer 22 | self.drop = nn.Dropout(p=drop) 23 | 24 | def forward(self, x): 25 | x = self.stem(x) 26 | x = self.hidden(x) 27 | x = self.drop(x) 28 | x = self.regressor(x) # linear output 29 | return x 30 | 31 | @staticmethod 32 | def init_weights(m): 33 | if type(m) == nn.Linear: 34 | n = m.in_features 35 | y = 1.0 / np.sqrt(n) 36 | m.weight.data.uniform_(-y, y) 37 | m.bias.data.fill_(0) 38 | 39 | 40 | class MLP: 41 | """ Multi Layer Perceptron """ 42 | def __init__(self, **kwargs): 43 | self.model = Net(**kwargs) 44 | self.name = 'mlp' 45 | 46 | def fit(self, **kwargs): 47 | self.model = train(self.model, **kwargs) 48 | 49 | def predict(self, test_data, device='cpu'): 50 | return predict(self.model, test_data, device=device) 51 | 52 | 53 | def train(net, x, y, trn_split=0.8, pretrained=None, device='cpu', 54 | lr=8e-4, epochs=2000, verbose=False): 55 | 56 | n_samples = x.shape[0] 57 | target = torch.zeros(n_samples, 1) 58 | perm = torch.randperm(target.size(0)) 59 | trn_idx = perm[:int(n_samples * trn_split)] 60 | vld_idx = perm[int(n_samples * trn_split):] 61 | 62 | inputs = torch.from_numpy(x).float() 63 | target[:, 0] = torch.from_numpy(y).float() 64 | 65 | # back-propagation training of a NN 66 | if pretrained is not None: 67 | print("Constructing MLP surrogate model with pre-trained weights") 68 | init = torch.load(pretrained, map_location='cpu') 69 | net.load_state_dict(init) 70 | best_net = copy.deepcopy(net) 71 | else: 72 | # print("Constructing MLP surrogate model with " 73 | # "sample size = {}, epochs = {}".format(x.shape[0], epochs)) 74 | 75 | # initialize the weights 76 | # net.apply(Net.init_weights) 77 | net = net.to(device) 78 | optimizer = torch.optim.Adam(net.parameters(), lr=lr) 79 | criterion = nn.SmoothL1Loss() 80 | # criterion = nn.MSELoss() 81 | 82 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(epochs), eta_min=0) 83 | 84 | best_loss = 1e33 85 | for epoch in range(epochs): 86 | trn_inputs = inputs[trn_idx] 87 | trn_labels = target[trn_idx] 88 | loss_trn = train_one_epoch(net, trn_inputs, trn_labels, criterion, optimizer, device) 89 | loss_vld = infer(net, inputs[vld_idx], target[vld_idx], criterion, device) 90 | scheduler.step() 91 | 92 | # if epoch % 500 == 0 and verbose: 93 | # print("Epoch {:4d}: trn loss = {:.4E}, vld loss = {:.4E}".format(epoch, loss_trn, loss_vld)) 94 | 95 | if loss_vld < best_loss: 96 | best_loss = loss_vld 97 | best_net = copy.deepcopy(net) 98 | 99 | validate(best_net, inputs, target, device=device) 100 | 101 | return best_net.to('cpu') 102 | 103 | 104 | def train_one_epoch(net, data, target, criterion, optimizer, device): 105 | net.train() 106 | optimizer.zero_grad() 107 | 108 | data, target = data.to(device), target.to(device) 109 | pred = net(data) 110 | loss = criterion(pred, target) 111 | loss.backward() 112 | optimizer.step() 113 | 114 | return loss.item() 115 | 116 | 117 | def infer(net, data, target, criterion, device): 118 | net.eval() 119 | 120 | with torch.no_grad(): 121 | data, target = data.to(device), target.to(device) 122 | pred = net(data) 123 | loss = criterion(pred, target) 124 | 125 | return loss.item() 126 | 127 | 128 | def validate(net, data, target, device): 129 | net.eval() 130 | 131 | with torch.no_grad(): 132 | data, target = data.to(device), target.to(device) 133 | pred = net(data) 134 | pred, target = pred.cpu().detach().numpy(), target.cpu().detach().numpy() 135 | 136 | rmse, rho, tau = get_correlation(pred, target) 137 | 138 | # print("Validation RMSE = {:.4f}, Spearman's Rho = {:.4f}, Kendall’s Tau = {:.4f}".format(rmse, rho, tau)) 139 | return rmse, rho, tau, pred, target 140 | 141 | 142 | def predict(net, query, device): 143 | 144 | if query.ndim < 2: 145 | data = torch.zeros(1, query.shape[0]) 146 | data[0, :] = torch.from_numpy(query).float() 147 | else: 148 | data = torch.from_numpy(query).float() 149 | 150 | net = net.to(device) 151 | net.eval() 152 | with torch.no_grad(): 153 | data = data.to(device) 154 | pred = net(data) 155 | 156 | return pred.cpu().detach().numpy() -------------------------------------------------------------------------------- /codebase/networks/nsganetv2.py: -------------------------------------------------------------------------------- 1 | from timm.models.layers import drop_path 2 | from ofa.utils.layers import * 3 | from ofa.utils import MyModule 4 | from ofa.imagenet_classification.networks import MobileNetV3 5 | 6 | 7 | class MobileInvertedResidualBlock(MyModule): 8 | """ 9 | Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/ 10 | imagenet_codebase/networks/proxyless_nets.py to include drop path in training 11 | 12 | """ 13 | def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0): 14 | super(MobileInvertedResidualBlock, self).__init__() 15 | 16 | self.mobile_inverted_conv = mobile_inverted_conv 17 | self.shortcut = shortcut 18 | self.drop_connect_rate = drop_connect_rate 19 | 20 | def forward(self, x): 21 | if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer): 22 | res = x 23 | elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer): 24 | res = self.mobile_inverted_conv(x) 25 | else: 26 | # res = self.mobile_inverted_conv(x) + self.shortcut(x) 27 | res = self.mobile_inverted_conv(x) 28 | 29 | if self.drop_connect_rate > 0.: 30 | res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training) 31 | 32 | res += self.shortcut(x) 33 | 34 | return res 35 | 36 | @property 37 | def module_str(self): 38 | return '(%s, %s)' % ( 39 | self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None, 40 | self.shortcut.module_str if self.shortcut is not None else None 41 | ) 42 | 43 | @property 44 | def config(self): 45 | return { 46 | 'name': MobileInvertedResidualBlock.__name__, 47 | 'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None, 48 | 'shortcut': self.shortcut.config if self.shortcut is not None else None, 49 | } 50 | 51 | @staticmethod 52 | def build_from_config(config): 53 | mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv']) 54 | shortcut = set_layer_from_config(config['shortcut']) 55 | return MobileInvertedResidualBlock( 56 | mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate']) 57 | 58 | 59 | class NSGANetV2(MobileNetV3): 60 | """ 61 | Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/ 62 | imagenet_codebase/networks/mobilenet_v3.py to include drop path in training 63 | and option to reset classification layer 64 | """ 65 | @staticmethod 66 | def build_from_config(config, drop_connect_rate=0.0): 67 | first_conv = set_layer_from_config(config['first_conv']) 68 | final_expand_layer = set_layer_from_config(config['final_expand_layer']) 69 | feature_mix_layer = set_layer_from_config(config['feature_mix_layer']) 70 | classifier = set_layer_from_config(config['classifier']) 71 | 72 | blocks = [] 73 | for block_idx, block_config in enumerate(config['blocks']): 74 | block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks']) 75 | blocks.append(MobileInvertedResidualBlock.build_from_config(block_config)) 76 | 77 | net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier) 78 | if 'bn' in config: 79 | net.set_bn_param(**config['bn']) 80 | else: 81 | net.set_bn_param(momentum=0.1, eps=1e-3) 82 | 83 | return net 84 | 85 | def zero_last_gamma(self): 86 | for m in self.modules(): 87 | if isinstance(m, MobileInvertedResidualBlock): 88 | if isinstance(m.mobile_inverted_conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer): 89 | m.mobile_inverted_conv.point_linear.bn.weight.data.zero_() 90 | 91 | @staticmethod 92 | def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate): 93 | # first conv layer 94 | first_conv = ConvLayer( 95 | 3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act' 96 | ) 97 | # build mobile blocks 98 | feature_dim = input_channel 99 | blocks = [] 100 | for stage_id, block_config_list in cfg.items(): 101 | for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list: 102 | mb_conv = MBConvLayer( 103 | feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se 104 | ) 105 | if stride == 1 and out_channel == feature_dim: 106 | shortcut = IdentityLayer(out_channel, out_channel) 107 | else: 108 | shortcut = None 109 | blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut)) 110 | feature_dim = out_channel 111 | # final expand layer 112 | final_expand_layer = ConvLayer( 113 | feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act', 114 | ) 115 | feature_dim = feature_dim * 6 116 | # feature mix layer 117 | feature_mix_layer = ConvLayer( 118 | feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish', 119 | ) 120 | # classifier 121 | classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) 122 | 123 | return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier 124 | 125 | @staticmethod 126 | def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0): 127 | model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate) -------------------------------------------------------------------------------- /codebase/data_providers/stl10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | 5 | import torchvision 6 | import torch.utils.data 7 | import torchvision.transforms as transforms 8 | 9 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler 10 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider 11 | 12 | 13 | class STL10DataProvider(DataProvider): 14 | 15 | def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None, 16 | n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None): 17 | 18 | self._save_path = save_path 19 | 20 | self.image_size = image_size # int or list of int 21 | self.distort_color = distort_color 22 | self.resize_scale = resize_scale 23 | 24 | self._valid_transform_dict = {} 25 | if not isinstance(self.image_size, int): 26 | assert isinstance(self.image_size, list) 27 | from ofa.utils.my_dataloader import MyDataLoader 28 | self.image_size.sort() # e.g., 160 -> 224 29 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() 30 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) 31 | 32 | for img_size in self.image_size: 33 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size) 34 | self.active_img_size = max(self.image_size) 35 | valid_transforms = self._valid_transform_dict[self.active_img_size] 36 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image 37 | else: 38 | self.active_img_size = self.image_size 39 | valid_transforms = self.build_valid_transform() 40 | train_loader_class = torch.utils.data.DataLoader 41 | 42 | train_transforms = self.build_train_transform() 43 | train_dataset = self.train_dataset(train_transforms) 44 | 45 | if valid_size is not None: 46 | if not isinstance(valid_size, int): 47 | assert isinstance(valid_size, float) and 0 < valid_size < 1 48 | valid_size = int(len(train_dataset.data) * valid_size) 49 | 50 | valid_dataset = self.train_dataset(valid_transforms) 51 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size) 52 | 53 | if num_replicas is not None: 54 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes)) 55 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes)) 56 | else: 57 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes) 58 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes) 59 | 60 | self.train = train_loader_class( 61 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 62 | num_workers=n_worker, pin_memory=True, 63 | ) 64 | self.valid = torch.utils.data.DataLoader( 65 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler, 66 | num_workers=n_worker, pin_memory=True, 67 | ) 68 | else: 69 | if num_replicas is not None: 70 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank) 71 | self.train = train_loader_class( 72 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 73 | num_workers=n_worker, pin_memory=True 74 | ) 75 | else: 76 | self.train = train_loader_class( 77 | train_dataset, batch_size=train_batch_size, shuffle=True, 78 | num_workers=n_worker, pin_memory=True, 79 | ) 80 | self.valid = None 81 | 82 | test_dataset = self.test_dataset(valid_transforms) 83 | if num_replicas is not None: 84 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank) 85 | self.test = torch.utils.data.DataLoader( 86 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True, 87 | ) 88 | else: 89 | self.test = torch.utils.data.DataLoader( 90 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True, 91 | ) 92 | 93 | if self.valid is None: 94 | self.valid = self.test 95 | 96 | @staticmethod 97 | def name(): 98 | return 'stl10' 99 | 100 | @property 101 | def data_shape(self): 102 | return 3, self.active_img_size, self.active_img_size # C, H, W 103 | 104 | @property 105 | def n_classes(self): 106 | return 10 107 | 108 | @property 109 | def save_path(self): 110 | if self._save_path is None: 111 | self._save_path = '/mnt/datastore/STL10' # home server 112 | 113 | if not os.path.exists(self._save_path): 114 | self._save_path = '/mnt/datastore/STL10' # home server 115 | return self._save_path 116 | 117 | @property 118 | def data_url(self): 119 | raise ValueError('unable to download %s' % self.name()) 120 | 121 | def train_dataset(self, _transforms): 122 | # dataset = datasets.ImageFolder(self.train_path, _transforms) 123 | dataset = torchvision.datasets.STL10( 124 | root=self.valid_path, split='train', download=False, transform=_transforms) 125 | return dataset 126 | 127 | def test_dataset(self, _transforms): 128 | # dataset = datasets.ImageFolder(self.valid_path, _transforms) 129 | dataset = torchvision.datasets.STL10( 130 | root=self.valid_path, split='test', download=False, transform=_transforms) 131 | return dataset 132 | 133 | @property 134 | def train_path(self): 135 | # return os.path.join(self.save_path, 'train') 136 | return self.save_path 137 | 138 | @property 139 | def valid_path(self): 140 | # return os.path.join(self.save_path, 'val') 141 | return self.save_path 142 | 143 | @property 144 | def normalize(self): 145 | return transforms.Normalize( 146 | mean=[0.44671097, 0.4398105, 0.4066468], 147 | std=[0.2603405, 0.25657743, 0.27126738]) 148 | 149 | def build_train_transform(self, image_size=None, print_log=True): 150 | if image_size is None: 151 | image_size = self.image_size 152 | if print_log: 153 | print('Color jitter: %s, resize_scale: %s, img_size: %s' % 154 | (self.distort_color, self.resize_scale, image_size)) 155 | 156 | if self.distort_color == 'torch': 157 | color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) 158 | elif self.distort_color == 'tf': 159 | color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5) 160 | else: 161 | color_transform = None 162 | 163 | if isinstance(image_size, list): 164 | resize_transform_class = MyRandomResizedCrop 165 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(), 166 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS)) 167 | else: 168 | resize_transform_class = transforms.RandomResizedCrop 169 | 170 | train_transforms = [ 171 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)), 172 | transforms.RandomHorizontalFlip(), 173 | ] 174 | if color_transform is not None: 175 | train_transforms.append(color_transform) 176 | train_transforms += [ 177 | transforms.ToTensor(), 178 | self.normalize, 179 | ] 180 | 181 | train_transforms = transforms.Compose(train_transforms) 182 | return train_transforms 183 | 184 | def build_valid_transform(self, image_size=None): 185 | if image_size is None: 186 | image_size = self.active_img_size 187 | return transforms.Compose([ 188 | transforms.Resize(int(math.ceil(image_size / 0.875))), 189 | transforms.CenterCrop(image_size), 190 | transforms.ToTensor(), 191 | self.normalize, 192 | ]) 193 | 194 | def assign_active_img_size(self, new_img_size): 195 | self.active_img_size = new_img_size 196 | if self.active_img_size not in self._valid_transform_dict: 197 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform() 198 | # change the transform of the valid and test set 199 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] 200 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] 201 | 202 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None): 203 | # used for resetting running statistics 204 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None: 205 | if num_worker is None: 206 | num_worker = self.train.num_workers 207 | 208 | n_samples = len(self.train.dataset.data) 209 | g = torch.Generator() 210 | g.manual_seed(DataProvider.SUB_SEED) 211 | rand_indexes = torch.randperm(n_samples, generator=g).tolist() 212 | 213 | new_train_dataset = self.train_dataset( 214 | self.build_train_transform(image_size=self.active_img_size, print_log=False)) 215 | chosen_indexes = rand_indexes[:n_images] 216 | if num_replicas is not None: 217 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes)) 218 | else: 219 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes) 220 | sub_data_loader = torch.utils.data.DataLoader( 221 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler, 222 | num_workers=num_worker, pin_memory=True, 223 | ) 224 | self.__dict__['sub_train_%d' % self.active_img_size] = [] 225 | for images, labels in sub_data_loader: 226 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels)) 227 | return self.__dict__['sub_train_%d' % self.active_img_size] 228 | -------------------------------------------------------------------------------- /codebase/data_providers/flowers102.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | import math 4 | import numpy as np 5 | 6 | import PIL 7 | 8 | import torch.utils.data 9 | import torchvision.transforms as transforms 10 | import torchvision.datasets as datasets 11 | 12 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler 13 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider 14 | 15 | 16 | class Flowers102DataProvider(DataProvider): 17 | 18 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32, 19 | resize_scale=0.08, distort_color=None, image_size=224, 20 | num_replicas=None, rank=None): 21 | 22 | # warnings.filterwarnings('ignore') 23 | self._save_path = save_path 24 | 25 | self.image_size = image_size # int or list of int 26 | self.distort_color = distort_color 27 | self.resize_scale = resize_scale 28 | 29 | self._valid_transform_dict = {} 30 | if not isinstance(self.image_size, int): 31 | assert isinstance(self.image_size, list) 32 | from ofa.utils.my_dataloader import MyDataLoader 33 | self.image_size.sort() # e.g., 160 -> 224 34 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() 35 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) 36 | 37 | for img_size in self.image_size: 38 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size) 39 | self.active_img_size = max(self.image_size) 40 | valid_transforms = self._valid_transform_dict[self.active_img_size] 41 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image 42 | else: 43 | self.active_img_size = self.image_size 44 | valid_transforms = self.build_valid_transform() 45 | train_loader_class = torch.utils.data.DataLoader 46 | 47 | train_transforms = self.build_train_transform() 48 | train_dataset = self.train_dataset(train_transforms) 49 | 50 | weights = self.make_weights_for_balanced_classes( 51 | train_dataset.imgs, self.n_classes) 52 | weights = torch.DoubleTensor(weights) 53 | train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) 54 | 55 | if valid_size is not None: 56 | raise NotImplementedError("validation dataset not yet implemented") 57 | # valid_dataset = self.valid_dataset(valid_transforms) 58 | 59 | # self.train = train_loader_class( 60 | # train_dataset, batch_size=train_batch_size, sampler=train_sampler, 61 | # num_workers=n_worker, pin_memory=True) 62 | # self.valid = torch.utils.data.DataLoader( 63 | # valid_dataset, batch_size=test_batch_size, 64 | # num_workers=n_worker, pin_memory=True) 65 | else: 66 | self.train = train_loader_class( 67 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 68 | num_workers=n_worker, pin_memory=True, 69 | ) 70 | self.valid = None 71 | 72 | test_dataset = self.test_dataset(valid_transforms) 73 | self.test = torch.utils.data.DataLoader( 74 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True, 75 | ) 76 | 77 | if self.valid is None: 78 | self.valid = self.test 79 | 80 | @staticmethod 81 | def name(): 82 | return 'flowers102' 83 | 84 | @property 85 | def data_shape(self): 86 | return 3, self.active_img_size, self.active_img_size # C, H, W 87 | 88 | @property 89 | def n_classes(self): 90 | return 102 91 | 92 | @property 93 | def save_path(self): 94 | if self._save_path is None: 95 | # self._save_path = '/mnt/datastore/Oxford102Flowers' # home server 96 | self._save_path = '/mnt/datastore/Flowers102' # home server 97 | 98 | if not os.path.exists(self._save_path): 99 | # self._save_path = '/mnt/datastore/Oxford102Flowers' # home server 100 | self._save_path = '/mnt/datastore/Flowers102' # home server 101 | return self._save_path 102 | 103 | @property 104 | def data_url(self): 105 | raise ValueError('unable to download %s' % self.name()) 106 | 107 | def train_dataset(self, _transforms): 108 | dataset = datasets.ImageFolder(self.train_path, _transforms) 109 | return dataset 110 | 111 | # def valid_dataset(self, _transforms): 112 | # dataset = datasets.ImageFolder(self.valid_path, _transforms) 113 | # return dataset 114 | 115 | def test_dataset(self, _transforms): 116 | dataset = datasets.ImageFolder(self.test_path, _transforms) 117 | return dataset 118 | 119 | @property 120 | def train_path(self): 121 | return os.path.join(self.save_path, 'train') 122 | 123 | # @property 124 | # def valid_path(self): 125 | # return os.path.join(self.save_path, 'train') 126 | 127 | @property 128 | def test_path(self): 129 | return os.path.join(self.save_path, 'test') 130 | 131 | @property 132 | def normalize(self): 133 | return transforms.Normalize( 134 | mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803], 135 | std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665]) 136 | 137 | @staticmethod 138 | def make_weights_for_balanced_classes(images, nclasses): 139 | count = [0] * nclasses 140 | 141 | # Counts per label 142 | for item in images: 143 | count[item[1]] += 1 144 | 145 | weight_per_class = [0.] * nclasses 146 | 147 | # Total number of images. 148 | N = float(sum(count)) 149 | 150 | # super-sample the smaller classes. 151 | for i in range(nclasses): 152 | weight_per_class[i] = N / float(count[i]) 153 | 154 | weight = [0] * len(images) 155 | 156 | # Calculate a weight per image. 157 | for idx, val in enumerate(images): 158 | weight[idx] = weight_per_class[val[1]] 159 | 160 | return weight 161 | 162 | def build_train_transform(self, image_size=None, print_log=True): 163 | if image_size is None: 164 | image_size = self.image_size 165 | if print_log: 166 | print('Color jitter: %s, resize_scale: %s, img_size: %s' % 167 | (self.distort_color, self.resize_scale, image_size)) 168 | 169 | if self.distort_color == 'torch': 170 | color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) 171 | elif self.distort_color == 'tf': 172 | color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5) 173 | else: 174 | color_transform = None 175 | 176 | if isinstance(image_size, list): 177 | resize_transform_class = MyRandomResizedCrop 178 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(), 179 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS)) 180 | else: 181 | resize_transform_class = transforms.RandomResizedCrop 182 | 183 | train_transforms = [ 184 | transforms.RandomAffine( 185 | 45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0), 186 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)), 187 | # transforms.RandomHorizontalFlip(), 188 | ] 189 | if color_transform is not None: 190 | train_transforms.append(color_transform) 191 | train_transforms += [ 192 | transforms.ToTensor(), 193 | self.normalize, 194 | ] 195 | 196 | train_transforms = transforms.Compose(train_transforms) 197 | return train_transforms 198 | 199 | def build_valid_transform(self, image_size=None): 200 | if image_size is None: 201 | image_size = self.active_img_size 202 | return transforms.Compose([ 203 | transforms.Resize(int(math.ceil(image_size / 0.875))), 204 | transforms.CenterCrop(image_size), 205 | transforms.ToTensor(), 206 | self.normalize, 207 | ]) 208 | 209 | def assign_active_img_size(self, new_img_size): 210 | self.active_img_size = new_img_size 211 | if self.active_img_size not in self._valid_transform_dict: 212 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform() 213 | # change the transform of the valid and test set 214 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] 215 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] 216 | 217 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None): 218 | # used for resetting running statistics 219 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None: 220 | if num_worker is None: 221 | num_worker = self.train.num_workers 222 | 223 | n_samples = len(self.train.dataset.samples) 224 | g = torch.Generator() 225 | g.manual_seed(DataProvider.SUB_SEED) 226 | rand_indexes = torch.randperm(n_samples, generator=g).tolist() 227 | 228 | new_train_dataset = self.train_dataset( 229 | self.build_train_transform(image_size=self.active_img_size, print_log=False)) 230 | chosen_indexes = rand_indexes[:n_images] 231 | if num_replicas is not None: 232 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes)) 233 | else: 234 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes) 235 | sub_data_loader = torch.utils.data.DataLoader( 236 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler, 237 | num_workers=num_worker, pin_memory=True, 238 | ) 239 | self.__dict__['sub_train_%d' % self.active_img_size] = [] 240 | for images, labels in sub_data_loader: 241 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels)) 242 | return self.__dict__['sub_train_%d' % self.active_img_size] 243 | -------------------------------------------------------------------------------- /codebase/data_providers/imagenet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | import math 4 | import numpy as np 5 | 6 | import torch.utils.data 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | 10 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler 11 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider 12 | 13 | 14 | class ImagenetDataProvider(DataProvider): 15 | 16 | def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32, 17 | resize_scale=0.08, distort_color=None, image_size=224, 18 | num_replicas=None, rank=None): 19 | 20 | warnings.filterwarnings('ignore') 21 | self._save_path = save_path 22 | 23 | self.image_size = image_size # int or list of int 24 | self.distort_color = distort_color 25 | self.resize_scale = resize_scale 26 | 27 | self._valid_transform_dict = {} 28 | if not isinstance(self.image_size, int): 29 | assert isinstance(self.image_size, list) 30 | from ofa.utils.my_dataloader import MyDataLoader 31 | self.image_size.sort() # e.g., 160 -> 224 32 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() 33 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) 34 | 35 | for img_size in self.image_size: 36 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size) 37 | self.active_img_size = max(self.image_size) 38 | valid_transforms = self._valid_transform_dict[self.active_img_size] 39 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image 40 | else: 41 | self.active_img_size = self.image_size 42 | valid_transforms = self.build_valid_transform() 43 | train_loader_class = torch.utils.data.DataLoader 44 | 45 | train_transforms = self.build_train_transform() 46 | train_dataset = self.train_dataset(train_transforms) 47 | 48 | if valid_size is not None: 49 | if not isinstance(valid_size, int): 50 | assert isinstance(valid_size, float) and 0 < valid_size < 1 51 | valid_size = int(len(train_dataset.samples) * valid_size) 52 | 53 | valid_dataset = self.train_dataset(valid_transforms) 54 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size) 55 | 56 | if num_replicas is not None: 57 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes)) 58 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes)) 59 | else: 60 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes) 61 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes) 62 | 63 | self.train = train_loader_class( 64 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 65 | num_workers=n_worker, pin_memory=True, 66 | ) 67 | self.valid = torch.utils.data.DataLoader( 68 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler, 69 | num_workers=n_worker, pin_memory=True, 70 | ) 71 | else: 72 | if num_replicas is not None: 73 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank) 74 | self.train = train_loader_class( 75 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 76 | num_workers=n_worker, pin_memory=True 77 | ) 78 | else: 79 | self.train = train_loader_class( 80 | train_dataset, batch_size=train_batch_size, shuffle=True, 81 | num_workers=n_worker, pin_memory=True, 82 | ) 83 | self.valid = None 84 | 85 | test_dataset = self.test_dataset(valid_transforms) 86 | if num_replicas is not None: 87 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank) 88 | self.test = torch.utils.data.DataLoader( 89 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True, 90 | ) 91 | else: 92 | self.test = torch.utils.data.DataLoader( 93 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True, 94 | ) 95 | 96 | if self.valid is None: 97 | self.valid = self.test 98 | 99 | @staticmethod 100 | def name(): 101 | return 'imagenet' 102 | 103 | @property 104 | def data_shape(self): 105 | return 3, self.active_img_size, self.active_img_size # C, H, W 106 | 107 | @property 108 | def n_classes(self): 109 | return 1000 110 | 111 | @property 112 | def save_path(self): 113 | if self._save_path is None: 114 | # self._save_path = '/dataset/imagenet' 115 | # self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers 116 | self._save_path = '/mnt/datastore/ILSVRC2012' # home server 117 | 118 | if not os.path.exists(self._save_path): 119 | # self._save_path = os.path.expanduser('~/dataset/imagenet') 120 | # self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012') 121 | self._save_path = '/mnt/datastore/ILSVRC2012' # home server 122 | return self._save_path 123 | 124 | @property 125 | def data_url(self): 126 | raise ValueError('unable to download %s' % self.name()) 127 | 128 | def train_dataset(self, _transforms): 129 | dataset = datasets.ImageFolder(self.train_path, _transforms) 130 | return dataset 131 | 132 | def test_dataset(self, _transforms): 133 | dataset = datasets.ImageFolder(self.valid_path, _transforms) 134 | return dataset 135 | 136 | @property 137 | def train_path(self): 138 | return os.path.join(self.save_path, 'train') 139 | 140 | @property 141 | def valid_path(self): 142 | return os.path.join(self.save_path, 'val') 143 | 144 | @property 145 | def normalize(self): 146 | return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 147 | 148 | def build_train_transform(self, image_size=None, print_log=True): 149 | if image_size is None: 150 | image_size = self.image_size 151 | if print_log: 152 | print('Color jitter: %s, resize_scale: %s, img_size: %s' % 153 | (self.distort_color, self.resize_scale, image_size)) 154 | 155 | if self.distort_color == 'torch': 156 | color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) 157 | elif self.distort_color == 'tf': 158 | color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5) 159 | else: 160 | color_transform = None 161 | 162 | if isinstance(image_size, list): 163 | resize_transform_class = MyRandomResizedCrop 164 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(), 165 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS)) 166 | else: 167 | resize_transform_class = transforms.RandomResizedCrop 168 | 169 | train_transforms = [ 170 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)), 171 | transforms.RandomHorizontalFlip(), 172 | ] 173 | if color_transform is not None: 174 | train_transforms.append(color_transform) 175 | train_transforms += [ 176 | transforms.ToTensor(), 177 | self.normalize, 178 | ] 179 | 180 | train_transforms = transforms.Compose(train_transforms) 181 | return train_transforms 182 | 183 | def build_valid_transform(self, image_size=None): 184 | if image_size is None: 185 | image_size = self.active_img_size 186 | return transforms.Compose([ 187 | transforms.Resize(int(math.ceil(image_size / 0.875))), 188 | transforms.CenterCrop(image_size), 189 | transforms.ToTensor(), 190 | self.normalize, 191 | ]) 192 | 193 | def assign_active_img_size(self, new_img_size): 194 | self.active_img_size = new_img_size 195 | if self.active_img_size not in self._valid_transform_dict: 196 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform() 197 | # change the transform of the valid and test set 198 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] 199 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] 200 | 201 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None): 202 | # used for resetting running statistics 203 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None: 204 | if num_worker is None: 205 | num_worker = self.train.num_workers 206 | 207 | n_samples = len(self.train.dataset.samples) 208 | g = torch.Generator() 209 | g.manual_seed(DataProvider.SUB_SEED) 210 | rand_indexes = torch.randperm(n_samples, generator=g).tolist() 211 | 212 | new_train_dataset = self.train_dataset( 213 | self.build_train_transform(image_size=self.active_img_size, print_log=False)) 214 | chosen_indexes = rand_indexes[:n_images] 215 | if num_replicas is not None: 216 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes)) 217 | else: 218 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes) 219 | sub_data_loader = torch.utils.data.DataLoader( 220 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler, 221 | num_workers=num_worker, pin_memory=True, 222 | ) 223 | self.__dict__['sub_train_%d' % self.active_img_size] = [] 224 | for images, labels in sub_data_loader: 225 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels)) 226 | return self.__dict__['sub_train_%d' % self.active_img_size] 227 | -------------------------------------------------------------------------------- /codebase/data_providers/pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import warnings 4 | import numpy as np 5 | 6 | from timm.data.transforms import str_to_pil_interp 7 | from timm.data.auto_augment import rand_augment_transform 8 | 9 | import torch.utils.data 10 | import torchvision.transforms as transforms 11 | import torchvision.datasets as datasets 12 | 13 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler 14 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider 15 | 16 | 17 | class OxfordIIITPetsDataProvider(DataProvider): 18 | 19 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32, 20 | resize_scale=0.08, distort_color=None, image_size=224, 21 | num_replicas=None, rank=None): 22 | 23 | warnings.filterwarnings('ignore') 24 | self._save_path = save_path 25 | 26 | self.image_size = image_size # int or list of int 27 | self.distort_color = distort_color 28 | self.resize_scale = resize_scale 29 | 30 | self._valid_transform_dict = {} 31 | if not isinstance(self.image_size, int): 32 | assert isinstance(self.image_size, list) 33 | from ofa.utils.my_dataloader import MyDataLoader 34 | self.image_size.sort() # e.g., 160 -> 224 35 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() 36 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) 37 | 38 | for img_size in self.image_size: 39 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size) 40 | self.active_img_size = max(self.image_size) 41 | valid_transforms = self._valid_transform_dict[self.active_img_size] 42 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image 43 | else: 44 | self.active_img_size = self.image_size 45 | valid_transforms = self.build_valid_transform() 46 | train_loader_class = torch.utils.data.DataLoader 47 | 48 | train_transforms = self.build_train_transform() 49 | train_dataset = self.train_dataset(train_transforms) 50 | 51 | if valid_size is not None: 52 | if not isinstance(valid_size, int): 53 | assert isinstance(valid_size, float) and 0 < valid_size < 1 54 | valid_size = int(len(train_dataset.samples) * valid_size) 55 | 56 | valid_dataset = self.train_dataset(valid_transforms) 57 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size) 58 | 59 | if num_replicas is not None: 60 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes)) 61 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes)) 62 | else: 63 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes) 64 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes) 65 | 66 | self.train = train_loader_class( 67 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 68 | num_workers=n_worker, pin_memory=True, 69 | ) 70 | self.valid = torch.utils.data.DataLoader( 71 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler, 72 | num_workers=n_worker, pin_memory=True, 73 | ) 74 | else: 75 | if num_replicas is not None: 76 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank) 77 | self.train = train_loader_class( 78 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 79 | num_workers=n_worker, pin_memory=True 80 | ) 81 | else: 82 | self.train = train_loader_class( 83 | train_dataset, batch_size=train_batch_size, shuffle=True, 84 | num_workers=n_worker, pin_memory=True, 85 | ) 86 | self.valid = None 87 | 88 | test_dataset = self.test_dataset(valid_transforms) 89 | if num_replicas is not None: 90 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank) 91 | self.test = torch.utils.data.DataLoader( 92 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True, 93 | ) 94 | else: 95 | self.test = torch.utils.data.DataLoader( 96 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True, 97 | ) 98 | 99 | if self.valid is None: 100 | self.valid = self.test 101 | 102 | @staticmethod 103 | def name(): 104 | return 'pets' 105 | 106 | @property 107 | def data_shape(self): 108 | return 3, self.active_img_size, self.active_img_size # C, H, W 109 | 110 | @property 111 | def n_classes(self): 112 | return 37 113 | 114 | @property 115 | def save_path(self): 116 | if self._save_path is None: 117 | self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server 118 | 119 | if not os.path.exists(self._save_path): 120 | self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server 121 | return self._save_path 122 | 123 | @property 124 | def data_url(self): 125 | raise ValueError('unable to download %s' % self.name()) 126 | 127 | def train_dataset(self, _transforms): 128 | dataset = datasets.ImageFolder(self.train_path, _transforms) 129 | return dataset 130 | 131 | def test_dataset(self, _transforms): 132 | dataset = datasets.ImageFolder(self.valid_path, _transforms) 133 | return dataset 134 | 135 | @property 136 | def train_path(self): 137 | return os.path.join(self.save_path, 'train') 138 | 139 | @property 140 | def valid_path(self): 141 | return os.path.join(self.save_path, 'valid') 142 | 143 | @property 144 | def normalize(self): 145 | return transforms.Normalize( 146 | mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783], 147 | std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917]) 148 | 149 | def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'): 150 | if image_size is None: 151 | image_size = self.image_size 152 | # if print_log: 153 | # print('Color jitter: %s, resize_scale: %s, img_size: %s' % 154 | # (self.distort_color, self.resize_scale, image_size)) 155 | 156 | # if self.distort_color == 'torch': 157 | # color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) 158 | # elif self.distort_color == 'tf': 159 | # color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5) 160 | # else: 161 | # color_transform = None 162 | 163 | if isinstance(image_size, list): 164 | resize_transform_class = MyRandomResizedCrop 165 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(), 166 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS)) 167 | img_size_min = min(image_size) 168 | else: 169 | resize_transform_class = transforms.RandomResizedCrop 170 | img_size_min = image_size 171 | 172 | train_transforms = [ 173 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)), 174 | transforms.RandomHorizontalFlip(), 175 | ] 176 | 177 | aa_params = dict( 178 | translate_const=int(img_size_min * 0.45), 179 | img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807, 180 | 0.39566558230789783]]), 181 | ) 182 | aa_params['interpolation'] = str_to_pil_interp('bicubic') 183 | train_transforms += [rand_augment_transform(auto_augment, aa_params)] 184 | 185 | # if color_transform is not None: 186 | # train_transforms.append(color_transform) 187 | train_transforms += [ 188 | transforms.ToTensor(), 189 | self.normalize, 190 | ] 191 | 192 | train_transforms = transforms.Compose(train_transforms) 193 | return train_transforms 194 | 195 | def build_valid_transform(self, image_size=None): 196 | if image_size is None: 197 | image_size = self.active_img_size 198 | return transforms.Compose([ 199 | transforms.Resize(int(math.ceil(image_size / 0.875))), 200 | transforms.CenterCrop(image_size), 201 | transforms.ToTensor(), 202 | self.normalize, 203 | ]) 204 | 205 | def assign_active_img_size(self, new_img_size): 206 | self.active_img_size = new_img_size 207 | if self.active_img_size not in self._valid_transform_dict: 208 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform() 209 | # change the transform of the valid and test set 210 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] 211 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] 212 | 213 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None): 214 | # used for resetting running statistics 215 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None: 216 | if num_worker is None: 217 | num_worker = self.train.num_workers 218 | 219 | n_samples = len(self.train.dataset.samples) 220 | g = torch.Generator() 221 | g.manual_seed(DataProvider.SUB_SEED) 222 | rand_indexes = torch.randperm(n_samples, generator=g).tolist() 223 | 224 | new_train_dataset = self.train_dataset( 225 | self.build_train_transform(image_size=self.active_img_size, print_log=False)) 226 | chosen_indexes = rand_indexes[:n_images] 227 | if num_replicas is not None: 228 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes)) 229 | else: 230 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes) 231 | sub_data_loader = torch.utils.data.DataLoader( 232 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler, 233 | num_workers=num_worker, pin_memory=True, 234 | ) 235 | self.__dict__['sub_train_%d' % self.active_img_size] = [] 236 | for images, labels in sub_data_loader: 237 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels)) 238 | return self.__dict__['sub_train_%d' % self.active_img_size] -------------------------------------------------------------------------------- /codebase/data_providers/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import numpy as np 4 | 5 | from timm.data.transforms import str_to_pil_interp 6 | from timm.data.auto_augment import rand_augment_transform 7 | 8 | import torch.utils.data 9 | import torchvision.transforms as transforms 10 | import torchvision.datasets as datasets 11 | 12 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler 13 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider 14 | 15 | 16 | class DTDDataProvider(DataProvider): 17 | 18 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32, 19 | resize_scale=0.08, distort_color=None, image_size=224, 20 | num_replicas=None, rank=None): 21 | 22 | warnings.filterwarnings('ignore') 23 | self._save_path = save_path 24 | 25 | self.image_size = image_size # int or list of int 26 | self.distort_color = distort_color 27 | self.resize_scale = resize_scale 28 | 29 | self._valid_transform_dict = {} 30 | if not isinstance(self.image_size, int): 31 | assert isinstance(self.image_size, list) 32 | from ofa.utils.my_dataloader import MyDataLoader 33 | self.image_size.sort() # e.g., 160 -> 224 34 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() 35 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) 36 | 37 | for img_size in self.image_size: 38 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size) 39 | self.active_img_size = max(self.image_size) 40 | valid_transforms = self._valid_transform_dict[self.active_img_size] 41 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image 42 | else: 43 | self.active_img_size = self.image_size 44 | valid_transforms = self.build_valid_transform() 45 | train_loader_class = torch.utils.data.DataLoader 46 | 47 | train_transforms = self.build_train_transform() 48 | train_dataset = self.train_dataset(train_transforms) 49 | 50 | if valid_size is not None: 51 | if not isinstance(valid_size, int): 52 | assert isinstance(valid_size, float) and 0 < valid_size < 1 53 | valid_size = int(len(train_dataset.samples) * valid_size) 54 | 55 | valid_dataset = self.train_dataset(valid_transforms) 56 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size) 57 | 58 | if num_replicas is not None: 59 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes)) 60 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes)) 61 | else: 62 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes) 63 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes) 64 | 65 | self.train = train_loader_class( 66 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 67 | num_workers=n_worker, pin_memory=True, 68 | ) 69 | self.valid = torch.utils.data.DataLoader( 70 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler, 71 | num_workers=n_worker, pin_memory=True, 72 | ) 73 | else: 74 | if num_replicas is not None: 75 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank) 76 | self.train = train_loader_class( 77 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 78 | num_workers=n_worker, pin_memory=True 79 | ) 80 | else: 81 | self.train = train_loader_class( 82 | train_dataset, batch_size=train_batch_size, shuffle=True, 83 | num_workers=n_worker, pin_memory=True, 84 | ) 85 | self.valid = None 86 | 87 | test_dataset = self.test_dataset(valid_transforms) 88 | if num_replicas is not None: 89 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank) 90 | self.test = torch.utils.data.DataLoader( 91 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True, 92 | ) 93 | else: 94 | self.test = torch.utils.data.DataLoader( 95 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True, 96 | ) 97 | 98 | if self.valid is None: 99 | self.valid = self.test 100 | 101 | @staticmethod 102 | def name(): 103 | return 'dtd' 104 | 105 | @property 106 | def data_shape(self): 107 | return 3, self.active_img_size, self.active_img_size # C, H, W 108 | 109 | @property 110 | def n_classes(self): 111 | return 47 112 | 113 | @property 114 | def save_path(self): 115 | if self._save_path is None: 116 | self._save_path = '/mnt/datastore/dtd' # home server 117 | 118 | if not os.path.exists(self._save_path): 119 | self._save_path = '/mnt/datastore/dtd' # home server 120 | return self._save_path 121 | 122 | @property 123 | def data_url(self): 124 | raise ValueError('unable to download %s' % self.name()) 125 | 126 | def train_dataset(self, _transforms): 127 | dataset = datasets.ImageFolder(self.train_path, _transforms) 128 | return dataset 129 | 130 | def test_dataset(self, _transforms): 131 | dataset = datasets.ImageFolder(self.valid_path, _transforms) 132 | return dataset 133 | 134 | @property 135 | def train_path(self): 136 | return os.path.join(self.save_path, 'train') 137 | 138 | @property 139 | def valid_path(self): 140 | return os.path.join(self.save_path, 'valid') 141 | 142 | @property 143 | def normalize(self): 144 | return transforms.Normalize( 145 | mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676], 146 | std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832]) 147 | 148 | def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'): 149 | if image_size is None: 150 | image_size = self.image_size 151 | # if print_log: 152 | # print('Color jitter: %s, resize_scale: %s, img_size: %s' % 153 | # (self.distort_color, self.resize_scale, image_size)) 154 | 155 | # if self.distort_color == 'torch': 156 | # color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) 157 | # elif self.distort_color == 'tf': 158 | # color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5) 159 | # else: 160 | # color_transform = None 161 | 162 | if isinstance(image_size, list): 163 | resize_transform_class = MyRandomResizedCrop 164 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(), 165 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS)) 166 | img_size_min = min(image_size) 167 | else: 168 | resize_transform_class = transforms.RandomResizedCrop 169 | img_size_min = image_size 170 | 171 | train_transforms = [ 172 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)), 173 | transforms.RandomHorizontalFlip(), 174 | ] 175 | 176 | aa_params = dict( 177 | translate_const=int(img_size_min * 0.45), 178 | img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454, 179 | 0.42627281899380676]]), 180 | ) 181 | aa_params['interpolation'] = str_to_pil_interp('bicubic') 182 | train_transforms += [rand_augment_transform(auto_augment, aa_params)] 183 | 184 | # if color_transform is not None: 185 | # train_transforms.append(color_transform) 186 | train_transforms += [ 187 | transforms.ToTensor(), 188 | self.normalize, 189 | ] 190 | 191 | train_transforms = transforms.Compose(train_transforms) 192 | return train_transforms 193 | 194 | def build_valid_transform(self, image_size=None): 195 | if image_size is None: 196 | image_size = self.active_img_size 197 | return transforms.Compose([ 198 | # transforms.Resize(int(math.ceil(image_size / 0.875))), 199 | transforms.Resize((image_size, image_size), interpolation=3), 200 | transforms.CenterCrop(image_size), 201 | transforms.ToTensor(), 202 | self.normalize, 203 | ]) 204 | 205 | def assign_active_img_size(self, new_img_size): 206 | self.active_img_size = new_img_size 207 | if self.active_img_size not in self._valid_transform_dict: 208 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform() 209 | # change the transform of the valid and test set 210 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] 211 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] 212 | 213 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None): 214 | # used for resetting running statistics 215 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None: 216 | if num_worker is None: 217 | num_worker = self.train.num_workers 218 | 219 | n_samples = len(self.train.dataset.samples) 220 | g = torch.Generator() 221 | g.manual_seed(DataProvider.SUB_SEED) 222 | rand_indexes = torch.randperm(n_samples, generator=g).tolist() 223 | 224 | new_train_dataset = self.train_dataset( 225 | self.build_train_transform(image_size=self.active_img_size, print_log=False)) 226 | chosen_indexes = rand_indexes[:n_images] 227 | if num_replicas is not None: 228 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes)) 229 | else: 230 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes) 231 | sub_data_loader = torch.utils.data.DataLoader( 232 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler, 233 | num_workers=num_worker, pin_memory=True, 234 | ) 235 | self.__dict__['sub_train_%d' % self.active_img_size] = [] 236 | for images, labels in sub_data_loader: 237 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels)) 238 | return self.__dict__['sub_train_%d' % self.active_img_size] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /codebase/data_providers/autoaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 3 | """ 4 | 5 | from PIL import Image, ImageEnhance, ImageOps 6 | import numpy as np 7 | import random 8 | 9 | 10 | class ImageNetPolicy(object): 11 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 12 | 13 | Example: 14 | >>> policy = ImageNetPolicy() 15 | >>> transformed = policy(image) 16 | 17 | Example as a PyTorch Transform: 18 | >>> transform=transforms.Compose([ 19 | >>> transforms.Resize(256), 20 | >>> ImageNetPolicy(), 21 | >>> transforms.ToTensor()]) 22 | """ 23 | def __init__(self, fillcolor=(128, 128, 128)): 24 | self.policies = [ 25 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 26 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 27 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 28 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 29 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 30 | 31 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 32 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 33 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 34 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 35 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 36 | 37 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 38 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 39 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 40 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 41 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 42 | 43 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 44 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 45 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 46 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 47 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 48 | 49 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 50 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 51 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 52 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 53 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 54 | ] 55 | 56 | 57 | def __call__(self, img): 58 | policy_idx = random.randint(0, len(self.policies) - 1) 59 | return self.policies[policy_idx](img) 60 | 61 | def __repr__(self): 62 | return "AutoAugment ImageNet Policy" 63 | 64 | 65 | class CIFAR10Policy(object): 66 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 67 | 68 | Example: 69 | >>> policy = CIFAR10Policy() 70 | >>> transformed = policy(image) 71 | 72 | Example as a PyTorch Transform: 73 | >>> transform=transforms.Compose([ 74 | >>> transforms.Resize(256), 75 | >>> CIFAR10Policy(), 76 | >>> transforms.ToTensor()]) 77 | """ 78 | def __init__(self, fillcolor=(128, 128, 128)): 79 | self.policies = [ 80 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 81 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 82 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 83 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 84 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 85 | 86 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 87 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 88 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 89 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 90 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 91 | 92 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 93 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 94 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 95 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 96 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 97 | 98 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 99 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 100 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 101 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 102 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 103 | 104 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 105 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 106 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 107 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 108 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 109 | ] 110 | 111 | 112 | def __call__(self, img): 113 | policy_idx = random.randint(0, len(self.policies) - 1) 114 | return self.policies[policy_idx](img) 115 | 116 | def __repr__(self): 117 | return "AutoAugment CIFAR10 Policy" 118 | 119 | 120 | class SVHNPolicy(object): 121 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 122 | 123 | Example: 124 | >>> policy = SVHNPolicy() 125 | >>> transformed = policy(image) 126 | 127 | Example as a PyTorch Transform: 128 | >>> transform=transforms.Compose([ 129 | >>> transforms.Resize(256), 130 | >>> SVHNPolicy(), 131 | >>> transforms.ToTensor()]) 132 | """ 133 | def __init__(self, fillcolor=(128, 128, 128)): 134 | self.policies = [ 135 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 136 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 137 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 138 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 139 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 140 | 141 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 142 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 143 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 144 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 145 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 146 | 147 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 148 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 149 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 150 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 151 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 152 | 153 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 154 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 155 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 156 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 157 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 158 | 159 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 160 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 161 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 162 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 163 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 164 | ] 165 | 166 | 167 | def __call__(self, img): 168 | policy_idx = random.randint(0, len(self.policies) - 1) 169 | return self.policies[policy_idx](img) 170 | 171 | def __repr__(self): 172 | return "AutoAugment SVHN Policy" 173 | 174 | 175 | class SubPolicy(object): 176 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 177 | ranges = { 178 | "shearX": np.linspace(0, 0.3, 10), 179 | "shearY": np.linspace(0, 0.3, 10), 180 | "translateX": np.linspace(0, 150 / 331, 10), 181 | "translateY": np.linspace(0, 150 / 331, 10), 182 | "rotate": np.linspace(0, 30, 10), 183 | "color": np.linspace(0.0, 0.9, 10), 184 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 185 | "solarize": np.linspace(256, 0, 10), 186 | "contrast": np.linspace(0.0, 0.9, 10), 187 | "sharpness": np.linspace(0.0, 0.9, 10), 188 | "brightness": np.linspace(0.0, 0.9, 10), 189 | "autocontrast": [0] * 10, 190 | "equalize": [0] * 10, 191 | "invert": [0] * 10 192 | } 193 | 194 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 195 | def rotate_with_fill(img, magnitude): 196 | rot = img.convert("RGBA").rotate(magnitude) 197 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 198 | 199 | func = { 200 | "shearX": lambda img, magnitude: img.transform( 201 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 202 | Image.BICUBIC, fillcolor=fillcolor), 203 | "shearY": lambda img, magnitude: img.transform( 204 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 205 | Image.BICUBIC, fillcolor=fillcolor), 206 | "translateX": lambda img, magnitude: img.transform( 207 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 208 | fillcolor=fillcolor), 209 | "translateY": lambda img, magnitude: img.transform( 210 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 211 | fillcolor=fillcolor), 212 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 213 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 214 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 215 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 216 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 217 | 1 + magnitude * random.choice([-1, 1])), 218 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 219 | 1 + magnitude * random.choice([-1, 1])), 220 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 221 | 1 + magnitude * random.choice([-1, 1])), 222 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 223 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 224 | "invert": lambda img, magnitude: ImageOps.invert(img) 225 | } 226 | 227 | self.p1 = p1 228 | self.operation1 = func[operation1] 229 | self.magnitude1 = ranges[operation1][magnitude_idx1] 230 | self.p2 = p2 231 | self.operation2 = func[operation2] 232 | self.magnitude2 = ranges[operation2][magnitude_idx2] 233 | 234 | 235 | def __call__(self, img): 236 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 237 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 238 | return img -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import yaml 5 | import numpy as np 6 | from collections import OrderedDict 7 | from torchprofile import profile_macs 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.backends.cudnn as cudnn 12 | 13 | from pymoo.model.mutation import Mutation 14 | from pymoo.model.sampling import Sampling 15 | from pymoo.model.crossover import Crossover 16 | 17 | DEFAULT_CFG = { 18 | 'gpus': '0', 'config': None, 'init': None, 'trn_batch_size': 128, 'vld_batch_size': 250, 'num_workers': 4, 19 | 'n_epochs': 0, 'save': None, 'resolution': 224, 'valid_size': 10000, 'test': True, 'latency': None, 20 | 'verbose': False, 'classifier_only': False, 'reset_running_statistics': True, 21 | } 22 | 23 | 24 | def get_correlation(prediction, target): 25 | import scipy.stats as stats 26 | 27 | rmse = np.sqrt(((prediction - target) ** 2).mean()) 28 | rho, _ = stats.spearmanr(prediction, target) 29 | tau, _ = stats.kendalltau(prediction, target) 30 | 31 | return rmse, rho, tau 32 | 33 | 34 | def bash_command_template(**kwargs): 35 | gpus = kwargs.pop('gpus', DEFAULT_CFG['gpus']) 36 | cfg = OrderedDict() 37 | cfg['subnet'] = kwargs['subnet'] 38 | cfg['data'] = kwargs['data'] 39 | cfg['dataset'] = kwargs['dataset'] 40 | cfg['n_classes'] = kwargs['n_classes'] 41 | cfg['supernet_path'] = kwargs['supernet_path'] 42 | cfg['config'] = kwargs.pop('config', DEFAULT_CFG['config']) 43 | cfg['init'] = kwargs.pop('init', DEFAULT_CFG['init']) 44 | cfg['save'] = kwargs.pop('save', DEFAULT_CFG['save']) 45 | cfg['trn_batch_size'] = kwargs.pop('trn_batch_size', DEFAULT_CFG['trn_batch_size']) 46 | cfg['vld_batch_size'] = kwargs.pop('vld_batch_size', DEFAULT_CFG['vld_batch_size']) 47 | cfg['num_workers'] = kwargs.pop('num_workers', DEFAULT_CFG['num_workers']) 48 | cfg['n_epochs'] = kwargs.pop('n_epochs', DEFAULT_CFG['n_epochs']) 49 | cfg['resolution'] = kwargs.pop('resolution', DEFAULT_CFG['resolution']) 50 | cfg['valid_size'] = kwargs.pop('valid_size', DEFAULT_CFG['valid_size']) 51 | cfg['test'] = kwargs.pop('test', DEFAULT_CFG['test']) 52 | cfg['latency'] = kwargs.pop('latency', DEFAULT_CFG['latency']) 53 | cfg['verbose'] = kwargs.pop('verbose', DEFAULT_CFG['verbose']) 54 | cfg['classifier_only'] = kwargs.pop('classifier_only', DEFAULT_CFG['classifier_only']) 55 | cfg['reset_running_statistics'] = kwargs.pop( 56 | 'reset_running_statistics', DEFAULT_CFG['reset_running_statistics']) 57 | 58 | execution_line = "CUDA_VISIBLE_DEVICES={} python evaluator.py".format(gpus) 59 | for k, v in cfg.items(): 60 | if v is not None: 61 | if isinstance(v, bool): 62 | if v: 63 | execution_line += " --{}".format(k) 64 | else: 65 | execution_line += " --{} {}".format(k, v) 66 | execution_line += ' &' 67 | return execution_line 68 | 69 | 70 | def prepare_eval_folder(path, configs, gpu=2, n_gpus=8, **kwargs): 71 | """ create a folder for parallel evaluation of a population of architectures """ 72 | os.makedirs(path, exist_ok=True) 73 | gpu_template = ','.join(['{}'] * gpu) 74 | gpus = [gpu_template.format(i, i + 1) for i in range(0, n_gpus, gpu)] 75 | bash_file = ['#!/bin/bash'] 76 | for i in range(0, len(configs), n_gpus//gpu): 77 | for j in range(n_gpus//gpu): 78 | if i + j < len(configs): 79 | job = os.path.join(path, "net_{}.subnet".format(i + j)) 80 | with open(job, 'w') as handle: 81 | json.dump(configs[i + j], handle) 82 | bash_file.append(bash_command_template( 83 | gpus=gpus[j], subnet=job, save=os.path.join( 84 | path, "net_{}.stats".format(i + j)), **kwargs)) 85 | bash_file.append('wait') 86 | 87 | with open(os.path.join(path, 'run_bash.sh'), 'w') as handle: 88 | for line in bash_file: 89 | handle.write(line + os.linesep) 90 | 91 | 92 | class MySampling(Sampling): 93 | 94 | def _do(self, problem, n_samples, **kwargs): 95 | X = np.full((n_samples, problem.n_var), False, dtype=np.bool) 96 | 97 | for k in range(n_samples): 98 | I = np.random.permutation(problem.n_var)[:problem.n_max] 99 | X[k, I] = True 100 | 101 | return X 102 | 103 | 104 | class BinaryCrossover(Crossover): 105 | def __init__(self): 106 | super().__init__(2, 1) 107 | 108 | def _do(self, problem, X, **kwargs): 109 | n_parents, n_matings, n_var = X.shape 110 | 111 | _X = np.full((self.n_offsprings, n_matings, problem.n_var), False) 112 | 113 | for k in range(n_matings): 114 | p1, p2 = X[0, k], X[1, k] 115 | 116 | both_are_true = np.logical_and(p1, p2) 117 | _X[0, k, both_are_true] = True 118 | 119 | n_remaining = problem.n_max - np.sum(both_are_true) 120 | 121 | I = np.where(np.logical_xor(p1, p2))[0] 122 | 123 | S = I[np.random.permutation(len(I))][:n_remaining] 124 | _X[0, k, S] = True 125 | 126 | return _X 127 | 128 | 129 | class MyMutation(Mutation): 130 | def _do(self, problem, X, **kwargs): 131 | for i in range(X.shape[0]): 132 | X[i, :] = X[i, :] 133 | is_false = np.where(np.logical_not(X[i, :]))[0] 134 | is_true = np.where(X[i, :])[0] 135 | try: 136 | X[i, np.random.choice(is_false)] = True 137 | X[i, np.random.choice(is_true)] = False 138 | except ValueError: 139 | pass 140 | 141 | return X 142 | 143 | 144 | class LatencyEstimator(object): 145 | """ 146 | Modified from https://github.com/mit-han-lab/proxylessnas/blob/ 147 | f273683a77c4df082dd11cc963b07fc3613079a0/search/utils/latency_estimator.py#L29 148 | """ 149 | def __init__(self, fname): 150 | # fname = download_url(url, overwrite=True) 151 | 152 | with open(fname, 'r') as fp: 153 | self.lut = yaml.load(fp, yaml.SafeLoader) 154 | 155 | @staticmethod 156 | def repr_shape(shape): 157 | if isinstance(shape, (list, tuple)): 158 | return 'x'.join(str(_) for _ in shape) 159 | elif isinstance(shape, str): 160 | return shape 161 | else: 162 | return TypeError 163 | 164 | def predict(self, ltype: str, _input, output, expand=None, 165 | kernel=None, stride=None, idskip=None, se=None): 166 | """ 167 | :param ltype: 168 | Layer type must be one of the followings 169 | 1. `first_conv`: The initial stem 3x3 conv with stride 2 170 | 2. `final_expand_layer`: (Only for MobileNet-V3) 171 | The upsample 1x1 conv that increases num_filters by 6 times + GAP. 172 | 3. 'feature_mix_layer': 173 | The upsample 1x1 conv that increase num_filters to num_features + torch.squeeze 174 | 3. `classifier`: fully connected linear layer (num_features to num_classes) 175 | 4. `MBConv`: MobileInvertedResidual 176 | :param _input: input shape (h, w, #channels) 177 | :param output: output shape (h, w, #channels) 178 | :param expand: expansion ratio 179 | :param kernel: kernel size 180 | :param stride: 181 | :param idskip: indicate whether has the residual connection 182 | :param se: indicate whether has squeeze-and-excitation 183 | """ 184 | infos = [ltype, 'input:%s' % self.repr_shape(_input), 185 | 'output:%s' % self.repr_shape(output), ] 186 | if ltype in ('MBConv',): 187 | assert None not in (expand, kernel, stride, idskip, se) 188 | infos += ['expand:%d' % expand, 'kernel:%d' % kernel, 189 | 'stride:%d' % stride, 'idskip:%d' % idskip, 'se:%d' % se] 190 | key = '-'.join(infos) 191 | return self.lut[key]['mean'] 192 | 193 | 194 | def look_up_latency(net, lut, resolution=224): 195 | def _half(x, times=1): 196 | for _ in range(times): 197 | x = np.ceil(x / 2) 198 | return int(x) 199 | 200 | predicted_latency = 0 201 | 202 | # first_conv 203 | predicted_latency += lut.predict( 204 | 'first_conv', [resolution, resolution, 3], 205 | [resolution // 2, resolution // 2, net.first_conv.out_channels]) 206 | 207 | # final_expand_layer (only for MobileNet V3 models) 208 | input_resolution = _half(resolution, times=5) 209 | predicted_latency += lut.predict( 210 | 'final_expand_layer', 211 | [input_resolution, input_resolution, net.final_expand_layer.in_channels], 212 | [input_resolution, input_resolution, net.final_expand_layer.out_channels] 213 | ) 214 | 215 | # feature_mix_layer 216 | predicted_latency += lut.predict( 217 | 'feature_mix_layer', 218 | [1, 1, net.feature_mix_layer.in_channels], 219 | [1, 1, net.feature_mix_layer.out_channels] 220 | ) 221 | 222 | # classifier 223 | predicted_latency += lut.predict( 224 | 'classifier', 225 | [net.classifier.in_features], 226 | [net.classifier.out_features] 227 | ) 228 | 229 | # blocks 230 | fsize = _half(resolution) 231 | for block in net.blocks: 232 | idskip = 0 if block.config['shortcut'] is None else 1 233 | se = 1 if block.config['mobile_inverted_conv']['use_se'] else 0 234 | stride = block.config['mobile_inverted_conv']['stride'] 235 | out_fz = _half(fsize) if stride > 1 else fsize 236 | block_latency = lut.predict( 237 | 'MBConv', 238 | [fsize, fsize, block.config['mobile_inverted_conv']['in_channels']], 239 | [out_fz, out_fz, block.config['mobile_inverted_conv']['out_channels']], 240 | expand=block.config['mobile_inverted_conv']['expand_ratio'], 241 | kernel=block.config['mobile_inverted_conv']['kernel_size'], 242 | stride=stride, idskip=idskip, se=se 243 | ) 244 | predicted_latency += block_latency 245 | fsize = out_fz 246 | 247 | return predicted_latency 248 | 249 | 250 | def get_net_info(net, input_shape=(3, 224, 224), measure_latency=None, print_info=True, clean=False, lut=None): 251 | """ 252 | Modified from https://github.com/mit-han-lab/once-for-all/blob/ 253 | 35ddcb9ca30905829480770a6a282d49685aa282/ofa/imagenet_codebase/utils/pytorch_utils.py#L139 254 | """ 255 | from ofa.imagenet_codebase.utils.pytorch_utils import count_parameters, measure_net_latency 256 | 257 | # artificial input data 258 | inputs = torch.randn(1, 3, input_shape[-2], input_shape[-1]) 259 | 260 | # move network to GPU if available 261 | if torch.cuda.is_available(): 262 | device = torch.device('cuda:0') 263 | net = net.to(device) 264 | cudnn.benchmark = True 265 | inputs = inputs.to(device) 266 | 267 | net_info = {} 268 | if isinstance(net, nn.DataParallel): 269 | net = net.module 270 | 271 | # parameters 272 | net_info['params'] = count_parameters(net) 273 | 274 | # flops 275 | net_info['flops'] = int(profile_macs(copy.deepcopy(net), inputs)) 276 | 277 | # latencies 278 | latency_types = [] if measure_latency is None else measure_latency.split('#') 279 | 280 | # print(latency_types) 281 | for l_type in latency_types: 282 | if lut is not None and l_type in lut: 283 | latency_estimator = LatencyEstimator(lut[l_type]) 284 | latency = look_up_latency(net, latency_estimator, input_shape[2]) 285 | measured_latency = None 286 | else: 287 | latency, measured_latency = measure_net_latency( 288 | net, l_type, fast=False, input_shape=input_shape, clean=clean) 289 | net_info['%s latency' % l_type] = { 290 | 'val': latency, 291 | 'hist': measured_latency 292 | } 293 | 294 | if print_info: 295 | # print(net) 296 | print('Total training params: %.2fM' % (net_info['params'] / 1e6)) 297 | print('Total FLOPs: %.2fM' % (net_info['flops'] / 1e6)) 298 | for l_type in latency_types: 299 | print('Estimated %s latency: %.3fms' % (l_type, net_info['%s latency' % l_type]['val'])) 300 | 301 | return net_info 302 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | import numpy as np 6 | 7 | import utils 8 | from codebase.networks import NSGANetV2 9 | from codebase.run_manager import get_run_config 10 | from ofa.elastic_nn.networks import OFAMobileNetV3 11 | from ofa.imagenet_codebase.run_manager import RunManager 12 | from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d 13 | 14 | import warnings 15 | warnings.simplefilter("ignore") 16 | 17 | DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1 18 | 19 | 20 | def parse_string_list(string): 21 | if isinstance(string, str): 22 | # convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3] 23 | return list(map(int, string[1:-1].split())) 24 | else: 25 | return string 26 | 27 | 28 | def pad_none(x, depth, max_depth): 29 | new_x, counter = [], 0 30 | for d in depth: 31 | for _ in range(d): 32 | new_x.append(x[counter]) 33 | counter += 1 34 | if d < max_depth: 35 | new_x += [None] * (max_depth - d) 36 | return new_x 37 | 38 | 39 | def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None): 40 | 41 | net_info = utils.get_net_info( 42 | net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut) 43 | 44 | gpu_latency, cpu_latency = None, None 45 | for k in net_info.keys(): 46 | if 'gpu' in k: 47 | gpu_latency = np.round(net_info[k]['val'], 2) 48 | if 'cpu' in k: 49 | cpu_latency = np.round(net_info[k]['val'], 2) 50 | 51 | return { 52 | 'params': np.round(net_info['params'] / 1e6, 2), 53 | 'flops': np.round(net_info['flops'] / 1e6, 2), 54 | 'gpu': gpu_latency, 'cpu': cpu_latency 55 | } 56 | 57 | 58 | def validate_config(config, max_depth=4): 59 | kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d'] 60 | 61 | if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size) 62 | if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio) 63 | if isinstance(depth, str): depth = parse_string_list(depth) 64 | 65 | assert (isinstance(kernel_size, list) or isinstance(kernel_size, int)) 66 | assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int)) 67 | assert isinstance(depth, list) 68 | 69 | if len(kernel_size) < len(depth) * max_depth: 70 | kernel_size = pad_none(kernel_size, depth, max_depth) 71 | if len(exp_ratio) < len(depth) * max_depth: 72 | exp_ratio = pad_none(exp_ratio, depth, max_depth) 73 | 74 | # return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']} 75 | return {'ks': kernel_size, 'e': exp_ratio, 'd': depth} 76 | 77 | 78 | class OFAEvaluator: 79 | """ based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """ 80 | def __init__(self, 81 | n_classes=1000, 82 | model_path='./data/ofa_mbv3_d234_e346_k357_w1.0', 83 | kernel_size=None, exp_ratio=None, depth=None): 84 | # default configurations 85 | self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size 86 | self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate 87 | self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition 88 | 89 | if 'w1.0' in model_path: 90 | self.width_mult = 1.0 91 | elif 'w1.2' in model_path: 92 | self.width_mult = 1.2 93 | else: 94 | raise ValueError 95 | 96 | self.engine = OFAMobileNetV3( 97 | n_classes=n_classes, 98 | dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size, 99 | expand_ratio_list=self.exp_ratio, depth_list=self.depth) 100 | 101 | init = torch.load(model_path, map_location='cpu')['state_dict'] 102 | self.engine.load_weights_from_net(init) 103 | 104 | def sample(self, config=None): 105 | """ randomly sample a sub-network """ 106 | if config is not None: 107 | config = validate_config(config) 108 | self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d']) 109 | else: 110 | config = self.engine.sample_active_subnet() 111 | 112 | subnet = self.engine.get_active_subnet(preserve_weight=True) 113 | return subnet, config 114 | 115 | @staticmethod 116 | def save_net_config(path, net, config_name='net.config'): 117 | """ dump run_config and net_config to the model_folder """ 118 | net_save_path = os.path.join(path, config_name) 119 | json.dump(net.config, open(net_save_path, 'w'), indent=4) 120 | print('Network configs dump to %s' % net_save_path) 121 | 122 | @staticmethod 123 | def save_net(path, net, model_name): 124 | """ dump net weight as checkpoint """ 125 | if isinstance(net, torch.nn.DataParallel): 126 | checkpoint = {'state_dict': net.module.state_dict()} 127 | else: 128 | checkpoint = {'state_dict': net.state_dict()} 129 | model_path = os.path.join(path, model_name) 130 | torch.save(checkpoint, model_path) 131 | print('Network model dump to %s' % model_path) 132 | 133 | @staticmethod 134 | def eval(subnet, data_path, dataset='imagenet', n_epochs=0, resolution=224, trn_batch_size=128, vld_batch_size=250, 135 | num_workers=4, valid_size=None, is_test=True, log_dir='.tmp/eval', measure_latency=None, no_logs=False, 136 | reset_running_statistics=True): 137 | 138 | lut = {'cpu': 'data/i7-8700K_lut.yaml'} 139 | 140 | info = get_net_info( 141 | subnet, (3, resolution, resolution), measure_latency=measure_latency, 142 | print_info=False, clean=True, lut=lut) 143 | 144 | run_config = get_run_config( 145 | dataset=dataset, data_path=data_path, image_size=resolution, n_epochs=n_epochs, 146 | train_batch_size=trn_batch_size, test_batch_size=vld_batch_size, 147 | n_worker=num_workers, valid_size=valid_size) 148 | 149 | # set the image size. You can set any image size from 192 to 256 here 150 | run_config.data_provider.assign_active_img_size(resolution) 151 | 152 | if n_epochs > 0: 153 | # for datasets other than the one supernet was trained on (ImageNet) 154 | # a few epochs of training need to be applied 155 | subnet.reset_classifier( 156 | last_channel=subnet.classifier.in_features, 157 | n_classes=run_config.data_provider.n_classes, dropout_rate=cfgs.drop_rate) 158 | 159 | run_manager = RunManager(log_dir, subnet, run_config, init=False) 160 | if reset_running_statistics: 161 | # run_manager.reset_running_statistics(net=subnet, batch_size=vld_batch_size) 162 | run_manager.reset_running_statistics(net=subnet) 163 | 164 | if n_epochs > 0: 165 | subnet = run_manager.train(cfgs) 166 | 167 | loss, top1, top5 = run_manager.validate(net=subnet, is_test=is_test, no_logs=no_logs) 168 | 169 | info['loss'], info['top1'], info['top5'] = loss, top1, top5 170 | 171 | save_path = os.path.join(log_dir, 'net.stats') if cfgs.save is None else cfgs.save 172 | if cfgs.save_config: 173 | OFAEvaluator.save_net_config(log_dir, subnet, "net.config") 174 | OFAEvaluator.save_net(log_dir, subnet, "net.init") 175 | with open(save_path, 'w') as handle: 176 | json.dump(info, handle) 177 | 178 | print(info) 179 | 180 | 181 | def main(args): 182 | """ one evaluation of a subnet or a config from a file """ 183 | mode = 'subnet' 184 | if args.config is not None: 185 | if args.init is not None: 186 | mode = 'config' 187 | 188 | print('Evaluation mode: {}'.format(mode)) 189 | if mode == 'config': 190 | net_config = json.load(open(args.config)) 191 | subnet = NSGANetV2.build_from_config(net_config, drop_connect_rate=args.drop_connect_rate) 192 | init = torch.load(args.init, map_location='cpu')['state_dict'] 193 | subnet.load_state_dict(init) 194 | subnet.classifier.dropout_rate = args.drop_rate 195 | try: 196 | resolution = net_config['resolution'] 197 | except KeyError: 198 | resolution = args.resolution 199 | 200 | elif mode == 'subnet': 201 | config = json.load(open(args.subnet)) 202 | evaluator = OFAEvaluator(n_classes=args.n_classes, model_path=args.supernet_path) 203 | subnet, _ = evaluator.sample({'ks': config['ks'], 'e': config['e'], 'd': config['d']}) 204 | resolution = config['r'] 205 | 206 | else: 207 | raise NotImplementedError 208 | 209 | OFAEvaluator.eval( 210 | subnet, log_dir=args.log_dir, data_path=args.data, dataset=args.dataset, n_epochs=args.n_epochs, 211 | resolution=resolution, trn_batch_size=args.trn_batch_size, vld_batch_size=args.vld_batch_size, 212 | num_workers=args.num_workers, valid_size=args.valid_size, is_test=args.test, measure_latency=args.latency, 213 | no_logs=(not args.verbose), reset_running_statistics=args.reset_running_statistics) 214 | 215 | 216 | if __name__ == '__main__': 217 | parser = argparse.ArgumentParser() 218 | parser.add_argument('--data', type=str, default='/mnt/datastore/ILSVRC2012', 219 | help='location of the data corpus') 220 | parser.add_argument('--log_dir', type=str, default='.tmp', 221 | help='directory for logging') 222 | parser.add_argument('--dataset', type=str, default='imagenet', 223 | help='name of the dataset (imagenet, cifar10, cifar100, ...)') 224 | parser.add_argument('--n_classes', type=int, default=1000, 225 | help='number of classes for the given dataset') 226 | parser.add_argument('--supernet_path', type=str, default='./data/ofa_mbv3_d234_e346_k357_w1.0', 227 | help='file path to supernet weights') 228 | parser.add_argument('--subnet', type=str, default=None, 229 | help='location of a json file of ks, e, d, and e') 230 | parser.add_argument('--config', type=str, default=None, 231 | help='location of a json file of specific model declaration') 232 | parser.add_argument('--init', type=str, default=None, 233 | help='location of initial weight to load') 234 | parser.add_argument('--trn_batch_size', type=int, default=128, 235 | help='test batch size for inference') 236 | parser.add_argument('--vld_batch_size', type=int, default=256, 237 | help='test batch size for inference') 238 | parser.add_argument('--num_workers', type=int, default=6, 239 | help='number of workers for data loading') 240 | parser.add_argument('--n_epochs', type=int, default=0, 241 | help='number of training epochs') 242 | parser.add_argument('--save', type=str, default=None, 243 | help='location to save the evaluated metrics') 244 | parser.add_argument('--resolution', type=int, default=224, 245 | help='input resolution (192 -> 256)') 246 | parser.add_argument('--valid_size', type=int, default=None, 247 | help='validation set size, randomly sampled from training set') 248 | parser.add_argument('--test', action='store_true', default=False, 249 | help='evaluation performance on testing set') 250 | parser.add_argument('--latency', type=str, default=None, 251 | help='latency measurement settings (gpu64#cpu)') 252 | parser.add_argument('--verbose', action='store_true', default=False, 253 | help='whether to display evaluation progress') 254 | parser.add_argument('--reset_running_statistics', action='store_true', default=False, 255 | help='reset the running mean / std of BN') 256 | parser.add_argument('--drop_rate', type=float, default=0.2, 257 | help='dropout rate') 258 | parser.add_argument('--drop_connect_rate', type=float, default=0.0, 259 | help='connection dropout rate') 260 | parser.add_argument('--save_config', action='store_true', default=False, 261 | help='save config file') 262 | cfgs = parser.parse_args() 263 | 264 | cfgs.teacher_model = None 265 | 266 | main(cfgs) 267 | 268 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NSGANetV2: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search [[slides]](https://www.zhichaolu.com/assets/nsganetv2/0343-long.pdf)[[arXiv]](https://arxiv.org/abs/2007.10396) 2 | ```BibTex 3 | @inproceedings{ 4 | lu2020nsganetv2, 5 | title={{NSGANetV2}: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search}, 6 | author={Zhichao Lu and Kalyanmoy Deb and Erik Goodman and Wolfgang Banzhaf and Vishnu Naresh Boddeti}, 7 | booktitle={European Conference on Computer Vision (ECCV)}, 8 | year={2020} 9 | } 10 | ``` 11 | 12 | ## Overview 13 | 14 | NSGANetV2 is an efficient NAS algorithm for generating task-specific models that are competitive under multiple competing objectives. It comprises of two surrogates, one at the architecture level to improve sample efficiency and one at the weights level, through a supernet, to improve gradient descent training efficiency. 15 | 16 | ## Datasets 17 | Download the datasets from the links embedded in the names. Datasets with * can be automatically downloaded. 18 | 19 | | Dataset | Type | Train Size | Test Size | #Classes | 20 | |:-:|:-:|:-:|:-:|:-:| 21 | | [ImageNet](http://www.image-net.org/) | multi-class | 1,281,167 | 50,000 | 1,000 | 22 | | [CINIC-10](https://github.com/BayesWatch/cinic-10) | | 180,000 | 9,000 | 10 | 23 | | [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* | | 50,000 | 10,000 | 10 | 24 | | [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html)* | | 50,000 | 10,000 | 10 | 25 | | [STL-10](https://ai.stanford.edu/~acoates/stl10/)* | | 5,000 | 8,000 | 10 | 26 | | [FGVC Aircraft](http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)* | fine-grained | 6,667 | 3,333 | 100 | 27 | | [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) | | 3,760 | 1,880 | 47 | 28 | | [Oxford-IIIT Pets](https://www.robots.ox.ac.uk/~vgg/data/pets/) | | 3,680 | 3,369 | 37 | 29 | | [Oxford Flowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) | | 2,040 | 6,149 | 102 | 30 | 31 | ## How to evalute NSGANetV2 models 32 | Download the models (`net.config`) and weights (`net.init`) from [[Google Drive]](https://drive.google.com/drive/folders/1owwmRNYQc8hIKOOFCYFCl2dukF5y1o-d?usp=sharing) or [[Baidu Yun]](https://pan.baidu.com/s/126FysPlOVTtDb6GQgx7eaA)(提取码:4isq). 33 | ```python 34 | """ NSGANetV2 pretrained models 35 | Syntax: python validation.py \ 36 | --dataset [imagenet/cifar10/...] --data /path/to/data \ 37 | --model /path/to/model/config/file --pretrained /path/to/model/weights 38 | """ 39 | ``` 40 | ImageNet | CIFAR-10 | CINIC10 41 | :-------------------------:|:-------------------------:|:-------------------------: 42 | ![](assets/imagenet.png) | ![](assets/c10.png) | ![](assets/cinic10.png) 43 | FLOPs@225: [[Google Drive]](https://drive.google.com/drive/folders/1jNvM_JgWBcnIR8mLDmW_0Q6H_EMlZ_97?usp=sharing)
FLOPs@312: [[Google Drive]](https://drive.google.com/drive/folders/1w2bRfAGCcDWjVzIPwn9Olqa8sbC5UpjZ?usp=sharing)
FLOPs@400: [[Google Drive]](https://drive.google.com/drive/folders/1vwOUzO6iNfyVNi3hgWSx4TANebrCkhri?usp=sharing)
FLOPs@593: [[Google Drive]](https://drive.google.com/drive/folders/1r6jk3ausee2PqHsOqessnmmS_3RZYvhY?usp=sharing) | FLOPs@232: [[Google Drive]](https://drive.google.com/drive/folders/1-bdVtZdzTCTPWYS4B-03qaoQycYs266Z?usp=sharing)
FLOPs@291: [[Google Drive]](https://drive.google.com/drive/folders/1SFuM5zSeqJjSxrNxLEzmixzaU4MvM19t?usp=sharing)
FLOPs@392: [[Google Drive]](https://drive.google.com/drive/folders/1sHrTlAB6fQVnAzOIOLUhKWcOQNusTfot?usp=sharing)
FLOPs@468: [[Google Drive]](https://drive.google.com/drive/folders/1F5Ar8h3DbpV1WmDfhrtPKmpuQYLFVXq-?usp=sharing) | FLOPs@317: [[Google Drive]](https://drive.google.com/drive/folders/1J_vOo1ym9LlATtAQXSehCQt2aJx2Hfm-?usp=sharing)
FLOPs@411: [[Google Drive]](https://drive.google.com/drive/folders/1tzjSn_wF9AwiArV7rusCOylu6IxZnCXq?usp=sharing)
FLOPs@501: [[Google Drive]](https://drive.google.com/drive/folders/1_u7ZCqMr_3HeOQJfbsq-5oW93ZqNpNpS?usp=sharing)
FLOPs@710: [[Google Drive]](https://drive.google.com/drive/folders/1BBnPx-nEPtGNMeixSYkGb2omW-QHO6o_?usp=sharing) 44 | 45 | Flowers102 | Aircraft | Oxford-IIIT Pets 46 | :-------------------------:|:-------------------------:|:-------------------------: 47 | ![](assets/flowers102.png) | ![](assets/aircraft.png) | ![](assets/pets.png) 48 | FLOPs@151: [[Google Drive]](https://drive.google.com/drive/folders/12yaZcl_wEYyUkRbLsrTotYtFjlwjIY7g?usp=sharing)
FLOPs@218: [[Google Drive]](https://drive.google.com/drive/folders/1i2nWhz0rMeRSLjZ-JcWylssdEQCX_EfT?usp=sharing)
FLOPs@249: [[Google Drive]](https://drive.google.com/drive/folders/1maqXm9yRv69tbElGKqHbZVIZXuSQ0uRy?usp=sharing)
FLOPs@317: [[Google Drive]](https://drive.google.com/drive/folders/14HK-zEYKr5sySmKbFRbEaQ66IX88e4uC?usp=sharing) | FLOPs@176: [[Google Drive]](https://drive.google.com/drive/folders/1q8AcI-zU_z3PIMjbqUEgszvAwL1p-n-5?usp=sharing)
FLOPs@271: [[Google Drive]](https://drive.google.com/drive/folders/1VIeOJXaVk6NXrNkGkcShIKP_fEfaZv3Y?usp=sharing)
FLOPs@331: [[Google Drive]](https://drive.google.com/drive/folders/1QQ4Lnpoe2obCnf0r6O5yfF3Rh4llZmTg?usp=sharing)
FLOPs@502: [[Google Drive]](https://drive.google.com/drive/folders/1yM4nmZMSz7EcB3xHX4QxKV7Ifcpk_KRS?usp=sharing) | FLOPs@137: [[Google Drive]](https://drive.google.com/drive/folders/1TWNkSBUQTrpq8IU6eRTHVoVtBtdjThoQ?usp=sharing)
FLOPs@189: [[Google Drive]](https://drive.google.com/drive/folders/1Xu-Dh6rDBS608a2p0ebti6mBs33GUSsc?usp=sharing)
FLOPs@284: [[Google Drive]](https://drive.google.com/drive/folders/1sNV4FquZifPw-VonjWmogS3LojmwPPd1?usp=sharing)
FLOPs@391: [[Google Drive]](https://drive.google.com/drive/folders/1E5PYuI-kb7Gr308kAzIwpo9fzFO4TXVQ?usp=sharing) 49 | 50 | CIFAR-100 | DTD | STL-10 51 | :-------------------------:|:-------------------------:|:-------------------------: 52 | ![](assets/c100.png) | ![](assets/dtd.png) | ![](assets/stl-10.png) 53 | FLOPs@261: [[Google Drive]](https://drive.google.com/drive/folders/1RlnOUhKKaexpwG2pSWiOGSZVGedbf9jd?usp=sharing)
FLOPs@398: [[Google Drive]](https://drive.google.com/drive/folders/1aqbQZvX-Zr6OgNi1-k_x-YzWBnzZyL92?usp=sharing)
FLOPs@492: [[Google Drive]](https://drive.google.com/drive/folders/1aqbQZvX-Zr6OgNi1-k_x-YzWBnzZyL92?usp=sharing)
FLOPs@796: [[Google Drive]](https://drive.google.com/drive/folders/1PJE6rtoJoKChXhw40PHLrT6tjy276ohJ?usp=sharing) | FLOPs@123: [[Google Drive]](https://drive.google.com/drive/folders/1isUPK0pHjVmXqDUqXgR3JOvP3zmq-PgJ?usp=sharing)
FLOPs@164: [[Google Drive]](https://drive.google.com/drive/folders/1pRGHy-bbw9Gz2g2QHau4z7dNFOdKT_w2?usp=sharing)
FLOPs@202: [[Google Drive]](https://drive.google.com/drive/folders/1CsGdcIv79rTDQbXSEJEgTOYZ6ViRwz1S?usp=sharing)
FLOPs@213: [[Google Drive]](https://drive.google.com/drive/folders/1MJVapbO7V5fT5Vjf1G0wjkgCOgkLKWEa?usp=sharing) | FLOPs@240: [[Google Drive]](https://drive.google.com/drive/folders/1tHOiEJhOpplOPnNqaDkHGPGfD5YRZrsJ?usp=sharing)
FLOPs@303: [[Google Drive]](https://drive.google.com/drive/folders/1IfFeg0CkPEFNN3iseOyeR1u-sozgSaya?usp=sharing)
FLOPs@436: [[Google Drive]](https://drive.google.com/drive/folders/13naG888dNouXj8Bd-gL4pMzrskDwdBC8?usp=sharing)
FLOPs@573: [[Google Drive]](https://drive.google.com/drive/folders/1iig97a1Xr-K40xbgwgBvMOPe8aLaqW2a?usp=sharing) 54 | 55 | ## How to use MSuNAS to search 56 | ```python 57 | """ Bi-objective search 58 | Syntax: python msunas.py \ 59 | --dataset [imagenet/cifar10/...] --data /path/to/dataset/images \ 60 | --save search-xxx \ # dir to save search results 61 | --sec_obj [params/flops/cpu] \ # objective (in addition to top-1 acc) 62 | --n_gpus 8 \ # number of available gpus 63 | --supernet_path /path/to/supernet/weights \ 64 | --vld_size [10000/5000/...] \ # number of subset images from training set to guide search 65 | --n_epochs [0/5] 66 | """ 67 | ``` 68 | - Download the pre-trained (on ImageNet) supernet from [here](https://hanlab.mit.edu/files/OnceForAll/ofa_nets/). 69 | - It supports searching for *FLOPs*, *Params*, and *Latency* as the second objective. 70 | - To optimize latency on your own device, you need to first construct a `look-up-table` for your own device, like [this](https://github.com/mikelzc1990/nsganetv2/blob/master/data/i7-8700K_lut.yaml). 71 | - Choose an appropriate `--vld_size` to guide the search, e.g. 10,000 for ImageNet, 5,000 for CIFAR-10/100. 72 | - Set `--n_epochs` to `0` for ImageNet and `5` for all other datasets. 73 | - See [here](https://github.com/mikelzc1990/nsganetv2/blob/master/scripts/search.sh) for some examples. 74 | - Output file structure: 75 | - Every architecture sampled during search has `net_x.subnet` and `net_x.stats` stored in the corresponding iteration dir. 76 | - A stats file is generated by the end of each iteration, `iter_x.stats`; it stores every architectures evaluated so far in `["archive"]`, and iteration-wise statistics, e.g. hypervolume in `["hv"]`, accuracy predictor related in `["surrogate"]`. 77 | - In case any architectures failed to evaluate during search, you may re-visit them in `failed` sub-dir under experiment dir. 78 | 79 | ImageNet | CIFAR-10 80 | :-------------------------:|:-------------------------: 81 | ![](assets/imagenet.gif) | ![](assets/c10.gif) 82 | 83 | ## How to choose architectures 84 | Once the search is completed, you can choose suitable architectures by: 85 | - You have preferences, e.g. architectures with xx.x% top-1 acc. and xxxM FLOPs, etc. 86 | ```python 87 | """ Find architectures with objectives close to your preferences 88 | Syntax: python post_search.py \ 89 | -n 3 \ # number of desired architectures you want, the most accurate archecture will always be selected 90 | --save search-imagenet/final \ # path to the dir to store the selected architectures 91 | --expr search-imagenet/iter_30.stats \ # path to last iteration stats file in experiment dir 92 | --prefer top1#80+flops#150 \ # your preferences, i.e. you want an architecture with 80% top-1 acc. and 150M FLOPs 93 | --supernet_path /path/to/imagenet/supernet/weights \ 94 | """ 95 | ``` 96 | - If you do not have preferences, pass `None` to argument `--prefer`, architectures will then be selected based on trade-offs. 97 | - All selected architectures should have three files created: 98 | - `net.subnet`: use to sample the architecture from the supernet 99 | - `net.config`: configuration file that defines the full architectural components 100 | - `net.inherited`: the inherited weights from supernet 101 | 102 | ## How to validate architectures 103 | To realize the full potential of the searched architectures, we further fine-tune from the inherited weights. Assuming that you have both `net.config` and `net.inherited` files. 104 | ```python 105 | """ Fine-tune on ImageNet from inherited weights 106 | Syntax: sh scripts/distributed_train.sh 8 \ # of available gpus 107 | /path/to/imagenet/data/ \ 108 | --model [nsganetv2_s/nsganetv2_m/...] \ # just for naming the output dir 109 | --model-config /path/to/model/.config/file \ 110 | --initial-checkpoint /path/to/model/.inherited/file \ 111 | --img-size [192, ..., 224, ..., 256] \ # image resolution, check "r" in net.subnet 112 | -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 \ 113 | --opt rmsproptf --opt-eps .001 -j 6 --warmup-lr 1e-6 \ 114 | --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 \ 115 | --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .024 \ 116 | --teacher /path/to/supernet/weights \ 117 | """ 118 | ``` 119 | - Adjust learning rate as `(batch_size_per_gpu * #GPUs / 256) * 0.006` depending on your system config. 120 | ```python 121 | """ Fine-tune on CIFAR-10 from inherited weights 122 | Syntax: python train_cifar.py \ 123 | --data /path/to/CIFAR-10/data/ \ 124 | --model [nsganetv2_s/nsganetv2_m/...] \ # just for naming the output dir 125 | --model-config /path/to/model/.config/file \ 126 | --img-size [192, ..., 224, ..., 256] \ # image resolution, check "r" in net.subnet 127 | --drop 0.2 --drop-path 0.2 \ 128 | --cutout --autoaugment --save 129 | """ 130 | ``` 131 | 132 | ## More Use Cases (coming soon) 133 | - [ ] With a different supernet (search space). 134 | - [ ] NASBench 101/201. 135 | - [ ] Architecture visualization. 136 | 137 | ## Requirements 138 | - Python 3.7 139 | - Cython 0.29 (optional; makes `non_dominated_sorting` faster in pymoo) 140 | - PyTorch 1.5.1 141 | - [pymoo](https://github.com/msu-coinlab/pymoo) 0.4.1 142 | - [torchprofile](https://github.com/zhijian-liu/torchprofile) 0.0.1 (for FLOPs calculation) 143 | - [OnceForAll](https://github.com/mit-han-lab/once-for-all) 0.0.4 (lower level supernet) 144 | - [timm](https://github.com/rwightman/pytorch-image-models) 0.1.30 145 | - [pySOT](https://github.com/dme65/pySOT) 0.2.3 (RBF surrogate model) 146 | - [pydacefit](https://github.com/msu-coinlab/pydacefit) 1.0.1 (GP surrogate model) 147 | -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import copy 5 | import logging 6 | import argparse 7 | import numpy as np 8 | from datetime import datetime 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.utils 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torchvision.transforms as transforms 17 | 18 | from codebase.data_providers.autoaugment import CIFAR10Policy 19 | 20 | from evaluator import OFAEvaluator 21 | from torchprofile import profile_macs 22 | from codebase.networks import NSGANetV2 23 | 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch CIFAR Training') 28 | parser.add_argument('--seed', type=int, default=None, help='random seed') 29 | parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') 30 | parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, cifar100, or cinic10') 31 | parser.add_argument('--batch-size', type=int, default=96, help='batch size') 32 | parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading') 33 | parser.add_argument('--n_gpus', type=int, default=1, help='number of available gpus for training') 34 | parser.add_argument('--lr', type=float, default=0.01, help='init learning rate') 35 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 36 | parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay') 37 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency') 38 | parser.add_argument('--epochs', type=int, default=150, help='num of training epochs') 39 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 40 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') 41 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 42 | parser.add_argument('--autoaugment', action='store_true', default=False, help='use auto augmentation') 43 | parser.add_argument('--save', action='store_true', default=False, help='dump output') 44 | parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save') 45 | parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model') 46 | # model related 47 | parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', 48 | help='Name of model to train (default: "countception"') 49 | parser.add_argument('--model-config', type=str, default=None, 50 | help='location of a json file of specific model declaration') 51 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 52 | help='Initialize model from this checkpoint (default: none)') 53 | parser.add_argument('--drop', type=float, default=0.2, 54 | help='dropout rate') 55 | parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT', 56 | help='Drop path rate (default: None)') 57 | parser.add_argument('--img-size', type=int, default=224, 58 | help='input resolution (192 -> 256)') 59 | args = parser.parse_args() 60 | 61 | dataset = args.dataset 62 | 63 | log_format = '%(asctime)s %(message)s' 64 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 65 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 66 | 67 | if args.save: 68 | args.save = '-'.join([ 69 | datetime.now().strftime("%Y%m%d-%H%M%S"), 70 | args.dataset, 71 | args.model, 72 | str(args.img_size) 73 | ]) 74 | 75 | if not os.path.exists(args.save): 76 | os.makedirs(args.save, exist_ok=True) 77 | print('Experiment dir : {}'.format(args.save)) 78 | 79 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 80 | fh.setFormatter(logging.Formatter(log_format)) 81 | logging.getLogger().addHandler(fh) 82 | 83 | device = 'cuda' 84 | 85 | NUM_CLASSES = 100 if 'cifar100' in dataset else 10 86 | 87 | 88 | def main(): 89 | if not torch.cuda.is_available(): 90 | logging.info('no gpu device available') 91 | sys.exit(1) 92 | 93 | logging.info("args = %s", args) 94 | 95 | if args.seed is not None: 96 | np.random.seed(args.seed) 97 | torch.manual_seed(args.seed) 98 | 99 | best_acc = 0 # initiate a artificial best accuracy so far 100 | top_checkpoints = [] # initiate a list to keep track of 101 | 102 | # Data 103 | train_transform, valid_transform = _data_transforms(args) 104 | if dataset == 'cifar100': 105 | train_data = torchvision.datasets.CIFAR100( 106 | root=args.data, train=True, download=True, transform=train_transform) 107 | valid_data = torchvision.datasets.CIFAR100( 108 | root=args.data, train=False, download=True, transform=valid_transform) 109 | elif dataset == 'cifar10': 110 | train_data = torchvision.datasets.CIFAR10( 111 | root=args.data, train=True, download=True, transform=train_transform) 112 | valid_data = torchvision.datasets.CIFAR10( 113 | root=args.data, train=False, download=True, transform=valid_transform) 114 | elif dataset == 'cinic10': 115 | train_data = torchvision.datasets.ImageFolder( 116 | args.data + 'train_and_valid', transform=train_transform) 117 | valid_data = torchvision.datasets.ImageFolder( 118 | args.data + 'test', transform=valid_transform) 119 | else: 120 | raise KeyError 121 | 122 | train_queue = torch.utils.data.DataLoader( 123 | train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.num_workers) 124 | 125 | valid_queue = torch.utils.data.DataLoader( 126 | valid_data, batch_size=200, shuffle=False, pin_memory=True, num_workers=args.num_workers) 127 | 128 | net_config = json.load(open(args.model_config)) 129 | net = NSGANetV2.build_from_config(net_config, drop_connect_rate=args.drop_path) 130 | init = torch.load(args.initial_checkpoint, map_location='cpu')['state_dict'] 131 | net.load_state_dict(init) 132 | 133 | NSGANetV2.reset_classifier( 134 | net, last_channel=net.classifier.in_features, 135 | n_classes=NUM_CLASSES, dropout_rate=args.drop) 136 | 137 | # calculate #Paramaters and #FLOPS 138 | inputs = torch.randn(1, 3, args.img_size, args.img_size) 139 | flops = profile_macs(copy.deepcopy(net), inputs) / 1e6 140 | params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6 141 | net_name = "net_flops@{:.0f}".format(flops) 142 | logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops)) 143 | 144 | if args.n_gpus > 1: 145 | net = nn.DataParallel(net) # data parallel in case more than 1 gpu available 146 | 147 | net = net.to(device) 148 | 149 | n_epochs = args.epochs 150 | 151 | parameters = filter(lambda p: p.requires_grad, net.parameters()) 152 | 153 | criterion = nn.CrossEntropyLoss().to(device) 154 | 155 | optimizer = optim.SGD(parameters, 156 | lr=args.lr, 157 | momentum=args.momentum, 158 | weight_decay=args.weight_decay) 159 | 160 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs) 161 | 162 | if args.evaluate: 163 | infer(valid_queue, net, criterion) 164 | sys.exit(0) 165 | 166 | for epoch in range(n_epochs): 167 | 168 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) 169 | 170 | train(train_queue, net, criterion, optimizer) 171 | _, valid_acc = infer(valid_queue, net, criterion) 172 | 173 | # checkpoint saving 174 | if args.save: 175 | if len(top_checkpoints) < args.topk: 176 | OFAEvaluator.save_net(args.save, net, net_name+'.ckpt{}'.format(epoch)) 177 | top_checkpoints.append((os.path.join(args.save, net_name+'.ckpt{}'.format(epoch)), valid_acc)) 178 | else: 179 | idx = np.argmin([x[1] for x in top_checkpoints]) 180 | if valid_acc > top_checkpoints[idx][1]: 181 | OFAEvaluator.save_net(args.save, net, net_name + '.ckpt{}'.format(epoch)) 182 | top_checkpoints.append((os.path.join(args.save, net_name+'.ckpt{}'.format(epoch)), valid_acc)) 183 | # remove the idx 184 | os.remove(top_checkpoints[idx][0]) 185 | top_checkpoints.pop(idx) 186 | print(top_checkpoints) 187 | 188 | if valid_acc > best_acc: 189 | OFAEvaluator.save_net(args.save, net, net_name + '.best') 190 | best_acc = valid_acc 191 | 192 | scheduler.step() 193 | 194 | OFAEvaluator.save_net_config(args.save, net, net_name+'.config') 195 | 196 | 197 | # Training 198 | def train(train_queue, net, criterion, optimizer): 199 | net.train() 200 | train_loss = 0 201 | correct = 0 202 | total = 0 203 | 204 | for step, (inputs, targets) in enumerate(train_queue): 205 | # upsample by bicubic to match imagenet training size 206 | inputs = F.interpolate(inputs, size=args.img_size, mode='bicubic', align_corners=False) 207 | inputs, targets = inputs.to(device), targets.to(device) 208 | optimizer.zero_grad() 209 | outputs = net(inputs) 210 | loss = criterion(outputs, targets) 211 | 212 | loss.backward() 213 | nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip) 214 | optimizer.step() 215 | 216 | train_loss += loss.item() 217 | _, predicted = outputs.max(1) 218 | total += targets.size(0) 219 | correct += predicted.eq(targets).sum().item() 220 | 221 | if step % args.report_freq == 0: 222 | logging.info('train %03d %e %f', step, train_loss/total, 100.*correct/total) 223 | 224 | logging.info('train acc %f', 100. * correct / total) 225 | 226 | return train_loss/total, 100.*correct/total 227 | 228 | 229 | def infer(valid_queue, net, criterion): 230 | net.eval() 231 | test_loss = 0 232 | correct = 0 233 | total = 0 234 | 235 | with torch.no_grad(): 236 | for step, (inputs, targets) in enumerate(valid_queue): 237 | inputs, targets = inputs.to(device), targets.to(device) 238 | outputs = net(inputs) 239 | loss = criterion(outputs, targets) 240 | 241 | test_loss += loss.item() 242 | _, predicted = outputs.max(1) 243 | total += targets.size(0) 244 | correct += predicted.eq(targets).sum().item() 245 | 246 | if step % args.report_freq == 0: 247 | logging.info('valid %03d %e %f', step, test_loss/total, 100.*correct/total) 248 | 249 | acc = 100.*correct/total 250 | logging.info('valid acc %f', 100. * correct / total) 251 | 252 | return test_loss/total, acc 253 | 254 | 255 | class Cutout(object): 256 | def __init__(self, length): 257 | self.length = length 258 | 259 | def __call__(self, img): 260 | h, w = img.size(1), img.size(2) 261 | mask = np.ones((h, w), np.float32) 262 | y = np.random.randint(h) 263 | x = np.random.randint(w) 264 | 265 | y1 = np.clip(y - self.length // 2, 0, h) 266 | y2 = np.clip(y + self.length // 2, 0, h) 267 | x1 = np.clip(x - self.length // 2, 0, w) 268 | x2 = np.clip(x + self.length // 2, 0, w) 269 | 270 | mask[y1: y2, x1: x2] = 0. 271 | mask = torch.from_numpy(mask) 272 | mask = mask.expand_as(img) 273 | img *= mask 274 | return img 275 | 276 | 277 | def _data_transforms(args): 278 | 279 | if 'cifar' in args.dataset: 280 | norm_mean = [0.49139968, 0.48215827, 0.44653124] 281 | norm_std = [0.24703233, 0.24348505, 0.26158768] 282 | elif 'cinic' in args.dataset: 283 | norm_mean = [0.47889522, 0.47227842, 0.43047404] 284 | norm_std = [0.24205776, 0.23828046, 0.25874835] 285 | else: 286 | raise KeyError 287 | 288 | train_transform = transforms.Compose([ 289 | transforms.RandomCrop(32, padding=4), 290 | # transforms.Resize(224, interpolation=3), # BICUBIC interpolation 291 | transforms.RandomHorizontalFlip(), 292 | ]) 293 | 294 | if args.autoaugment: 295 | train_transform.transforms.append(CIFAR10Policy()) 296 | 297 | train_transform.transforms.append(transforms.ToTensor()) 298 | 299 | if args.cutout: 300 | train_transform.transforms.append(Cutout(args.cutout_length)) 301 | 302 | train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std)) 303 | 304 | valid_transform = transforms.Compose([ 305 | transforms.Resize(args.img_size, interpolation=3), # BICUBIC interpolation 306 | transforms.ToTensor(), 307 | transforms.Normalize(norm_mean, norm_std), 308 | ]) 309 | return train_transform, valid_transform 310 | 311 | 312 | if __name__ == '__main__': 313 | main() 314 | -------------------------------------------------------------------------------- /codebase/run_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from codebase.data_providers.imagenet import * 2 | from codebase.data_providers.cifar import * 3 | from codebase.data_providers.flowers102 import * 4 | from codebase.data_providers.stl10 import * 5 | from codebase.data_providers.dtd import * 6 | from codebase.data_providers.pets import * 7 | from codebase.data_providers.aircraft import * 8 | 9 | from ofa.imagenet_classification.run_manager.run_config import RunConfig 10 | 11 | 12 | class ImagenetRunConfig(RunConfig): 13 | 14 | def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None, 15 | dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None, 16 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None, 17 | mixup_alpha=None, 18 | model_init='he_fout', validation_frequency=1, print_frequency=10, 19 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, 20 | data_path='/mnt/datastore/ILSVRC2012', 21 | **kwargs): 22 | super(ImagenetRunConfig, self).__init__( 23 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 24 | dataset, train_batch_size, test_batch_size, valid_size, 25 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 26 | mixup_alpha, 27 | model_init, validation_frequency, print_frequency 28 | ) 29 | self.n_worker = n_worker 30 | self.resize_scale = resize_scale 31 | self.distort_color = distort_color 32 | self.image_size = image_size 33 | self.imagenet_data_path = data_path 34 | 35 | @property 36 | def data_provider(self): 37 | if self.__dict__.get('_data_provider', None) is None: 38 | if self.dataset == ImagenetDataProvider.name(): 39 | DataProviderClass = ImagenetDataProvider 40 | else: 41 | raise NotImplementedError 42 | self.__dict__['_data_provider'] = DataProviderClass( 43 | save_path=self.imagenet_data_path, 44 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, 45 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, 46 | distort_color=self.distort_color, image_size=self.image_size, 47 | ) 48 | return self.__dict__['_data_provider'] 49 | 50 | 51 | class CIFARRunConfig(RunConfig): 52 | def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None, 53 | dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None, 54 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None, 55 | mixup_alpha=None, 56 | model_init='he_fout', validation_frequency=1, print_frequency=10, 57 | n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, 58 | data_path='/mnt/datastore/CIFAR', 59 | **kwargs): 60 | super(CIFARRunConfig, self).__init__( 61 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 62 | dataset, train_batch_size, test_batch_size, valid_size, 63 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 64 | mixup_alpha, 65 | model_init, validation_frequency, print_frequency 66 | ) 67 | 68 | self.n_worker = n_worker 69 | self.resize_scale = resize_scale 70 | self.distort_color = distort_color 71 | self.image_size = image_size 72 | self.cifar_data_path = data_path 73 | 74 | @property 75 | def data_provider(self): 76 | if self.__dict__.get('_data_provider', None) is None: 77 | if self.dataset == CIFAR10DataProvider.name(): 78 | DataProviderClass = CIFAR10DataProvider 79 | elif self.dataset == CIFAR100DataProvider.name(): 80 | DataProviderClass = CIFAR100DataProvider 81 | elif self.dataset == CINIC10DataProvider.name(): 82 | DataProviderClass = CINIC10DataProvider 83 | else: 84 | raise NotImplementedError 85 | self.__dict__['_data_provider'] = DataProviderClass( 86 | save_path=self.cifar_data_path, 87 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, 88 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, 89 | distort_color=self.distort_color, image_size=self.image_size, 90 | ) 91 | return self.__dict__['_data_provider'] 92 | 93 | 94 | class Flowers102RunConfig(RunConfig): 95 | 96 | def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None, 97 | dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None, 98 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None, 99 | mixup_alpha=None, 100 | model_init='he_fout', validation_frequency=1, print_frequency=10, 101 | n_worker=4, resize_scale=0.08, distort_color=None, image_size=224, 102 | data_path='/mnt/datastore/Flowers102', 103 | **kwargs): 104 | super(Flowers102RunConfig, self).__init__( 105 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 106 | dataset, train_batch_size, test_batch_size, valid_size, 107 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 108 | mixup_alpha, 109 | model_init, validation_frequency, print_frequency 110 | ) 111 | 112 | self.n_worker = n_worker 113 | self.resize_scale = resize_scale 114 | self.distort_color = distort_color 115 | self.image_size = image_size 116 | self.flowers102_data_path = data_path 117 | 118 | @property 119 | def data_provider(self): 120 | if self.__dict__.get('_data_provider', None) is None: 121 | if self.dataset == Flowers102DataProvider.name(): 122 | DataProviderClass = Flowers102DataProvider 123 | else: 124 | raise NotImplementedError 125 | self.__dict__['_data_provider'] = DataProviderClass( 126 | save_path=self.flowers102_data_path, 127 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, 128 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, 129 | distort_color=self.distort_color, image_size=self.image_size, 130 | ) 131 | return self.__dict__['_data_provider'] 132 | 133 | 134 | class STL10RunConfig(RunConfig): 135 | 136 | def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None, 137 | dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None, 138 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None, 139 | mixup_alpha=None, 140 | model_init='he_fout', validation_frequency=1, print_frequency=10, 141 | n_worker=4, resize_scale=0.08, distort_color=None, image_size=224, 142 | data_path='/mnt/datastore/STL10', 143 | **kwargs): 144 | super(STL10RunConfig, self).__init__( 145 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 146 | dataset, train_batch_size, test_batch_size, valid_size, 147 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 148 | mixup_alpha, 149 | model_init, validation_frequency, print_frequency 150 | ) 151 | 152 | self.n_worker = n_worker 153 | self.resize_scale = resize_scale 154 | self.distort_color = distort_color 155 | self.image_size = image_size 156 | self.stl10_data_path = data_path 157 | 158 | @property 159 | def data_provider(self): 160 | if self.__dict__.get('_data_provider', None) is None: 161 | if self.dataset == STL10DataProvider.name(): 162 | DataProviderClass = STL10DataProvider 163 | else: 164 | raise NotImplementedError 165 | self.__dict__['_data_provider'] = DataProviderClass( 166 | save_path=self.stl10_data_path, 167 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, 168 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, 169 | distort_color=self.distort_color, image_size=self.image_size, 170 | ) 171 | return self.__dict__['_data_provider'] 172 | 173 | 174 | class DTDRunConfig(RunConfig): 175 | 176 | def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 177 | dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None, 178 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None, 179 | mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10, 180 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, 181 | data_path='/mnt/datastore/dtd', 182 | **kwargs): 183 | super(DTDRunConfig, self).__init__( 184 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 185 | dataset, train_batch_size, test_batch_size, valid_size, 186 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 187 | mixup_alpha, 188 | model_init, validation_frequency, print_frequency 189 | ) 190 | self.n_worker = n_worker 191 | self.resize_scale = resize_scale 192 | self.distort_color = distort_color 193 | self.image_size = image_size 194 | self.data_path = data_path 195 | 196 | @property 197 | def data_provider(self): 198 | if self.__dict__.get('_data_provider', None) is None: 199 | if self.dataset == DTDDataProvider.name(): 200 | DataProviderClass = DTDDataProvider 201 | else: 202 | raise NotImplementedError 203 | self.__dict__['_data_provider'] = DataProviderClass( 204 | save_path=self.data_path, 205 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, 206 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, 207 | distort_color=self.distort_color, image_size=self.image_size, 208 | ) 209 | return self.__dict__['_data_provider'] 210 | 211 | 212 | class PetsRunConfig(RunConfig): 213 | 214 | def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 215 | dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None, 216 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None, 217 | mixup_alpha=None, 218 | model_init='he_fout', validation_frequency=1, print_frequency=10, 219 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, 220 | data_path='/mnt/datastore/Oxford-IIITPets', 221 | **kwargs): 222 | super(PetsRunConfig, self).__init__( 223 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 224 | dataset, train_batch_size, test_batch_size, valid_size, 225 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 226 | mixup_alpha, 227 | model_init, validation_frequency, print_frequency 228 | ) 229 | self.n_worker = n_worker 230 | self.resize_scale = resize_scale 231 | self.distort_color = distort_color 232 | self.image_size = image_size 233 | self.imagenet_data_path = data_path 234 | 235 | @property 236 | def data_provider(self): 237 | if self.__dict__.get('_data_provider', None) is None: 238 | if self.dataset == OxfordIIITPetsDataProvider.name(): 239 | DataProviderClass = OxfordIIITPetsDataProvider 240 | else: 241 | raise NotImplementedError 242 | self.__dict__['_data_provider'] = DataProviderClass( 243 | save_path=self.imagenet_data_path, 244 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, 245 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, 246 | distort_color=self.distort_color, image_size=self.image_size, 247 | ) 248 | return self.__dict__['_data_provider'] 249 | 250 | 251 | class AircraftRunConfig(RunConfig): 252 | 253 | def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None, 254 | dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None, 255 | opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None, 256 | mixup_alpha=None, 257 | model_init='he_fout', validation_frequency=1, print_frequency=10, 258 | n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, 259 | data_path='/mnt/datastore/Aircraft', 260 | **kwargs): 261 | super(AircraftRunConfig, self).__init__( 262 | n_epochs, init_lr, lr_schedule_type, lr_schedule_param, 263 | dataset, train_batch_size, test_batch_size, valid_size, 264 | opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys, 265 | mixup_alpha, 266 | model_init, validation_frequency, print_frequency 267 | ) 268 | self.n_worker = n_worker 269 | self.resize_scale = resize_scale 270 | self.distort_color = distort_color 271 | self.image_size = image_size 272 | self.data_path = data_path 273 | 274 | @property 275 | def data_provider(self): 276 | if self.__dict__.get('_data_provider', None) is None: 277 | if self.dataset == FGVCAircraftDataProvider.name(): 278 | DataProviderClass = FGVCAircraftDataProvider 279 | else: 280 | raise NotImplementedError 281 | self.__dict__['_data_provider'] = DataProviderClass( 282 | save_path=self.data_path, 283 | train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size, 284 | valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale, 285 | distort_color=self.distort_color, image_size=self.image_size, 286 | ) 287 | return self.__dict__['_data_provider'] 288 | 289 | 290 | def get_run_config(**kwargs): 291 | if kwargs['dataset'] == 'imagenet': 292 | run_config = ImagenetRunConfig(**kwargs) 293 | elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'): 294 | run_config = CIFARRunConfig(**kwargs) 295 | elif kwargs['dataset'] == 'flowers102': 296 | run_config = Flowers102RunConfig(**kwargs) 297 | elif kwargs['dataset'] == 'stl10': 298 | run_config = STL10RunConfig(**kwargs) 299 | elif kwargs['dataset'] == 'dtd': 300 | run_config = DTDRunConfig(**kwargs) 301 | elif kwargs['dataset'] == 'pets': 302 | run_config = PetsRunConfig(**kwargs) 303 | elif kwargs['dataset'] == 'aircraft': 304 | run_config = AircraftRunConfig(**kwargs) 305 | else: 306 | raise NotImplementedError 307 | 308 | return run_config 309 | 310 | 311 | -------------------------------------------------------------------------------- /msunas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import argparse 5 | import subprocess 6 | import numpy as np 7 | from utils import get_correlation 8 | from evaluator import OFAEvaluator, get_net_info 9 | 10 | from pymoo.optimize import minimize 11 | from pymoo.model.problem import Problem 12 | from pymoo.factory import get_performance_indicator 13 | from pymoo.algorithms.so_genetic_algorithm import GA 14 | from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting 15 | from pymoo.factory import get_algorithm, get_crossover, get_mutation 16 | 17 | from search_space.ofa import OFASearchSpace 18 | from acc_predictor.factory import get_acc_predictor 19 | from utils import prepare_eval_folder, MySampling, BinaryCrossover, MyMutation 20 | 21 | _DEBUG = False 22 | if _DEBUG: from pymoo.visualization.scatter import Scatter 23 | 24 | 25 | class MSuNAS: 26 | def __init__(self, kwargs): 27 | self.search_space = OFASearchSpace() 28 | self.save_path = kwargs.pop('save', '.tmp') # path to save results 29 | self.resume = kwargs.pop('resume', None) # resume search from a checkpoint 30 | self.sec_obj = kwargs.pop('sec_obj', 'flops') # second objective to optimize simultaneously 31 | self.iterations = kwargs.pop('iterations', 30) # number of iterations to run search 32 | self.n_doe = kwargs.pop('n_doe', 100) # number of architectures to train before fit surrogate model 33 | self.n_iter = kwargs.pop('n_iter', 8) # number of architectures to train in each iteration 34 | self.predictor = kwargs.pop('predictor', 'rbf') # which surrogate model to fit 35 | self.n_gpus = kwargs.pop('n_gpus', 1) # number of available gpus 36 | self.gpu = kwargs.pop('gpu', 1) # required number of gpus per evaluation job 37 | self.data = kwargs.pop('data', '../data') # location of the data files 38 | self.dataset = kwargs.pop('dataset', 'imagenet') # which dataset to run search on 39 | self.n_classes = kwargs.pop('n_classes', 1000) # number of classes of the given dataset 40 | self.n_workers = kwargs.pop('n_workers', 6) # number of threads for dataloader 41 | self.vld_size = kwargs.pop('vld_size', 10000) # number of images from train set to validate performance 42 | self.trn_batch_size = kwargs.pop('trn_batch_size', 96) # batch size for SGD training 43 | self.vld_batch_size = kwargs.pop('vld_batch_size', 250) # batch size for validation 44 | self.n_epochs = kwargs.pop('n_epochs', 5) # number of epochs to SGD training 45 | self.test = kwargs.pop('test', False) # evaluate performance on test set 46 | self.supernet_path = kwargs.pop( 47 | 'supernet_path', './data/ofa_mbv3_d234_e346_k357_w1.0') # supernet model path 48 | self.latency = self.sec_obj if "cpu" in self.sec_obj or "gpu" in self.sec_obj else None 49 | 50 | def search(self): 51 | 52 | if self.resume: 53 | archive = self._resume_from_dir() 54 | else: 55 | # the following lines corresponding to Algo 1 line 1-7 in the paper 56 | archive = [] # initialize an empty archive to store all trained CNNs 57 | 58 | # Design Of Experiment 59 | if self.iterations < 1: 60 | arch_doe = self.search_space.sample(self.n_doe) 61 | else: 62 | arch_doe = self.search_space.initialize(self.n_doe) 63 | 64 | # parallel evaluation of arch_doe 65 | top1_err, complexity = self._evaluate(arch_doe, it=0) 66 | 67 | # store evaluated / trained architectures 68 | for member in zip(arch_doe, top1_err, complexity): 69 | archive.append(member) 70 | 71 | # reference point (nadir point) for calculating hypervolume 72 | ref_pt = np.array([np.max([x[1] for x in archive]), np.max([x[2] for x in archive])]) 73 | 74 | # main loop of the search 75 | for it in range(1, self.iterations + 1): 76 | 77 | # construct accuracy predictor surrogate model from archive 78 | # Algo 1 line 9 / Fig. 3(a) in the paper 79 | acc_predictor, a_top1_err_pred = self._fit_acc_predictor(archive) 80 | 81 | # search for the next set of candidates for high-fidelity evaluation (lower level) 82 | # Algo 1 line 10-11 / Fig. 3(b)-(d) in the paper 83 | candidates, c_top1_err_pred = self._next(archive, acc_predictor, self.n_iter) 84 | 85 | # high-fidelity evaluation (lower level) 86 | # Algo 1 line 13-14 / Fig. 3(e) in the paper 87 | c_top1_err, complexity = self._evaluate(candidates, it=it) 88 | 89 | # check for accuracy predictor's performance 90 | rmse, rho, tau = get_correlation( 91 | np.vstack((a_top1_err_pred, c_top1_err_pred)), np.array([x[1] for x in archive] + c_top1_err)) 92 | 93 | # add to archive 94 | # Algo 1 line 15 / Fig. 3(e) in the paper 95 | for member in zip(candidates, c_top1_err, complexity): 96 | archive.append(member) 97 | 98 | # calculate hypervolume 99 | hv = self._calc_hv( 100 | ref_pt, np.column_stack(([x[1] for x in archive], [x[2] for x in archive]))) 101 | 102 | # print iteration-wise statistics 103 | print("Iter {}: hv = {:.2f}".format(it, hv)) 104 | print("fitting {}: RMSE = {:.4f}, Spearman's Rho = {:.4f}, Kendall’s Tau = {:.4f}".format( 105 | self.predictor, rmse, rho, tau)) 106 | 107 | # dump the statistics 108 | with open(os.path.join(self.save_path, "iter_{}.stats".format(it)), "w") as handle: 109 | json.dump({'archive': archive, 'candidates': archive[-self.n_iter:], 'hv': hv, 110 | 'surrogate': { 111 | 'model': self.predictor, 'name': acc_predictor.name, 112 | 'winner': acc_predictor.winner if self.predictor == 'as' else acc_predictor.name, 113 | 'rmse': rmse, 'rho': rho, 'tau': tau}}, handle) 114 | if _DEBUG: 115 | # plot 116 | plot = Scatter(legend={'loc': 'lower right'}) 117 | F = np.full((len(archive), 2), np.nan) 118 | F[:, 0] = np.array([x[2] for x in archive]) # second obj. (complexity) 119 | F[:, 1] = 100 - np.array([x[1] for x in archive]) # top-1 accuracy 120 | plot.add(F, s=15, facecolors='none', edgecolors='b', label='archive') 121 | F = np.full((len(candidates), 2), np.nan) 122 | F[:, 0] = np.array(complexity) 123 | F[:, 1] = 100 - np.array(c_top1_err) 124 | plot.add(F, s=30, color='r', label='candidates evaluated') 125 | F = np.full((len(candidates), 2), np.nan) 126 | F[:, 0] = np.array(complexity) 127 | F[:, 1] = 100 - c_top1_err_pred[:, 0] 128 | plot.add(F, s=20, facecolors='none', edgecolors='g', label='candidates predicted') 129 | plot.save(os.path.join(self.save_path, 'iter_{}.png'.format(it))) 130 | 131 | return 132 | 133 | def _resume_from_dir(self): 134 | """ resume search from a previous iteration """ 135 | import glob 136 | 137 | archive = [] 138 | for file in glob.glob(os.path.join(self.resume, "net_*.subnet")): 139 | arch = json.load(open(file)) 140 | pre, ext = os.path.splitext(file) 141 | stats = json.load(open(pre + ".stats")) 142 | archive.append((arch, 100 - stats['top1'], stats[self.sec_obj])) 143 | 144 | return archive 145 | 146 | def _evaluate(self, archs, it): 147 | gen_dir = os.path.join(self.save_path, "iter_{}".format(it)) 148 | prepare_eval_folder( 149 | gen_dir, archs, self.gpu, self.n_gpus, data=self.data, dataset=self.dataset, 150 | n_classes=self.n_classes, supernet_path=self.supernet_path, 151 | num_workers=self.n_workers, valid_size=self.vld_size, 152 | trn_batch_size=self.trn_batch_size, vld_batch_size=self.vld_batch_size, 153 | n_epochs=self.n_epochs, test=self.test, latency=self.latency, verbose=False) 154 | 155 | subprocess.call("sh {}/run_bash.sh".format(gen_dir), shell=True) 156 | 157 | top1_err, complexity = [], [] 158 | 159 | for i in range(len(archs)): 160 | try: 161 | stats = json.load(open(os.path.join(gen_dir, "net_{}.stats".format(i)))) 162 | except FileNotFoundError: 163 | # just in case the subprocess evaluation failed 164 | stats = {'top1': 0, self.sec_obj: 1000} # makes the solution artificially bad so it won't survive 165 | # store this architecture to a separate in case we want to revisit after the search 166 | os.makedirs(os.path.join(self.save_path, "failed"), exist_ok=True) 167 | shutil.copy(os.path.join(gen_dir, "net_{}.subnet".format(i)), 168 | os.path.join(self.save_path, "failed", "it_{}_net_{}".format(it, i))) 169 | 170 | top1_err.append(100 - stats['top1']) 171 | complexity.append(stats[self.sec_obj]) 172 | 173 | return top1_err, complexity 174 | 175 | def _fit_acc_predictor(self, archive): 176 | inputs = np.array([self.search_space.encode(x[0]) for x in archive]) 177 | targets = np.array([x[1] for x in archive]) 178 | assert len(inputs) > len(inputs[0]), "# of training samples have to be > # of dimensions" 179 | 180 | acc_predictor = get_acc_predictor(self.predictor, inputs, targets) 181 | 182 | return acc_predictor, acc_predictor.predict(inputs) 183 | 184 | def _next(self, archive, predictor, K): 185 | """ searching for next K candidate for high-fidelity evaluation (lower level) """ 186 | 187 | # the following lines corresponding to Algo 1 line 10 / Fig. 3(b) in the paper 188 | # get non-dominated architectures from archive 189 | F = np.column_stack(([x[1] for x in archive], [x[2] for x in archive])) 190 | front = NonDominatedSorting().do(F, only_non_dominated_front=True) 191 | # non-dominated arch bit-strings 192 | nd_X = np.array([self.search_space.encode(x[0]) for x in archive])[front] 193 | 194 | # initialize the candidate finding optimization problem 195 | problem = AuxiliarySingleLevelProblem( 196 | self.search_space, predictor, self.sec_obj, 197 | {'n_classes': self.n_classes, 'model_path': self.supernet_path}) 198 | 199 | # initiate a multi-objective solver to optimize the problem 200 | method = get_algorithm( 201 | "nsga2", pop_size=40, sampling=nd_X, # initialize with current nd archs 202 | crossover=get_crossover("int_two_point", prob=0.9), 203 | mutation=get_mutation("int_pm", eta=1.0), 204 | eliminate_duplicates=True) 205 | 206 | # kick-off the search 207 | res = minimize( 208 | problem, method, termination=('n_gen', 20), save_history=True, verbose=True) 209 | 210 | # check for duplicates 211 | not_duplicate = np.logical_not( 212 | [x in [x[0] for x in archive] for x in [self.search_space.decode(x) for x in res.pop.get("X")]]) 213 | 214 | # the following lines corresponding to Algo 1 line 11 / Fig. 3(c)-(d) in the paper 215 | # form a subset selection problem to short list K from pop_size 216 | indices = self._subset_selection(res.pop[not_duplicate], F[front, 1], K) 217 | pop = res.pop[not_duplicate][indices] 218 | 219 | candidates = [] 220 | for x in pop.get("X"): 221 | candidates.append(self.search_space.decode(x)) 222 | 223 | # decode integer bit-string to config and also return predicted top1_err 224 | return candidates, predictor.predict(pop.get("X")) 225 | 226 | @staticmethod 227 | def _subset_selection(pop, nd_F, K): 228 | problem = SubsetProblem(pop.get("F")[:, 1], nd_F, K) 229 | algorithm = GA( 230 | pop_size=100, sampling=MySampling(), crossover=BinaryCrossover(), 231 | mutation=MyMutation(), eliminate_duplicates=True) 232 | 233 | res = minimize( 234 | problem, algorithm, ('n_gen', 60), verbose=False) 235 | 236 | return res.X 237 | 238 | @staticmethod 239 | def _calc_hv(ref_pt, F, normalized=True): 240 | # calculate hypervolume on the non-dominated set of F 241 | front = NonDominatedSorting().do(F, only_non_dominated_front=True) 242 | nd_F = F[front, :] 243 | ref_point = 1.01 * ref_pt 244 | hv = get_performance_indicator("hv", ref_point=ref_point).calc(nd_F) 245 | if normalized: 246 | hv = hv / np.prod(ref_point) 247 | return hv 248 | 249 | 250 | class AuxiliarySingleLevelProblem(Problem): 251 | """ The optimization problem for finding the next N candidate architectures """ 252 | 253 | def __init__(self, search_space, predictor, sec_obj='flops', supernet=None): 254 | super().__init__(n_var=46, n_obj=2, n_constr=0, type_var=np.int) 255 | 256 | self.ss = search_space 257 | self.predictor = predictor 258 | self.xl = np.zeros(self.n_var) 259 | self.xu = 2 * np.ones(self.n_var) 260 | self.xu[-1] = int(len(self.ss.resolution) - 1) 261 | self.sec_obj = sec_obj 262 | self.lut = {'cpu': 'data/i7-8700K_lut.yaml'} 263 | 264 | # supernet engine for measuring complexity 265 | self.engine = OFAEvaluator( 266 | n_classes=supernet['n_classes'], model_path=supernet['model_path']) 267 | 268 | def _evaluate(self, x, out, *args, **kwargs): 269 | f = np.full((x.shape[0], self.n_obj), np.nan) 270 | 271 | top1_err = self.predictor.predict(x)[:, 0] # predicted top1 error 272 | 273 | for i, (_x, err) in enumerate(zip(x, top1_err)): 274 | config = self.ss.decode(_x) 275 | subnet, _ = self.engine.sample({'ks': config['ks'], 'e': config['e'], 'd': config['d']}) 276 | info = get_net_info(subnet, (3, config['r'], config['r']), 277 | measure_latency=self.sec_obj, print_info=False, clean=True, lut=self.lut) 278 | f[i, 0] = err 279 | f[i, 1] = info[self.sec_obj] 280 | 281 | out["F"] = f 282 | 283 | 284 | class SubsetProblem(Problem): 285 | """ select a subset to diversify the pareto front """ 286 | def __init__(self, candidates, archive, K): 287 | super().__init__(n_var=len(candidates), n_obj=1, 288 | n_constr=1, xl=0, xu=1, type_var=np.bool) 289 | self.archive = archive 290 | self.candidates = candidates 291 | self.n_max = K 292 | 293 | def _evaluate(self, x, out, *args, **kwargs): 294 | f = np.full((x.shape[0], 1), np.nan) 295 | g = np.full((x.shape[0], 1), np.nan) 296 | 297 | for i, _x in enumerate(x): 298 | # append selected candidates to archive then sort 299 | tmp = np.sort(np.concatenate((self.archive, self.candidates[_x]))) 300 | f[i, 0] = np.std(np.diff(tmp)) 301 | # we penalize if the number of selected candidates is not exactly K 302 | g[i, 0] = (self.n_max - np.sum(_x)) ** 2 303 | 304 | out["F"] = f 305 | out["G"] = g 306 | 307 | 308 | def main(args): 309 | engine = MSuNAS(vars(args)) 310 | engine.search() 311 | return 312 | 313 | 314 | if __name__ == '__main__': 315 | parser = argparse.ArgumentParser() 316 | parser.add_argument('--save', type=str, default='.tmp', 317 | help='location of dir to save') 318 | parser.add_argument('--resume', type=str, default=None, 319 | help='resume search from a checkpoint') 320 | parser.add_argument('--sec_obj', type=str, default='flops', 321 | help='second objective to optimize simultaneously') 322 | parser.add_argument('--iterations', type=int, default=30, 323 | help='number of search iterations') 324 | parser.add_argument('--n_doe', type=int, default=100, 325 | help='initial sample size for DOE') 326 | parser.add_argument('--n_iter', type=int, default=8, 327 | help='number of architectures to high-fidelity eval (low level) in each iteration') 328 | parser.add_argument('--predictor', type=str, default='rbf', 329 | help='which accuracy predictor model to fit (rbf/gp/cart/mlp/as)') 330 | parser.add_argument('--n_gpus', type=int, default=8, 331 | help='total number of available gpus') 332 | parser.add_argument('--gpu', type=int, default=1, 333 | help='number of gpus per evaluation job') 334 | parser.add_argument('--data', type=str, default='/mnt/datastore/ILSVRC2012', 335 | help='location of the data corpus') 336 | parser.add_argument('--dataset', type=str, default='imagenet', 337 | help='name of the dataset (imagenet, cifar10, cifar100, ...)') 338 | parser.add_argument('--n_classes', type=int, default=1000, 339 | help='number of classes of the given dataset') 340 | parser.add_argument('--supernet_path', type=str, default='./data/ofa_mbv3_d234_e346_k357_w1.0', 341 | help='file path to supernet weights') 342 | parser.add_argument('--n_workers', type=int, default=4, 343 | help='number of workers for dataloader per evaluation job') 344 | parser.add_argument('--vld_size', type=int, default=None, 345 | help='validation set size, randomly sampled from training set') 346 | parser.add_argument('--trn_batch_size', type=int, default=128, 347 | help='train batch size for training') 348 | parser.add_argument('--vld_batch_size', type=int, default=200, 349 | help='test batch size for inference') 350 | parser.add_argument('--n_epochs', type=int, default=5, 351 | help='number of epochs for CNN training') 352 | parser.add_argument('--test', action='store_true', default=False, 353 | help='evaluation performance on testing set') 354 | cfgs = parser.parse_args() 355 | main(cfgs) 356 | 357 | -------------------------------------------------------------------------------- /codebase/data_providers/aircraft.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import math 5 | import warnings 6 | import numpy as np 7 | 8 | from timm.data.transforms import str_to_pil_interp 9 | from timm.data.auto_augment import rand_augment_transform 10 | 11 | import torch.utils.data 12 | import torchvision.transforms as transforms 13 | from torchvision.datasets.folder import default_loader 14 | 15 | from ofa.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler 16 | from ofa.imagenet_classification.data_providers.base_provider import DataProvider 17 | 18 | 19 | def make_dataset(dir, image_ids, targets): 20 | assert(len(image_ids) == len(targets)) 21 | images = [] 22 | dir = os.path.expanduser(dir) 23 | for i in range(len(image_ids)): 24 | item = (os.path.join(dir, 'data', 'images', 25 | '%s.jpg' % image_ids[i]), targets[i]) 26 | images.append(item) 27 | return images 28 | 29 | 30 | def find_classes(classes_file): 31 | # read classes file, separating out image IDs and class names 32 | image_ids = [] 33 | targets = [] 34 | f = open(classes_file, 'r') 35 | for line in f: 36 | split_line = line.split(' ') 37 | image_ids.append(split_line[0]) 38 | targets.append(' '.join(split_line[1:])) 39 | f.close() 40 | 41 | # index class names 42 | classes = np.unique(targets) 43 | class_to_idx = {classes[i]: i for i in range(len(classes))} 44 | targets = [class_to_idx[c] for c in targets] 45 | 46 | return (image_ids, targets, classes, class_to_idx) 47 | 48 | 49 | class FGVCAircraft(torch.utils.data.Dataset): 50 | """`FGVC-Aircraft `_ Dataset. 51 | Args: 52 | root (string): Root directory path to dataset. 53 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification 54 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). 55 | transform (callable, optional): A function/transform that takes in a PIL image 56 | and returns a transformed version. E.g. ``transforms.RandomCrop`` 57 | target_transform (callable, optional): A function/transform that takes in the 58 | target and transforms it. 59 | loader (callable, optional): A function to load an image given its path. 60 | download (bool, optional): If true, downloads the dataset from the internet and 61 | puts it in the root directory. If dataset is already downloaded, it is not 62 | downloaded again. 63 | """ 64 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 65 | class_types = ('variant', 'family', 'manufacturer') 66 | splits = ('train', 'val', 'trainval', 'test') 67 | 68 | def __init__(self, root, class_type='variant', split='train', transform=None, 69 | target_transform=None, loader=default_loader, download=False): 70 | if split not in self.splits: 71 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 72 | split, ', '.join(self.splits), 73 | )) 74 | if class_type not in self.class_types: 75 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( 76 | class_type, ', '.join(self.class_types), 77 | )) 78 | self.root = os.path.expanduser(root) 79 | self.class_type = class_type 80 | self.split = split 81 | self.classes_file = os.path.join(self.root, 'data', 82 | 'images_%s_%s.txt' % (self.class_type, self.split)) 83 | 84 | if download: 85 | self.download() 86 | 87 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) 88 | samples = make_dataset(self.root, image_ids, targets) 89 | 90 | self.transform = transform 91 | self.target_transform = target_transform 92 | self.loader = loader 93 | 94 | self.samples = samples 95 | self.classes = classes 96 | self.class_to_idx = class_to_idx 97 | 98 | def __getitem__(self, index): 99 | """ 100 | Args: 101 | index (int): Index 102 | Returns: 103 | tuple: (sample, target) where target is class_index of the target class. 104 | """ 105 | 106 | path, target = self.samples[index] 107 | sample = self.loader(path) 108 | if self.transform is not None: 109 | sample = self.transform(sample) 110 | if self.target_transform is not None: 111 | target = self.target_transform(target) 112 | 113 | return sample, target 114 | 115 | def __len__(self): 116 | return len(self.samples) 117 | 118 | def __repr__(self): 119 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 120 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 121 | fmt_str += ' Root Location: {}\n'.format(self.root) 122 | tmp = ' Transforms (if any): ' 123 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 124 | tmp = ' Target Transforms (if any): ' 125 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 126 | return fmt_str 127 | 128 | def _check_exists(self): 129 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ 130 | os.path.exists(self.classes_file) 131 | 132 | def download(self): 133 | """Download the FGVC-Aircraft data if it doesn't exist already.""" 134 | from six.moves import urllib 135 | import tarfile 136 | 137 | if self._check_exists(): 138 | return 139 | 140 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz 141 | print('Downloading %s ... (may take a few minutes)' % self.url) 142 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir)) 143 | tar_name = self.url.rpartition('/')[-1] 144 | tar_path = os.path.join(parent_dir, tar_name) 145 | data = urllib.request.urlopen(self.url) 146 | 147 | # download .tar.gz file 148 | with open(tar_path, 'wb') as f: 149 | f.write(data.read()) 150 | 151 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b 152 | data_folder = tar_path.strip('.tar.gz') 153 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder)) 154 | tar = tarfile.open(tar_path) 155 | tar.extractall(parent_dir) 156 | 157 | # if necessary, rename data folder to self.root 158 | if not os.path.samefile(data_folder, self.root): 159 | print('Renaming %s to %s ...' % (data_folder, self.root)) 160 | os.rename(data_folder, self.root) 161 | 162 | # delete .tar.gz file 163 | print('Deleting %s ...' % tar_path) 164 | os.remove(tar_path) 165 | 166 | print('Done!') 167 | 168 | 169 | class FGVCAircraftDataProvider(DataProvider): 170 | 171 | def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32, 172 | resize_scale=0.08, distort_color=None, image_size=224, 173 | num_replicas=None, rank=None): 174 | 175 | warnings.filterwarnings('ignore') 176 | self._save_path = save_path 177 | 178 | self.image_size = image_size # int or list of int 179 | self.distort_color = distort_color 180 | self.resize_scale = resize_scale 181 | 182 | self._valid_transform_dict = {} 183 | if not isinstance(self.image_size, int): 184 | assert isinstance(self.image_size, list) 185 | from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader 186 | self.image_size.sort() # e.g., 160 -> 224 187 | MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy() 188 | MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size) 189 | 190 | for img_size in self.image_size: 191 | self._valid_transform_dict[img_size] = self.build_valid_transform(img_size) 192 | self.active_img_size = max(self.image_size) 193 | valid_transforms = self._valid_transform_dict[self.active_img_size] 194 | train_loader_class = MyDataLoader # randomly sample image size for each batch of training image 195 | else: 196 | self.active_img_size = self.image_size 197 | valid_transforms = self.build_valid_transform() 198 | train_loader_class = torch.utils.data.DataLoader 199 | 200 | train_transforms = self.build_train_transform() 201 | train_dataset = self.train_dataset(train_transforms) 202 | 203 | if valid_size is not None: 204 | if not isinstance(valid_size, int): 205 | assert isinstance(valid_size, float) and 0 < valid_size < 1 206 | valid_size = int(len(train_dataset.samples) * valid_size) 207 | 208 | valid_dataset = self.train_dataset(valid_transforms) 209 | train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size) 210 | 211 | if num_replicas is not None: 212 | train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes)) 213 | valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes)) 214 | else: 215 | train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes) 216 | valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes) 217 | 218 | self.train = train_loader_class( 219 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 220 | num_workers=n_worker, pin_memory=True, 221 | ) 222 | self.valid = torch.utils.data.DataLoader( 223 | valid_dataset, batch_size=test_batch_size, sampler=valid_sampler, 224 | num_workers=n_worker, pin_memory=True, 225 | ) 226 | else: 227 | if num_replicas is not None: 228 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank) 229 | self.train = train_loader_class( 230 | train_dataset, batch_size=train_batch_size, sampler=train_sampler, 231 | num_workers=n_worker, pin_memory=True 232 | ) 233 | else: 234 | self.train = train_loader_class( 235 | train_dataset, batch_size=train_batch_size, shuffle=True, 236 | num_workers=n_worker, pin_memory=True, 237 | ) 238 | self.valid = None 239 | 240 | test_dataset = self.test_dataset(valid_transforms) 241 | if num_replicas is not None: 242 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank) 243 | self.test = torch.utils.data.DataLoader( 244 | test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True, 245 | ) 246 | else: 247 | self.test = torch.utils.data.DataLoader( 248 | test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True, 249 | ) 250 | 251 | if self.valid is None: 252 | self.valid = self.test 253 | 254 | @staticmethod 255 | def name(): 256 | return 'aircraft' 257 | 258 | @property 259 | def data_shape(self): 260 | return 3, self.active_img_size, self.active_img_size # C, H, W 261 | 262 | @property 263 | def n_classes(self): 264 | return 100 265 | 266 | @property 267 | def save_path(self): 268 | if self._save_path is None: 269 | self._save_path = '/mnt/datastore/Aircraft' # home server 270 | 271 | if not os.path.exists(self._save_path): 272 | self._save_path = '/mnt/datastore/Aircraft' # home server 273 | return self._save_path 274 | 275 | @property 276 | def data_url(self): 277 | raise ValueError('unable to download %s' % self.name()) 278 | 279 | def train_dataset(self, _transforms): 280 | # dataset = datasets.ImageFolder(self.train_path, _transforms) 281 | dataset = FGVCAircraft( 282 | root=self.train_path, split='trainval', download=True, transform=_transforms) 283 | return dataset 284 | 285 | def test_dataset(self, _transforms): 286 | # dataset = datasets.ImageFolder(self.valid_path, _transforms) 287 | dataset = FGVCAircraft( 288 | root=self.valid_path, split='test', download=True, transform=_transforms) 289 | return dataset 290 | 291 | @property 292 | def train_path(self): 293 | return self.save_path 294 | 295 | @property 296 | def valid_path(self): 297 | return self.save_path 298 | 299 | @property 300 | def normalize(self): 301 | return transforms.Normalize( 302 | mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883], 303 | std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115]) 304 | 305 | def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'): 306 | if image_size is None: 307 | image_size = self.image_size 308 | # if print_log: 309 | # print('Color jitter: %s, resize_scale: %s, img_size: %s' % 310 | # (self.distort_color, self.resize_scale, image_size)) 311 | 312 | # if self.distort_color == 'torch': 313 | # color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) 314 | # elif self.distort_color == 'tf': 315 | # color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5) 316 | # else: 317 | # color_transform = None 318 | 319 | if isinstance(image_size, list): 320 | resize_transform_class = MyRandomResizedCrop 321 | print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(), 322 | 'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS)) 323 | img_size_min = min(image_size) 324 | else: 325 | resize_transform_class = transforms.RandomResizedCrop 326 | img_size_min = image_size 327 | 328 | train_transforms = [ 329 | resize_transform_class(image_size, scale=(self.resize_scale, 1.0)), 330 | transforms.RandomHorizontalFlip(), 331 | ] 332 | 333 | aa_params = dict( 334 | translate_const=int(img_size_min * 0.45), 335 | img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618, 336 | 0.5387914411673883]]), 337 | ) 338 | aa_params['interpolation'] = str_to_pil_interp('bicubic') 339 | train_transforms += [rand_augment_transform(auto_augment, aa_params)] 340 | 341 | # if color_transform is not None: 342 | # train_transforms.append(color_transform) 343 | train_transforms += [ 344 | transforms.ToTensor(), 345 | self.normalize, 346 | ] 347 | 348 | train_transforms = transforms.Compose(train_transforms) 349 | return train_transforms 350 | 351 | def build_valid_transform(self, image_size=None): 352 | if image_size is None: 353 | image_size = self.active_img_size 354 | return transforms.Compose([ 355 | transforms.Resize(int(math.ceil(image_size / 0.875))), 356 | transforms.CenterCrop(image_size), 357 | transforms.ToTensor(), 358 | self.normalize, 359 | ]) 360 | 361 | def assign_active_img_size(self, new_img_size): 362 | self.active_img_size = new_img_size 363 | if self.active_img_size not in self._valid_transform_dict: 364 | self._valid_transform_dict[self.active_img_size] = self.build_valid_transform() 365 | # change the transform of the valid and test set 366 | self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size] 367 | self.test.dataset.transform = self._valid_transform_dict[self.active_img_size] 368 | 369 | def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None): 370 | # used for resetting running statistics 371 | if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None: 372 | if num_worker is None: 373 | num_worker = self.train.num_workers 374 | 375 | n_samples = len(self.train.dataset.samples) 376 | g = torch.Generator() 377 | g.manual_seed(DataProvider.SUB_SEED) 378 | rand_indexes = torch.randperm(n_samples, generator=g).tolist() 379 | 380 | new_train_dataset = self.train_dataset( 381 | self.build_train_transform(image_size=self.active_img_size, print_log=False)) 382 | chosen_indexes = rand_indexes[:n_images] 383 | if num_replicas is not None: 384 | sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes)) 385 | else: 386 | sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes) 387 | sub_data_loader = torch.utils.data.DataLoader( 388 | new_train_dataset, batch_size=batch_size, sampler=sub_sampler, 389 | num_workers=num_worker, pin_memory=True, 390 | ) 391 | self.__dict__['sub_train_%d' % self.active_img_size] = [] 392 | for images, labels in sub_data_loader: 393 | self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels)) 394 | return self.__dict__['sub_train_%d' % self.active_img_size] 395 | 396 | 397 | if __name__ == '__main__': 398 | data = FGVCAircraft(root='/mnt/datastore/Aircraft', 399 | split='trainval', download=True) 400 | print(len(data.classes)) 401 | print(len(data.samples)) --------------------------------------------------------------------------------