├── .editorconfig ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE ├── README.md ├── benchmarks ├── evaluation │ ├── objective.py │ ├── train.py │ └── utils.py ├── objectives │ ├── NOTICE │ ├── addNIST.py │ ├── cifarTile.py │ ├── cifar_activation.py │ ├── custom_nb201 │ │ ├── DownsampledImageNet.py │ │ ├── cell_operations.py │ │ ├── config_utils.py │ │ ├── configs │ │ │ ├── CIFAR.config │ │ │ ├── ImageNet-16.config │ │ │ ├── ImageNet16-120-split.txt │ │ │ ├── LESS.config │ │ │ ├── cifar-split.txt │ │ │ ├── cifar100-test-split.txt │ │ │ └── imagenet-16-120-test-split.txt │ │ ├── custom_augmentations.py │ │ ├── evaluate_utils.py │ │ ├── genotypes.py │ │ ├── infer_cell.py │ │ └── tiny_network.py │ ├── darts_cnn.py │ ├── darts_utils │ │ ├── train.py │ │ └── train_search.py │ ├── hierarchical_nb201.py │ └── utils.py ├── search_spaces │ ├── activation_function_search │ │ ├── cifar_models │ │ │ ├── __init__.py │ │ │ └── resnet.py │ │ ├── grammar.cfg │ │ ├── graph.py │ │ ├── kvary_operations.py │ │ ├── stacking.py │ │ ├── topologies.py │ │ └── unary_operations.py │ ├── darts_cnn │ │ ├── cell.cfg │ │ ├── genotypes.py │ │ ├── graph.py │ │ ├── model.py │ │ ├── net2wider.py │ │ ├── operations.py │ │ ├── primitives.py │ │ ├── topologies.py │ │ ├── utils.py │ │ └── visualize.py │ └── hierarchical_nb201 │ │ ├── grammars │ │ ├── cell.cfg │ │ ├── cell_flexible.cfg │ │ ├── conv_block.cfg │ │ ├── macro.cfg │ │ └── macro_fixed_repetitive.cfg │ │ ├── graph.py │ │ ├── primitives.py │ │ └── topologies.py └── utils │ ├── objective.py │ ├── torch_error_message.py │ └── utils.py ├── experiments ├── darts_evaluate.py ├── darts_train_search.py ├── darts_utils │ ├── architect.py │ ├── cell_operations.py │ ├── genotypes.py │ ├── net2wider.py │ ├── search_cells.py │ ├── search_model.py │ ├── search_model_gdas.py │ └── utils.py ├── optimize.py ├── optimize_naswot.py ├── surrogate_regression.py ├── utils │ └── dataset_generation │ │ ├── NOTICE │ │ ├── data_packager.py │ │ ├── gen_cifartile.py │ │ ├── gen_gutenberg.py │ │ ├── gen_language_data.py │ │ ├── gen_multnist_data.py │ │ └── visualize_examples.py ├── utils_modelsummary.py ├── zero_cost_proxies_utils │ ├── NOTICE │ ├── __init__.py │ ├── epe_nas.py │ ├── fisher.py │ ├── grad_norm.py │ ├── grasp.py │ ├── jacov.py │ ├── l2_norm.py │ ├── model_stats.py │ ├── nwot.py │ ├── p_utils.py │ ├── plain.py │ ├── snip.py │ ├── synflow.py │ └── zen.py └── zero_cost_rank_correlation.py ├── install_dev_utils └── poetry.sh ├── poetry.lock └── pyproject.toml /.editorconfig: -------------------------------------------------------------------------------- 1 | # https://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | 11 | [*.py] 12 | max_line_length = 90 13 | indent_style = space 14 | indent_size = 4 15 | 16 | [*.tex] 17 | max_line_length = 120 18 | indent_style = space 19 | indent_size = 4 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | 4 | # slurm scripts 5 | slurm_scripts/* 6 | 7 | 8 | # IDE related 9 | .vscode/ 10 | .idea/ 11 | 12 | # Latex 13 | *.aux 14 | *.fdb_latexmk 15 | *.fls 16 | *.log 17 | *.synctex.gz 18 | *.dpth 19 | *.md5 20 | *.dep 21 | *.auxlock 22 | *.snm 23 | *.out 24 | *.nav 25 | *.toc 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: check-merge-conflict 7 | - id: end-of-file-fixer 8 | - id: debug-statements 9 | - id: check-yaml 10 | - id: check-toml 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 21.12b0 14 | hooks: 15 | - id: black 16 | 17 | - repo: https://github.com/PyCQA/isort 18 | rev: 5.10.1 19 | hooks: 20 | - id: isort 21 | 22 | - repo: https://github.com/asottile/pyupgrade 23 | rev: v2.29.1 24 | hooks: 25 | - id: pyupgrade 26 | args: [--py37-plus] 27 | 28 | - repo: https://github.com/executablebooks/mdformat 29 | rev: 0.7.11 30 | hooks: 31 | - id: mdformat 32 | additional_dependencies: 33 | - mdformat-gfm 34 | - mdformat-tables 35 | - mdformat-beautysh 36 | - mdformat-black 37 | 38 | - repo: https://github.com/PyCQA/pylint 39 | rev: v2.12.2 40 | hooks: 41 | - id: pylint 42 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you find our approach interesting for your own work, please cite the corresponding paper." 3 | authors: 4 | - family-names: Schrodi 5 | given-names: Simon 6 | - family-names: Stoll 7 | given-names: Danny 8 | - family-names: Ru 9 | given-names: Binxin 10 | - family-names: Sukthanker 11 | given-names: Rhea 12 | - family-names: Brox 13 | given-names: Thomas 14 | - family-names: Hutter 15 | given-names: Frank 16 | title: "Towards Discovering Neural Architectures from Scratch" 17 | version: 0.1.0 18 | date-released: 2022-11-03 19 | url: "https://github.com/automl/towards_nas_from_scratch" 20 | preferred-citation: 21 | type: misc 22 | doi: 10.48550/ARXIV.2211.01842 23 | url: "https://arxiv.org/abs/2211.01842" 24 | authors: 25 | - family-names: Schrodi 26 | given-names: Simon 27 | - family-names: Stoll 28 | given-names: Danny 29 | - family-names: Ru 30 | given-names: Binxin 31 | - family-names: Sukthanker 32 | given-names: Rhea 33 | - family-names: Brox 34 | given-names: Thomas 35 | - family-names: Hutter 36 | given-names: Frank 37 | keywords: "Machine Learning (cs.LG), Artificial Intelligence (cs.AI), Computer Vision and Pattern Recognition (cs.CV), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences" 38 | title: "Towards Discovering Neural Architectures from Scratch" 39 | publisher: arXiv 40 | year: 2022 41 | copyright: "arXiv.org perpetual, non-exclusive license" 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Simon Schrodi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmarks/evaluation/objective.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | 5 | 6 | class Objective: 7 | def __init__(self, seed: int, log_scale: bool, negative: bool = False) -> None: 8 | self.seed = seed 9 | self.log_scale = log_scale 10 | self.negative = negative 11 | 12 | @abstractmethod 13 | def __call__(self, config, **kwargs): 14 | raise NotImplementedError 15 | 16 | def set_seed(self, seed: int): 17 | self.seed = seed 18 | 19 | def transform(self, val): 20 | if self.log_scale: 21 | val = np.log(val + 1e-8) # avoid log(0) 22 | if self.negative: 23 | val *= -1 24 | return val 25 | 26 | def inv_transform(self, val): 27 | if self.negative: 28 | val *= -1 29 | if self.log_scale: 30 | val = np.exp(val) 31 | return val 32 | 33 | 34 | class ObjectiveWithAPI(Objective): 35 | def __init__(self, seed: int, log_scale: bool, negative: bool, api) -> None: 36 | super().__init__(seed, log_scale, negative) 37 | self.api = api 38 | 39 | def __call__(self, config, **kwargs): 40 | raise NotImplementedError 41 | -------------------------------------------------------------------------------- /benchmarks/objectives/addNIST.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from neps.search_spaces.search_space import SearchSpace 5 | 6 | from benchmarks.evaluation.objective import Objective 7 | from benchmarks.evaluation.train import training_pipeline 8 | from benchmarks.evaluation.utils import ( 9 | get_evaluation_metric, 10 | get_loss, 11 | get_optimizer, 12 | get_scheduler, 13 | get_train_val_test_loaders, 14 | ) 15 | 16 | 17 | class AddNISTObjective(Objective): 18 | dataset = "addNIST" 19 | n_epochs = 64 20 | batch_size = 64 21 | optim_kwargs = {"lr": 0.01, "momentum": 0.9, "weight_decay": 3e-4} 22 | num_classes = 20 23 | 24 | def __init__( 25 | self, 26 | data_path, 27 | seed, 28 | log_scale: bool = True, 29 | negative: bool = False, 30 | eval_mode: bool = False, 31 | ) -> None: 32 | super().__init__(seed, log_scale, negative) 33 | self.data_path = data_path 34 | self.failed_runs = 0 35 | 36 | self.eval_mode = eval_mode 37 | if self.eval_mode: 38 | self.n_epochs = 64 39 | 40 | def __call__(self, working_directory, previous_working_director, architecture, **hp): 41 | start = time.time() 42 | if isinstance(architecture, SearchSpace): 43 | model = architecture.hyperparameters["graph"].get_model_for_evaluation() 44 | for key in self.optim_kwargs: 45 | if key in architecture.hyperparameters: 46 | self.optim_kwargs[key] = architecture.hyperparameters[key].value 47 | elif hasattr(architecture, "get_model_for_evaluation"): 48 | model = architecture.get_model_for_evaluation() 49 | elif hasattr(architecture, "to_pytorch"): 50 | model = architecture.to_pytorch() 51 | elif isinstance(architecture, torch.nn.Module): # assumes to be a PyTorch model 52 | model = architecture 53 | else: 54 | raise NotImplementedError 55 | 56 | model.cuda() 57 | model.train() 58 | train_criterion = get_loss("CrossEntropyLoss") 59 | evaluation_metric = get_evaluation_metric("Accuracy", top_k=1) 60 | evaluation_metric.cuda() 61 | 62 | optimizer = get_optimizer("SGD", model, **self.optim_kwargs) 63 | scheduler = get_scheduler( 64 | scheduler="CosineAnnealingLR", optimizer=optimizer, T_max=self.n_epochs 65 | ) 66 | train_loader, valid_loader, test_loader = get_train_val_test_loaders( 67 | dataset=self.dataset, 68 | data=self.data_path, 69 | batch_size=self.batch_size, 70 | eval_mode=self.eval_mode, 71 | ) 72 | results = training_pipeline( 73 | model=model, 74 | train_criterion=train_criterion, 75 | evaluation_metric=evaluation_metric, 76 | optimizer=optimizer, 77 | scheduler=scheduler, 78 | train_loader=train_loader, 79 | valid_loader=valid_loader, 80 | test_loader=test_loader, 81 | n_epochs=self.n_epochs, 82 | eval_mode=self.eval_mode, 83 | ) 84 | try: 85 | if not self.eval_mode: 86 | val_err = 1 - results["val_scores"][-1] 87 | except Exception as e: 88 | print(e) 89 | val_err = 1.0 90 | self.failed_runs += 1 91 | if self.failed_runs > 10: 92 | raise Exception("Too many failed runs!") 93 | end = time.time() 94 | del model 95 | del train_criterion 96 | del evaluation_metric 97 | del optimizer 98 | del scheduler 99 | del train_loader 100 | del valid_loader 101 | del test_loader 102 | if torch.cuda.is_available(): 103 | torch.cuda.empty_cache() 104 | if self.eval_mode: 105 | results["train_time"] = end - start 106 | return results 107 | if isinstance(architecture, SearchSpace): 108 | return { 109 | "loss": self.transform(val_err), 110 | "info_dict": { 111 | "val_score": val_err, 112 | "val_scores": results["val_scores"], 113 | "test_score": 1 - results["test_scores"][-1], 114 | "test_scores": results["test_scores"], 115 | "train_time": end - start, 116 | "timestamp": end, 117 | }, 118 | } 119 | return { 120 | "loss": self.transform(val_err), 121 | "info_dict": { 122 | "val_score": val_err, 123 | "val_scores": results["val_scores"], 124 | "test_score": 1 - results["test_scores"][-1], 125 | "test_scores": results["test_scores"], 126 | "train_time": end - start, 127 | "timestamp": end, 128 | }, 129 | } 130 | 131 | def get_train_loader(self): 132 | train_loader, _, _ = get_train_val_test_loaders( 133 | dataset=self.dataset, 134 | data=self.data_path, 135 | batch_size=self.batch_size, 136 | eval_mode=self.eval_mode, 137 | ) 138 | return train_loader 139 | -------------------------------------------------------------------------------- /benchmarks/objectives/cifarTile.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from benchmarks.evaluation.objective import Objective 8 | from benchmarks.evaluation.train import training_pipeline 9 | from benchmarks.evaluation.utils import ( 10 | get_evaluation_metric, 11 | get_loss, 12 | get_optimizer, 13 | get_scheduler, 14 | get_train_val_test_loaders, 15 | ) 16 | 17 | 18 | def prepare_seed(rand_seed: int, workers: int = 4): 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.deterministic = True 21 | # torch.backends.cudnn.benchmark = True 22 | torch.set_num_threads(workers) 23 | random.seed(rand_seed) 24 | np.random.seed(rand_seed) 25 | torch.manual_seed(rand_seed) 26 | torch.cuda.manual_seed(rand_seed) 27 | torch.cuda.manual_seed_all(rand_seed) 28 | 29 | 30 | class CifarTileObjective(Objective): 31 | dataset = "cifarTile" 32 | batch_size = 64 33 | optim_kwargs = {"lr": 0.01, "momentum": 0.9, "weight_decay": 3e-4} 34 | n_epochs = 64 35 | workers = 2 36 | num_classes = 4 37 | 38 | def __init__( 39 | self, 40 | data_path, 41 | seed: int, 42 | log_scale: bool = True, 43 | negative: bool = False, 44 | eval_mode: bool = False, 45 | ) -> None: 46 | super().__init__(seed, log_scale, negative) 47 | self.data_path = data_path 48 | self.eval_mode = eval_mode 49 | 50 | def __call__(self, working_directory, previous_working_director, architecture, **kwargs): 51 | start = time.time() 52 | 53 | prepare_seed(self.seed, self.workers) 54 | 55 | if hasattr(architecture, "to_pytorch"): 56 | model = architecture.to_pytorch() 57 | elif isinstance(architecture, torch.nn.Module): 58 | model = architecture 59 | else: 60 | raise NotImplementedError 61 | 62 | model.cuda() 63 | model.train() 64 | train_criterion = get_loss("CrossEntropyLoss") 65 | evaluation_metric = get_evaluation_metric("Accuracy", top_k=1) 66 | evaluation_metric.cuda() 67 | 68 | optimizer = get_optimizer("SGD", model, **self.optim_kwargs) 69 | scheduler = get_scheduler( 70 | scheduler="CosineAnnealingLR", optimizer=optimizer, T_max=self.n_epochs 71 | ) 72 | train_loader, valid_loader, test_loader = get_train_val_test_loaders( 73 | dataset=self.dataset, 74 | data=self.data_path, 75 | batch_size=self.batch_size, 76 | eval_mode=self.eval_mode, 77 | ) 78 | results = training_pipeline( 79 | model=model, 80 | train_criterion=train_criterion, 81 | evaluation_metric=evaluation_metric, 82 | optimizer=optimizer, 83 | scheduler=scheduler, 84 | train_loader=train_loader, 85 | valid_loader=valid_loader, 86 | test_loader=test_loader, 87 | n_epochs=self.n_epochs, 88 | eval_mode=self.eval_mode, 89 | ) 90 | end = time.time() 91 | del model 92 | del train_criterion 93 | del evaluation_metric 94 | del optimizer 95 | del scheduler 96 | del train_loader 97 | del valid_loader 98 | del test_loader 99 | if torch.cuda.is_available(): 100 | torch.cuda.empty_cache() 101 | 102 | val_err = 1 - results["val_scores"][-1] 103 | return { 104 | "loss": self.transform(val_err), 105 | "info_dict": { 106 | "val_score": val_err, 107 | "val_scores": results["val_scores"], 108 | "test_score": 1 - results["test_scores"][-1], 109 | "test_scores": results["test_scores"], 110 | "train_time": end - start, 111 | "timestamp": end, 112 | }, 113 | } 114 | 115 | def get_train_loader(self): 116 | train_loader, _, _ = get_train_val_test_loaders( 117 | dataset=self.dataset, 118 | data=self.data_path, 119 | batch_size=self.batch_size, 120 | eval_mode=self.eval_mode, 121 | ) 122 | return train_loader 123 | 124 | if __name__ == "__main__": 125 | import argparse 126 | 127 | # pylint: disable=ungrouped-imports 128 | from neps.search_spaces.search_space import SearchSpace 129 | 130 | from benchmarks.search_spaces.hierarchical_nb201.graph import ( 131 | NB201Spaces, 132 | ) 133 | 134 | # pylint: enable=ungrouped-imports 135 | 136 | parser = argparse.ArgumentParser(description="Train CifarTile") 137 | parser.add_argument( 138 | "--data_path", 139 | help="Path to folder with data or where data should be saved to if downloaded.", 140 | type=str, 141 | ) 142 | parser.add_argument( 143 | "--dataset", 144 | default="cifarTile", 145 | type=str, 146 | ) 147 | parser.add_argument("--seed", default=777, type=int) 148 | args = parser.parse_args() 149 | 150 | pipeline_space = dict( 151 | architecture=NB201Spaces( 152 | space="variable_multi_multi", dataset=args.dataset, adjust_params=False 153 | ) 154 | ) 155 | pipeline_space = SearchSpace(**pipeline_space) 156 | pipeline_space = pipeline_space.sample() 157 | 158 | run_pipeline_fn = CifarTileObjective(data_path=args.data_path, seed=args.seed) 159 | res = run_pipeline_fn(architecture=pipeline_space.hyperparameters["architecture"]) 160 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/DownsampledImageNet.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import hashlib 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | import torch.utils.data as data 10 | from PIL import Image 11 | 12 | if sys.version_info[0] == 2: 13 | import cPickle as pickle 14 | else: 15 | import pickle 16 | 17 | 18 | def calculate_md5(fpath, chunk_size=1024 * 1024): 19 | md5 = hashlib.md5() 20 | with open(fpath, "rb") as f: 21 | for chunk in iter(lambda: f.read(chunk_size), b""): 22 | md5.update(chunk) 23 | return md5.hexdigest() 24 | 25 | 26 | def check_md5(fpath, md5, **kwargs): 27 | return md5 == calculate_md5(fpath, **kwargs) 28 | 29 | 30 | def check_integrity(fpath, md5=None): 31 | if not os.path.isfile(fpath): 32 | return False 33 | if md5 is None: 34 | return True 35 | else: 36 | return check_md5(fpath, md5) 37 | 38 | 39 | class ImageNet16(data.Dataset): 40 | # http://image-net.org/download-images 41 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets 42 | # https://arxiv.org/pdf/1707.08819.pdf 43 | 44 | train_list = [ 45 | ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"], 46 | ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"], 47 | ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"], 48 | ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"], 49 | ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"], 50 | ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"], 51 | ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"], 52 | ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"], 53 | ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"], 54 | ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"], 55 | ] 56 | valid_list = [ 57 | ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"], 58 | ] 59 | 60 | def __init__(self, root, train, transform, use_num_of_class_only=None): 61 | self.root = os.path.join(root, "ImageNet16") 62 | self.transform = transform 63 | self.train = train # training set or valid set 64 | if not self._check_integrity(): 65 | raise RuntimeError("Dataset not found or corrupted.") 66 | 67 | if self.train: 68 | downloaded_list = self.train_list 69 | else: 70 | downloaded_list = self.valid_list 71 | self.data = [] 72 | self.targets = [] 73 | 74 | # now load the picked numpy arrays 75 | for _, (file_name, _) in enumerate(downloaded_list): 76 | file_path = os.path.join(self.root, file_name) 77 | with open(file_path, "rb") as f: 78 | if sys.version_info[0] == 2: 79 | entry = pickle.load(f) 80 | else: 81 | entry = pickle.load(f, encoding="latin1") 82 | self.data.append(entry["data"]) 83 | self.targets.extend(entry["labels"]) 84 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) 85 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 86 | if use_num_of_class_only is not None: 87 | assert ( 88 | isinstance(use_num_of_class_only, int) 89 | and use_num_of_class_only > 0 90 | and use_num_of_class_only < 1000 91 | ), f"invalid use_num_of_class_only : {use_num_of_class_only}" 92 | new_data, new_targets = [], [] 93 | for I, L in zip(self.data, self.targets): 94 | if 1 <= L <= use_num_of_class_only: 95 | new_data.append(I) 96 | new_targets.append(L) 97 | self.data = new_data 98 | self.targets = new_targets 99 | 100 | def __repr__(self): 101 | return "{name}({num} images, {classes} classes)".format( 102 | name=self.__class__.__name__, 103 | num=len(self.data), 104 | classes=len(set(self.targets)), 105 | ) 106 | 107 | def __getitem__(self, index): 108 | img, target = self.data[index], self.targets[index] - 1 109 | 110 | img = Image.fromarray(img) 111 | 112 | if self.transform is not None: 113 | img = self.transform(img) 114 | 115 | return img, target 116 | 117 | def __len__(self): 118 | return len(self.data) 119 | 120 | def _check_integrity(self): 121 | root = self.root 122 | for fentry in self.train_list + self.valid_list: 123 | filename, md5 = fentry[0], fentry[1] 124 | fpath = os.path.join(root, filename) 125 | if not check_integrity(fpath, md5): 126 | return False 127 | return True 128 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/config_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import namedtuple 4 | 5 | support_types = ("str", "int", "bool", "float", "none") 6 | 7 | 8 | def convert_param(original_lists): 9 | assert isinstance(original_lists, list), "The type is not right : {:}".format( 10 | original_lists 11 | ) 12 | ctype, value = original_lists[0], original_lists[1] 13 | assert ctype in support_types, f"Ctype={ctype}, support={support_types}" 14 | is_list = isinstance(value, list) 15 | if not is_list: 16 | value = [value] 17 | outs = [] 18 | for x in value: 19 | if ctype == "int": 20 | x = int(x) 21 | elif ctype == "str": 22 | x = str(x) 23 | elif ctype == "bool": 24 | x = bool(int(x)) 25 | elif ctype == "float": 26 | x = float(x) 27 | elif ctype == "none": 28 | if x.lower() != "none": 29 | raise ValueError( 30 | f"For the none type, the value must be none instead of {x}" 31 | ) 32 | x = None 33 | else: 34 | raise TypeError(f"Does not know this type : {ctype}") 35 | outs.append(x) 36 | if not is_list: 37 | outs = outs[0] 38 | return outs 39 | 40 | 41 | def load_config(path, extra): 42 | path = str(path) 43 | assert os.path.exists(path), f"Can not find {path}" 44 | # Reading data back 45 | with open(path, encoding="utf-8") as f: 46 | data = json.load(f) 47 | content = {k: convert_param(v) for k, v in data.items()} 48 | assert extra is None or isinstance(extra, dict), "invalid type of extra : {:}".format( 49 | extra 50 | ) 51 | if isinstance(extra, dict): 52 | content = {**content, **extra} 53 | Arguments = namedtuple("Configure", " ".join(content.keys())) 54 | content = Arguments(**content) 55 | return content 56 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/configs/CIFAR.config: -------------------------------------------------------------------------------- 1 | { 2 | "scheduler": ["str", "cos"], 3 | "eta_min" : ["float", "0.0"], 4 | "epochs" : ["int", "200"], 5 | "warmup" : ["int", "0"], 6 | "optim" : ["str", "SGD"], 7 | "LR" : ["float", "0.1"], 8 | "decay" : ["float", "0.0005"], 9 | "momentum" : ["float", "0.9"], 10 | "nesterov" : ["bool", "1"], 11 | "criterion": ["str", "Softmax"], 12 | "batch_size": ["int", "256"] 13 | } 14 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/configs/ImageNet-16.config: -------------------------------------------------------------------------------- 1 | { 2 | "scheduler": ["str", "cos"], 3 | "eta_min" : ["float", "0.0"], 4 | "epochs" : ["int", "200"], 5 | "warmup" : ["int", "0"], 6 | "optim" : ["str", "SGD"], 7 | "LR" : ["float", "0.1"], 8 | "decay" : ["float", "0.0005"], 9 | "momentum" : ["float", "0.9"], 10 | "nesterov" : ["bool", "1"], 11 | "criterion": ["str", "Softmax"], 12 | "batch_size": ["int", "256"] 13 | } 14 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/configs/LESS.config: -------------------------------------------------------------------------------- 1 | { 2 | "scheduler": ["str", "cos"], 3 | "eta_min" : ["float", "0.0"], 4 | "epochs" : ["int", "12"], 5 | "warmup" : ["int", "0"], 6 | "optim" : ["str", "SGD"], 7 | "LR" : ["float", "0.1"], 8 | "decay" : ["float", "0.0005"], 9 | "momentum" : ["float", "0.9"], 10 | "nesterov" : ["bool", "1"], 11 | "criterion": ["str", "Softmax"], 12 | "batch_size": ["int", "256"] 13 | } 14 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/custom_augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | class CUTOUT: 7 | def __init__(self, length): 8 | self.length = length 9 | 10 | def __repr__(self): 11 | return "{name}(length={length})".format( 12 | name=self.__class__.__name__, **self.__dict__ 13 | ) 14 | 15 | def __call__(self, img): 16 | h, w = img.size(1), img.size(2) 17 | mask = np.ones((h, w), np.float32) 18 | y = np.random.randint(h) 19 | x = np.random.randint(w) 20 | 21 | y1 = np.clip(y - self.length // 2, 0, h) 22 | y2 = np.clip(y + self.length // 2, 0, h) 23 | x1 = np.clip(x - self.length // 2, 0, w) 24 | x2 = np.clip(x + self.length // 2, 0, w) 25 | 26 | mask[y1:y2, x1:x2] = 0.0 27 | mask = torch.from_numpy(mask) 28 | mask = mask.expand_as(img) 29 | img *= mask 30 | return img 31 | 32 | 33 | imagenet_pca = { 34 | "eigval": np.asarray([0.2175, 0.0188, 0.0045]), 35 | "eigvec": np.asarray( 36 | [ 37 | [-0.5675, 0.7192, 0.4009], 38 | [-0.5808, -0.0045, -0.8140], 39 | [-0.5836, -0.6948, 0.4203], 40 | ] 41 | ), 42 | } 43 | 44 | 45 | class Lighting: 46 | def __init__( 47 | self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"] 48 | ): 49 | self.alphastd = alphastd 50 | assert eigval.shape == (3,) 51 | assert eigvec.shape == (3, 3) 52 | self.eigval = eigval 53 | self.eigvec = eigvec 54 | 55 | def __call__(self, img): 56 | if self.alphastd == 0.0: 57 | return img 58 | rnd = np.random.randn(3) * self.alphastd 59 | rnd = rnd.astype("float32") 60 | v = rnd 61 | old_dtype = np.asarray(img).dtype 62 | v = v * self.eigval 63 | v = v.reshape((3, 1)) 64 | inc = np.dot(self.eigvec, v).reshape((3,)) 65 | img = np.add(img, inc) 66 | if old_dtype == np.uint8: 67 | img = np.clip(img, 0, 255) 68 | img = Image.fromarray(img.astype(old_dtype), "RGB") 69 | return img 70 | 71 | def __repr__(self): 72 | return self.__class__.__name__ + "()" 73 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/infer_cell.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | 5 | from copy import deepcopy 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .cell_operations import OPS 11 | 12 | 13 | # Cell for NAS-Bench-201 14 | class InferCell(nn.Module): 15 | def __init__( 16 | self, genotype, C_in, C_out, stride, affine=True, track_running_stats=True 17 | ): 18 | super().__init__() 19 | 20 | self.layers = nn.ModuleList() 21 | self.node_IN = [] 22 | self.node_IX = [] 23 | self.genotype = deepcopy(genotype) 24 | for i in range(1, len(genotype)): 25 | node_info = genotype[i - 1] 26 | cur_index = [] 27 | cur_innod = [] 28 | for (op_name, op_in) in node_info: 29 | if op_in == 0: 30 | layer = OPS[op_name](C_in, C_out, stride, affine, track_running_stats) 31 | else: 32 | layer = OPS[op_name](C_out, C_out, 1, affine, track_running_stats) 33 | cur_index.append(len(self.layers)) 34 | cur_innod.append(op_in) 35 | self.layers.append(layer) 36 | self.node_IX.append(cur_index) 37 | self.node_IN.append(cur_innod) 38 | self.nodes = len(genotype) 39 | self.in_dim = C_in 40 | self.out_dim = C_out 41 | 42 | def extra_repr(self): 43 | string = "info :: nodes={nodes}, inC={in_dim}, outC={out_dim}".format( 44 | **self.__dict__ 45 | ) 46 | laystr = [] 47 | for i, (node_layers, node_innods) in enumerate(zip(self.node_IX, self.node_IN)): 48 | y = [f"I{_ii}-L{_il}" for _il, _ii in zip(node_layers, node_innods)] 49 | x = "{:}<-({:})".format(i + 1, ",".join(y)) 50 | laystr.append(x) 51 | return ( 52 | string + ", [{:}]".format(" | ".join(laystr)) + f", {self.genotype.tostr()}" 53 | ) 54 | 55 | def forward(self, inputs): 56 | nodes = [inputs] 57 | for node_layers, node_innods in zip(self.node_IX, self.node_IN): 58 | node_feature = sum( 59 | self.layers[_il](nodes[_ii]) for _il, _ii in zip(node_layers, node_innods) 60 | ) 61 | nodes.append(node_feature) 62 | return nodes[-1] 63 | 64 | 65 | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 66 | class NASNetInferCell(nn.Module): 67 | def __init__( 68 | self, 69 | genotype, 70 | C_prev_prev, 71 | C_prev, 72 | C, 73 | reduction, 74 | reduction_prev, 75 | affine, 76 | track_running_stats, 77 | ): 78 | super().__init__() 79 | self.reduction = reduction 80 | if reduction_prev: 81 | self.preprocess0 = OPS["skip_connect"]( 82 | C_prev_prev, C, 2, affine, track_running_stats 83 | ) 84 | else: 85 | self.preprocess0 = OPS["nor_conv_1x1"]( 86 | C_prev_prev, C, 1, affine, track_running_stats 87 | ) 88 | self.preprocess1 = OPS["nor_conv_1x1"](C_prev, C, 1, affine, track_running_stats) 89 | 90 | if not reduction: 91 | nodes, concats = genotype["normal"], genotype["normal_concat"] 92 | else: 93 | nodes, concats = genotype["reduce"], genotype["reduce_concat"] 94 | self._multiplier = len(concats) 95 | self._concats = concats 96 | self._steps = len(nodes) 97 | self._nodes = nodes 98 | self.edges = nn.ModuleDict() 99 | for i, node in enumerate(nodes): 100 | for in_node in node: 101 | name, j = in_node[0], in_node[1] 102 | stride = 2 if reduction and j < 2 else 1 103 | node_str = f"{i + 2}<-{j}" 104 | self.edges[node_str] = OPS[name]( 105 | C, C, stride, affine, track_running_stats 106 | ) 107 | 108 | # [TODO] to support drop_prob in this function.. 109 | def forward(self, s0, s1, unused_drop_prob): 110 | s0 = self.preprocess0(s0) 111 | s1 = self.preprocess1(s1) 112 | 113 | states = [s0, s1] 114 | for i, node in enumerate(self._nodes): 115 | clist = [] 116 | for in_node in node: 117 | _, j = in_node[0], in_node[1] 118 | node_str = f"{i + 2}<-{j}" 119 | op = self.edges[node_str] 120 | clist.append(op(states[j])) 121 | states.append(sum(clist)) 122 | return torch.cat([states[x] for x in self._concats], dim=1) 123 | 124 | 125 | class AuxiliaryHeadCIFAR(nn.Module): 126 | def __init__(self, C, num_classes): 127 | """assuming input size 8x8""" 128 | super().__init__() 129 | self.features = nn.Sequential( 130 | nn.ReLU(inplace=True), 131 | nn.AvgPool2d( 132 | 5, stride=3, padding=0, count_include_pad=False 133 | ), # image size = 2 x 2 134 | nn.Conv2d(C, 128, 1, bias=False), 135 | nn.BatchNorm2d(128), 136 | nn.ReLU(inplace=True), 137 | nn.Conv2d(128, 768, 2, bias=False), 138 | nn.BatchNorm2d(768), 139 | nn.ReLU(inplace=True), 140 | ) 141 | self.classifier = nn.Linear(768, num_classes) 142 | 143 | def forward(self, x): 144 | x = self.features(x) 145 | x = self.classifier(x.view(x.size(0), -1)) 146 | return x 147 | -------------------------------------------------------------------------------- /benchmarks/objectives/custom_nb201/tiny_network.py: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # 3 | ##################################################### 4 | import torch.nn as nn 5 | 6 | from .cell_operations import ResNetBasicblock 7 | from .infer_cell import InferCell 8 | 9 | 10 | # The macro structure for architectures in NAS-Bench-201 11 | class TinyNetwork(nn.Module): 12 | def __init__(self, C, N, genotype, num_classes): 13 | super().__init__() 14 | self._C = C 15 | self._layerN = N 16 | 17 | self.stem = nn.Sequential( 18 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) 19 | ) 20 | 21 | layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N 22 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 23 | 24 | C_prev = C 25 | self.cells = nn.ModuleList() 26 | for C_curr, reduction in zip(layer_channels, layer_reductions): 27 | if reduction: 28 | cell = ResNetBasicblock(C_prev, C_curr, 2, True) 29 | else: 30 | cell = InferCell(genotype, C_prev, C_curr, 1) 31 | self.cells.append(cell) 32 | C_prev = cell.out_dim 33 | self._Layer = len(self.cells) 34 | 35 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 36 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 37 | self.classifier = nn.Linear(C_prev, num_classes) 38 | 39 | def get_message(self): 40 | string = self.extra_repr() 41 | for i, cell in enumerate(self.cells): 42 | string += "\n {:02d}/{:02d} :: {:}".format( 43 | i, len(self.cells), cell.extra_repr() 44 | ) 45 | return string 46 | 47 | def extra_repr(self): 48 | return "{name}(C={_C}, N={_layerN}, L={_Layer})".format( 49 | name=self.__class__.__name__, **self.__dict__ 50 | ) 51 | 52 | def forward(self, inputs): 53 | feature = self.stem(inputs) 54 | for cell in self.cells: 55 | feature = cell(feature) 56 | 57 | out = self.lastact(feature) 58 | out = self.global_pooling(out) 59 | out = out.view(out.size(0), -1) 60 | logits = self.classifier(out) 61 | 62 | return logits 63 | -------------------------------------------------------------------------------- /benchmarks/objectives/darts_cnn.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch 6 | from path import Path 7 | 8 | from benchmarks.evaluation.objective import Objective 9 | from benchmarks.objectives.darts_utils.train import train_evaluation 10 | from benchmarks.objectives.darts_utils.train_search import train_search 11 | from benchmarks.search_spaces.darts_cnn.genotypes import Genotype 12 | 13 | 14 | class DARTSCnn(Objective): 15 | def __init__( 16 | self, 17 | data_path: Union[str, Path], 18 | eval_policy: str = "last5", 19 | seed: int = 777, 20 | log_scale: bool = True, 21 | negative: bool = False, 22 | eval_mode: bool = False, 23 | ) -> None: 24 | super().__init__(seed, log_scale, negative) 25 | assert eval_policy in ["best", "last", "last5"] 26 | 27 | self.data_path = data_path 28 | self.eval_policy = eval_policy 29 | self.eval_mode = eval_mode 30 | 31 | def __call__(self, working_directory, previous_working_directory, normal, reduce): 32 | if torch.cuda.is_available(): 33 | torch.cuda.empty_cache() 34 | 35 | if hasattr(normal, "to_pytorch"): 36 | normal = normal.to_pytorch() 37 | else: 38 | raise NotImplementedError 39 | if hasattr(reduce, "to_pytorch"): 40 | reduce = reduce.to_pytorch() 41 | else: 42 | raise NotImplementedError 43 | 44 | genotype = Genotype( 45 | normal=normal, 46 | normal_concat=range(2, 6), 47 | reduce=reduce, 48 | reduce_concat=range(2, 6), 49 | ) 50 | start = time.time() 51 | if self.eval_mode: 52 | valid_accs = train_evaluation( 53 | genotype=genotype, 54 | data=self.data_path, 55 | seed=self.seed, 56 | save_path=working_directory, 57 | ) 58 | else: 59 | valid_accs = train_search( 60 | genotype=genotype, 61 | data=self.data_path, 62 | seed=self.seed, 63 | save_path=working_directory, 64 | ) 65 | end = time.time() 66 | 67 | if "best" == self.eval_policy: 68 | val_error = 1 - max(valid_accs) / 100 69 | elif "last" == self.eval_policy: 70 | val_error = 1 - valid_accs[-1] / 100 71 | elif "last5" == self.eval_policy: 72 | val_error = 1 - np.mean(valid_accs[-5:]) / 100 73 | 74 | return { 75 | "loss": self.transform(val_error), 76 | "info_dict": { 77 | "accs": valid_accs, 78 | "best_acc": max(valid_accs), 79 | "last_acc": valid_accs[-1], 80 | "time": end - start, 81 | "timestamp": end, 82 | }, 83 | } 84 | 85 | 86 | if __name__ == "__main__": 87 | import argparse 88 | import json 89 | import os 90 | import shutil 91 | 92 | import yaml 93 | from neps.search_spaces.search_space import SearchSpace 94 | 95 | # pylint: disable=ungrouped-imports 96 | from benchmarks.search_spaces.darts_cnn.graph import DARTSSpace 97 | 98 | # pylint: enable=ungrouped-imports 99 | 100 | parser = argparse.ArgumentParser(description="Train DARTS") 101 | parser.add_argument( 102 | "--data", 103 | default="", 104 | help="Path to folder with data or where data should be saved to if downloaded.", 105 | type=str, 106 | ) 107 | parser.add_argument( 108 | "--save", 109 | default="", 110 | type=str, 111 | ) 112 | parser.add_argument( 113 | "--arch", default="ours", type=str, choices=["ours", "drnas", "nasbowl"] 114 | ) 115 | parser.add_argument("--seed", default=777, type=int) 116 | parser.add_argument("--eval", action="store_true") 117 | args = parser.parse_args() 118 | 119 | pipeline_space = dict( 120 | normal=DARTSSpace(), 121 | reduce=DARTSSpace(), 122 | ) 123 | pipeline_space = SearchSpace(**pipeline_space) 124 | 125 | run_pipeline_fn = DARTSCnn(data_path=args.data, seed=args.seed, eval_mode=args.eval) 126 | 127 | if args.arch == "drnas": 128 | # DrNAS cells 129 | pipeline_space.load_from( 130 | { 131 | "normal": "(CELL DARTS (OP sep_conv_3x3) (IN1 0) (OP sep_conv_5x5) (IN1 1) (OP sep_conv_3x3) (IN2 1) (OP sep_conv_3x3) (IN2 2) (OP skip_connect) (IN3 0) (OP sep_conv_3x3) (IN3 1) (OP sep_conv_3x3) (IN4 2) (OP dil_conv_5x5) (IN4 3))", 132 | "reduce": "(CELL DARTS (OP max_pool_3x3) (IN1 0) (OP sep_conv_5x5) (IN1 1) (OP dil_conv_5x5) (IN2 2) (OP sep_conv_5x5) (IN2 1) (OP sep_conv_5x5) (IN3 1) (OP dil_conv_5x5) (IN3 3) (OP skip_connect) (IN4 4) (OP sep_conv_5x5) (IN4 1))", 133 | } 134 | ) 135 | working_directory = Path(args.save) / "drnas" 136 | elif args.arch == "nasbowl": 137 | # nasbowl cells 138 | pipeline_space.load_from( 139 | { 140 | "normal": "(CELL DARTS (OP skip_connect) (IN1 1) (OP sep_conv_3x3) (IN1 0) (OP sep_conv_3x3) (IN2 1) (OP max_pool_3x3) (IN2 0) (OP sep_conv_5x5) (IN3 1) (OP sep_conv_3x3) (IN3 0) (OP dil_conv_5x5) (IN4 2) (OP sep_conv_3x3) (IN4 1))", 141 | "reduce": "(CELL DARTS (OP skip_connect) (IN1 1) (OP sep_conv_3x3) (IN1 0) (OP sep_conv_3x3) (IN2 1) (OP max_pool_3x3) (IN2 0) (OP sep_conv_5x5) (IN3 1) (OP sep_conv_3x3) (IN3 0) (OP dil_conv_5x5) (IN4 2) (OP sep_conv_3x3) (IN4 1))", 142 | } 143 | ) 144 | working_directory = Path(args.save) / "bananas" 145 | elif args.arch == "ours": 146 | args.save = Path(args.save) 147 | results_dir = args.save / "results" 148 | assert os.path.isdir(results_dir) 149 | config_loss_dict = {} 150 | for config_number in os.listdir(results_dir): 151 | results_yaml = results_dir / config_number / "result.yaml" 152 | if os.path.isfile(results_yaml): 153 | with open(results_yaml) as f: 154 | data = yaml.safe_load(f) 155 | config_loss_dict[config_number] = data["loss"] 156 | best_config = min(config_loss_dict, key=config_loss_dict.get) 157 | config_yaml = results_dir / best_config / "config.yaml" 158 | with open(config_yaml) as f: 159 | identifier = yaml.safe_load(f) 160 | pipeline_space.load_from(identifier) 161 | working_directory = args.save / f"best_config_eseed_{args.seed}" 162 | working_directory.makedirs_p() 163 | shutil.copyfile(config_yaml, working_directory / "config.yaml") 164 | else: 165 | raise NotImplementedError 166 | working_directory.makedirs_p() 167 | 168 | res = run_pipeline_fn( 169 | working_directory, 170 | "", 171 | pipeline_space.hyperparameters["normal"], 172 | pipeline_space.hyperparameters["reduce"], 173 | ) 174 | print(args.arch, res) 175 | 176 | with open( 177 | working_directory / "results.json" 178 | if args.arch == "ours" 179 | else f"{args.arch}.json", 180 | "w", 181 | ) as f: 182 | json.dump(res, f, indent=4) 183 | -------------------------------------------------------------------------------- /benchmarks/objectives/darts_utils/train_search.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | import torch.utils 10 | import torchvision.datasets as dset 11 | 12 | import benchmarks.search_spaces.darts_cnn.utils as utils 13 | 14 | # from hierarchical_nas_benchmarks.search_spaces.darts_cnn.model_search import Network 15 | from benchmarks.search_spaces.darts_cnn.model import ( 16 | NetworkCIFAR as Network, 17 | ) 18 | 19 | TORCH_VERSION = torch.__version__ 20 | 21 | 22 | def train_search(genotype, data, seed, save_path): 23 | dataset = "cifar10" 24 | epochs = 50 25 | batch_size = 64 26 | learning_rate = 0.025 # 0.1 27 | learning_rate_min = 0.0 28 | momentum = 0.9 29 | weight_decay = 3e-4 30 | report_freq = 50 31 | gpu = 0 32 | init_channels = 16 # 36 33 | layers = 8 # 20 34 | cutout = False 35 | cutout_length = 16 36 | # drop_path_prob = 0.3 37 | grad_clip = 5 38 | train_portion = 0.5 39 | auxiliary = False 40 | # auxiliary_weight = 0.4 41 | 42 | log_format = "%(asctime)s %(message)s" 43 | logging.basicConfig( 44 | stream=sys.stdout, 45 | level=logging.INFO, 46 | format=log_format, 47 | datefmt="%m/%d %I:%M:%S %p", 48 | ) 49 | fh = logging.FileHandler(os.path.join(save_path, "log.txt")) 50 | fh.setFormatter(logging.Formatter(log_format)) 51 | logging.getLogger().addHandler(fh) 52 | 53 | CIFAR_CLASSES = 10 54 | if dataset == "cifar100": 55 | CIFAR_CLASSES = 100 56 | 57 | if not torch.cuda.is_available(): 58 | logging.info("no gpu device available") 59 | sys.exit(1) 60 | 61 | np.random.seed(seed) 62 | torch.cuda.set_device(gpu) 63 | cudnn.benchmark = True 64 | torch.manual_seed(seed) 65 | cudnn.enabled = True 66 | torch.cuda.manual_seed(seed) 67 | logging.info("gpu device = %d" % gpu) 68 | 69 | criterion = nn.CrossEntropyLoss() 70 | criterion = criterion.cuda() 71 | model = Network( 72 | init_channels, CIFAR_CLASSES, layers, auxiliary=auxiliary, genotype=genotype 73 | ) 74 | model = model.cuda() 75 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 76 | 77 | optimizer = torch.optim.SGD( 78 | model.parameters(), learning_rate, momentum=momentum, weight_decay=weight_decay 79 | ) 80 | 81 | ( 82 | train_transform, 83 | _, 84 | ) = utils._data_transforms_cifar10( # pylint: disable=protected-access 85 | cutout, cutout_length 86 | ) 87 | if dataset == "cifar100": 88 | train_data = dset.CIFAR100( 89 | root=data, train=True, download=True, transform=train_transform 90 | ) 91 | else: 92 | train_data = dset.CIFAR10( 93 | root=data, train=True, download=True, transform=train_transform 94 | ) 95 | 96 | num_train = len(train_data) 97 | indices = list(range(num_train)) 98 | split = int(np.floor(train_portion * num_train)) 99 | 100 | train_queue = torch.utils.data.DataLoader( 101 | train_data, 102 | batch_size=batch_size, 103 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), 104 | pin_memory=True, 105 | ) 106 | 107 | valid_queue = torch.utils.data.DataLoader( 108 | train_data, 109 | batch_size=batch_size, 110 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]), 111 | pin_memory=True, 112 | ) 113 | 114 | # configure progressive parameter 115 | # ks = [6, 4] 116 | # num_keeps = [7, 4] 117 | # train_epochs = [2, 2] if 'debug' in args.save else [25, 25] 118 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 119 | optimizer, epochs, eta_min=learning_rate_min 120 | ) 121 | 122 | valid_accs = [] 123 | model.drop_path_prob = 0.0 124 | for epoch in range(epochs): 125 | lr = scheduler.get_last_lr()[0] 126 | logging.info("epoch %d lr %e", epoch, lr) 127 | 128 | # training 129 | train_acc, _ = train( 130 | train_queue=train_queue, 131 | model=model, 132 | criterion=criterion, 133 | optimizer=optimizer, 134 | grad_clip=grad_clip, 135 | report_freq=report_freq, 136 | ) 137 | logging.info("train_acc %f", train_acc) 138 | 139 | # validation 140 | valid_acc, _ = infer( 141 | valid_queue=valid_queue, 142 | model=model, 143 | criterion=criterion, 144 | report_freq=report_freq, 145 | ) 146 | valid_accs.append(float(valid_acc.cpu().detach().numpy())) 147 | logging.info("valid_acc %f", valid_acc) 148 | 149 | scheduler.step() 150 | # utils.save(model, os.path.join(save_path, 'weights.pt')) 151 | 152 | del model 153 | del criterion 154 | del optimizer 155 | del scheduler 156 | 157 | logging.getLogger().removeHandler(logging.getLogger().handlers[0]) 158 | 159 | return valid_accs 160 | 161 | 162 | def train(train_queue, model, criterion, optimizer, grad_clip, report_freq): 163 | objs = utils.AvgrageMeter() 164 | top1 = utils.AvgrageMeter() 165 | top5 = utils.AvgrageMeter() 166 | 167 | for step, (input, target) in enumerate( # pylint: disable=redefined-builtin 168 | train_queue 169 | ): 170 | model.train() 171 | n = input.size(0) 172 | input = input.cuda() # pylint: disable=redefined-builtin 173 | target = target.cuda(non_blocking=True) 174 | 175 | optimizer.zero_grad() 176 | 177 | logits, _ = model(input) 178 | loss = criterion(logits, target) 179 | 180 | loss.backward() 181 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 182 | optimizer.step() 183 | optimizer.zero_grad() 184 | 185 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 186 | objs.update(loss.data, n) 187 | top1.update(prec1.data, n) 188 | top5.update(prec5.data, n) 189 | 190 | if step % report_freq == 0: 191 | logging.info("train %03d %e %f %f", step, objs.avg, top1.avg, top5.avg) 192 | 193 | return top1.avg, objs.avg 194 | 195 | 196 | def infer(valid_queue, model, criterion, report_freq): 197 | objs = utils.AvgrageMeter() 198 | top1 = utils.AvgrageMeter() 199 | top5 = utils.AvgrageMeter() 200 | model.eval() 201 | 202 | with torch.no_grad(): 203 | for step, (input, target) in enumerate( # pylint: disable=redefined-builtin 204 | valid_queue 205 | ): 206 | input = input.cuda() 207 | target = target.cuda(non_blocking=True) 208 | 209 | logits, _ = model(input) 210 | loss = criterion(logits, target) 211 | 212 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 213 | n = input.size(0) 214 | objs.update(loss.data, n) 215 | top1.update(prec1.data, n) 216 | top5.update(prec5.data, n) 217 | 218 | if step % report_freq == 0: 219 | logging.info("valid %03d %e %f %f", step, objs.avg, top1.avg, top5.avg) 220 | 221 | return top1.avg, objs.avg 222 | -------------------------------------------------------------------------------- /benchmarks/objectives/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from neps.search_spaces.search_space import SearchSpace 4 | 5 | from benchmarks.evaluation.objective import Objective 6 | 7 | 8 | class ObjectiveWithAPI(Objective): 9 | 10 | def __init__(self, api) -> None: 11 | super().__init__(None, None, None) 12 | self.api = api 13 | 14 | def __call__(self, architecture): 15 | if isinstance(architecture, SearchSpace): 16 | graph = list(architecture.hyperparameters.values()) 17 | if len(graph) != 1: 18 | raise Exception( 19 | "Only one hyperparameter is allowed for this objective!") 20 | _config = graph[0].get_model_for_evaluation() 21 | else: 22 | _config = architecture 23 | start = time.time() 24 | loss = self.api.eval(_config) 25 | end = time.time() 26 | return { 27 | "loss": self.api.transform(loss), 28 | "info_dict": { 29 | "config_id": architecture.id, 30 | "val_score": loss, 31 | "test_score": self.api.test(_config), 32 | "train_time": end - start, 33 | }, 34 | } 35 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/activation_function_search/cifar_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/activation_function_search/grammar.cfg: -------------------------------------------------------------------------------- 1 | L2 -> "BinaryTopo" BinOp L1 L1 "id" "id" | "UnaryTopo" L1 2 | L1 -> "BinaryTopo" BinOp UnOp UnOp "id" "id" | "UnaryTopo" UnOp 3 | BinOp -> "add" | "multi" | "sub" | "div" | "bmax" | "bmin" | "bsigmoid" | "bgaussian_sq" | "bgaussian_abs" | "wavg" 4 | UnOp -> "id" | "neg" | "abs" | "square" | "cubic" | "square_root" | "mconst" | "aconst" | "log" | "exp" | "sin" | "cos" | "sinh" | "cosh" | "tanh" | "asinh" | "atanh" | "sinc" | "umax" | "umin" | "sigmoid" | "logexp" | "gaussian" | "erf" | "const" 5 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/activation_function_search/graph.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | from path import Path 4 | from functools import partial 5 | import inspect 6 | 7 | import neps.search_spaces.graph_grammar.topologies as topos 8 | from neps.search_spaces.graph_grammar.api import FunctionParameter 9 | from neps.search_spaces.graph_grammar.graph import Graph 10 | from torchvision import models 11 | import torch 12 | from torch import nn 13 | 14 | import benchmarks.search_spaces.activation_function_search.unary_operations as UnaryOp 15 | from benchmarks.search_spaces.activation_function_search.stacking import Stacking 16 | import benchmarks.search_spaces.activation_function_search.kvary_operations as BinaryOp 17 | from benchmarks.search_spaces.activation_function_search.topologies import BinaryTopo 18 | import benchmarks.search_spaces.activation_function_search.cifar_models as cifar_models 19 | 20 | DIR_PATH = Path(os.path.dirname(os.path.realpath(__file__))) 21 | 22 | PRIMITIVES = { 23 | "UnaryTopo": partial(topos.LinearNEdge, number_of_edges=1), 24 | "BinaryTopo": BinaryTopo, 25 | 26 | # unary ops 27 | "id": UnaryOp.Identity(), 28 | "neg": UnaryOp.Negate(), 29 | "abs": UnaryOp.Absolute(), 30 | "square": UnaryOp.Square(), 31 | "cubic": UnaryOp.Cubic(), 32 | "square_root": UnaryOp.SquareRoot(), 33 | "mconst": UnaryOp.MultConst, 34 | "aconst": UnaryOp.AddConst, 35 | "log": UnaryOp.Log(), 36 | "exp": UnaryOp.Exp(), 37 | "sin": UnaryOp.Sin(), 38 | "cos": UnaryOp.Cos(), 39 | "sinh": UnaryOp.Sinh(), 40 | "cosh": UnaryOp.Cosh(), 41 | "tanh": UnaryOp.Tanh(), 42 | "asinh": UnaryOp.aSinh(), 43 | "atanh": UnaryOp.aTanh(), 44 | "sinc": UnaryOp.Sinc(), 45 | "umax": UnaryOp.UnaryMax(), 46 | "umin": UnaryOp.UnaryMin(), 47 | "sigmoid": UnaryOp.Sigmoid(), 48 | "logexp": UnaryOp.LogExp(), 49 | "gaussian": UnaryOp.Gaussian(), 50 | "erf": UnaryOp.Erf(), 51 | "const": UnaryOp.Constant, 52 | 53 | # binary ops 54 | "add": BinaryOp.Addition(), 55 | "multi": BinaryOp.Multiplication(), 56 | "sub": BinaryOp.Subtraction(), 57 | "div": BinaryOp.Division(), 58 | "bmax": BinaryOp.BinaryMax(), 59 | "bmin": BinaryOp.BinaryMin(), 60 | "bsigmoid": BinaryOp.SigmoidMult(), 61 | "bgaussian_sq": BinaryOp.BinaryGaussianSquare, 62 | "bgaussian_abs": BinaryOp.BinaryGaussianAbs, 63 | "wavg": BinaryOp.WeightedAvg, 64 | } 65 | 66 | def set_comb_op(node, **kwargs): 67 | node[1]["comb_op"] = Stacking() 68 | 69 | pytorch_activation_functions = [act[1] for act in inspect.getmembers(nn.modules.activation, inspect.isclass) if not (act[1] == nn.Module or act[1] == torch.Tensor or act[1] == nn.Parameter)] 70 | 71 | def build(activation_function: Graph, base_architecture: str = "resnet20", num_classes: int = 10): 72 | def replace_activation_functions(base_module: nn.Module, activation_function: Graph, channels: int = -1): 73 | for name, module in base_module.named_children(): 74 | if hasattr(module, "out_channels") and module.out_channels > 0: 75 | channels = module.out_channels 76 | elif hasattr(module, "num_features") and module.num_features > 0: 77 | channels = module.num_features 78 | 79 | if any(isinstance(module, act) for act in pytorch_activation_functions): 80 | activation_function_copied = deepcopy(activation_function) 81 | for _, _, data in activation_function_copied.edges(data=True): 82 | if hasattr(data["op"], "trainable") and data["op"].trainable: 83 | data.update({"op": data["op"](channels)}) 84 | activation_function_copied.compile() 85 | activation_function_copied.update_op_names() 86 | activation_function_copied = activation_function_copied._to_pytorch() # pylint: disable=protected-access 87 | setattr(base_module, name, activation_function_copied) 88 | elif isinstance(module, nn.Module): 89 | new_module = replace_activation_functions(base_module=module, activation_function=activation_function, channels=channels) 90 | setattr(base_module, name, new_module) 91 | return base_module 92 | 93 | 94 | print(base_architecture) 95 | if hasattr(cifar_models, base_architecture): 96 | base_model = getattr(cifar_models, base_architecture)(num_classes=num_classes) 97 | elif hasattr(models, base_architecture): 98 | base_model = getattr(models, base_architecture)(pretrained=False, num_classes=num_classes) 99 | else: 100 | raise NotImplementedError(f"Model {base_architecture} is not implemented!") 101 | 102 | # set stacking as combo op 103 | activation_function.update_nodes( 104 | update_func=lambda node, in_edges, out_edges: set_comb_op( 105 | node, **{"a": in_edges, "b": out_edges} 106 | ), 107 | single_instances=False, 108 | ) 109 | 110 | model = replace_activation_functions(base_module=base_model, activation_function=activation_function) 111 | 112 | return model 113 | 114 | class ActivationSpace: 115 | def __new__(cls, base_architecture:str="resnet20", dataset: str ="cifar10", return_graph_per_hierarchy: bool = True): 116 | assert hasattr(cifar_models, base_architecture) or hasattr(models, base_architecture) 117 | 118 | if dataset == "cifar10": 119 | build_fn = partial(build, base_architecture=base_architecture, num_classes=10) 120 | elif dataset == "cifar100": 121 | build_fn = partial(build, base_architecture=base_architecture, num_classes=100) 122 | else: 123 | raise NotImplementedError(f"Dataset {dataset} is not supported") 124 | 125 | productions = cls._read_grammar("grammar.cfg") 126 | 127 | return FunctionParameter( 128 | set_recursive_attribute=build_fn, 129 | old_build_api=True, 130 | name=f"activation_{dataset}_{base_architecture}", 131 | structure=productions, 132 | primitives=PRIMITIVES, 133 | return_graph_per_hierarchy=return_graph_per_hierarchy, 134 | constraint_kwargs=None, 135 | prior=None, 136 | ) 137 | 138 | @staticmethod 139 | def _read_grammar(grammar_file: str) -> str: 140 | with open(Path(DIR_PATH) / grammar_file) as f: 141 | productions = f.read() 142 | return productions 143 | 144 | if __name__ == "__main__": 145 | from neps.search_spaces.search_space import SearchSpace 146 | import math 147 | 148 | pipeline_space = dict( 149 | architecture=ActivationSpace(base_architecture="resnet20"), 150 | ) 151 | pipeline_space = SearchSpace(**pipeline_space) 152 | print( 153 | "benchmark", 154 | math.log10(pipeline_space.hyperparameters["architecture"].search_space_size), 155 | ) 156 | 157 | pipeline_space.load({ 158 | "architecture": "(L2 UnaryTopo (L1 UnaryTopo (umax)))" 159 | }) 160 | print(pipeline_space["architecture"].id) 161 | 162 | model = pipeline_space.hyperparameters["architecture"].to_pytorch() 163 | print(model) 164 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/activation_function_search/kvary_operations.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | 4 | from neps.search_spaces.graph_grammar.primitives import AbstractPrimitive 5 | 6 | EPS = 1e-6 7 | 8 | class BinaryOperation(AbstractPrimitive): 9 | trainable = False 10 | 11 | def __init__( 12 | self, **kwargs 13 | ): # pylint:disable=W0613 14 | super().__init__(locals()) 15 | 16 | def forward(self, x): # pylint: disable=W0613 17 | raise NotImplementedError 18 | 19 | @staticmethod 20 | def get_embedded_ops(): 21 | return None 22 | 23 | class TrainableBinaryOperation(AbstractPrimitive): 24 | trainable = True 25 | 26 | def __init__( 27 | self, in_channels, **kwargs 28 | ): # pylint:disable=W0613 29 | super().__init__(locals()) 30 | self.beta = torch.nn.Parameter(torch.ones(in_channels, 1, 1)) 31 | 32 | def forward(self, x): # pylint: disable=W0613 33 | raise NotImplementedError 34 | 35 | @staticmethod 36 | def get_embedded_ops(): 37 | return None 38 | 39 | class Addition(BinaryOperation): 40 | def forward(self, x): 41 | return torch.sum(x, dim=0) 42 | 43 | class Multiplication(BinaryOperation): 44 | def forward(self, x): 45 | x1, x2 = torch.unbind(x, dim=0) 46 | return x1 * x2 47 | 48 | class Subtraction(BinaryOperation): 49 | def forward(self, x): 50 | x1, x2 = torch.unbind(x, dim=0) 51 | return x1 - x2 52 | 53 | class Division(BinaryOperation): 54 | def forward(self, x): 55 | x1, x2 = torch.unbind(x, dim=0) 56 | return x1 / (x2 + EPS) 57 | 58 | class BinaryMax(BinaryOperation): 59 | def forward(self, x): 60 | x1, x2 = torch.unbind(x, dim=0) 61 | return torch.maximum(x1, x2) 62 | 63 | class BinaryMin(BinaryOperation): 64 | def forward(self, x): 65 | x1, x2 = torch.unbind(x, dim=0) 66 | return torch.minimum(x1, x2) 67 | 68 | class SigmoidMult(BinaryOperation): 69 | def forward(self, x): 70 | x1, x2 = torch.unbind(x, dim=0) 71 | return torch.sigmoid(x1) * x2 72 | 73 | class BinaryGaussianSquare(TrainableBinaryOperation): 74 | get_op_name = "BinaryGaussianSquare" 75 | 76 | def forward(self, x): 77 | x1, x2 = torch.unbind(x, dim=0) 78 | return torch.exp(-self.beta*torch.square(x1-x2)) 79 | 80 | class BinaryGaussianAbs(TrainableBinaryOperation): 81 | get_op_name = "BinaryGaussianAbs" 82 | 83 | def forward(self, x): 84 | x1, x2 = torch.unbind(x, dim=0) 85 | return torch.exp(-self.beta*torch.abs(x1-x2)) 86 | 87 | class WeightedAvg(TrainableBinaryOperation): 88 | get_op_name = "WeightedAvg" 89 | 90 | def forward(self, x): 91 | x1, x2 = torch.unbind(x, dim=0) 92 | return self.beta * x1 + (1-self.beta) * x2 93 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/activation_function_search/stacking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from neps.search_spaces.graph_grammar.primitives import AbstractPrimitive 4 | 5 | class Stacking(AbstractPrimitive): 6 | def __init__( 7 | self, **kwargs 8 | ): # pylint:disable=W0613 9 | super().__init__(locals()) 10 | 11 | def forward(self, x): # pylint: disable=W0613 12 | return torch.stack(x, dim=0) 13 | 14 | @staticmethod 15 | def get_embedded_ops(): 16 | return None 17 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/activation_function_search/topologies.py: -------------------------------------------------------------------------------- 1 | from neps.search_spaces.graph_grammar.topologies import AbstractTopology 2 | 3 | class BinaryTopo(AbstractTopology): 4 | edge_list = [(4, 5), (1, 2), (1, 3), (2, 4), (3, 4)] 5 | 6 | def __init__(self, *edge_vals): 7 | super().__init__() 8 | 9 | self.name = f"binary_op_{edge_vals[0]}" 10 | self.create_graph(dict(zip(self.edge_list, edge_vals))) 11 | self.set_scope(self.name) 12 | self.graph_type = "edge_attr" 13 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/activation_function_search/unary_operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from neps.search_spaces.graph_grammar.primitives import AbstractPrimitive 4 | 5 | EPS = 1e-6 6 | 7 | class UnaryOperation(AbstractPrimitive): 8 | trainable = False 9 | 10 | def __init__( 11 | self, **kwargs 12 | ): # pylint:disable=W0613 13 | super().__init__(locals()) 14 | 15 | def forward(self, x): # pylint: disable=W0613 16 | raise NotImplementedError 17 | 18 | @staticmethod 19 | def get_embedded_ops(): 20 | return None 21 | 22 | class TrainableUnaryOperation(UnaryOperation): 23 | trainable = True 24 | 25 | def __init__(self, in_channels, **kwargs): 26 | super().__init__(**kwargs) 27 | self.beta = torch.nn.Parameter(torch.ones(in_channels, 1, 1)) 28 | 29 | 30 | class Identity(UnaryOperation): 31 | def forward(self, x): 32 | return x 33 | 34 | class Negate(UnaryOperation): 35 | def forward(self, x): 36 | return -x 37 | 38 | class Absolute(UnaryOperation): 39 | def forward(self, x): 40 | return torch.abs(x) 41 | 42 | class Square(UnaryOperation): 43 | def forward(self, x): 44 | return x ** 2 45 | 46 | class Cubic(UnaryOperation): 47 | def forward(self, x): 48 | return x ** 3 49 | 50 | class SquareRoot(UnaryOperation): 51 | def forward(self, x): 52 | return torch.sqrt(x) 53 | 54 | class MultConst(TrainableUnaryOperation): 55 | get_op_name = "multConst" 56 | 57 | def forward(self, x): 58 | return self.beta * x 59 | 60 | class AddConst(TrainableUnaryOperation): 61 | get_op_name = "addConst" 62 | 63 | def forward(self, x): 64 | return self.beta + x 65 | 66 | class Log(UnaryOperation): 67 | def forward(self, x): 68 | return torch.log(torch.abs(x) + EPS) 69 | 70 | class Exp(UnaryOperation): 71 | def forward(self, x): 72 | return torch.exp(x) 73 | 74 | class Sin(UnaryOperation): 75 | def forward(self, x): 76 | return torch.sin(x) 77 | 78 | class Cos(UnaryOperation): 79 | def forward(self, x): 80 | return torch.cos(x) 81 | 82 | class Sinh(UnaryOperation): 83 | def forward(self, x): 84 | return torch.sinh(x) 85 | 86 | class Cosh(UnaryOperation): 87 | def forward(self, x): 88 | return torch.cosh(x) 89 | 90 | class Tanh(UnaryOperation): 91 | def forward(self, x): 92 | return torch.tanh(x) 93 | 94 | class aSinh(UnaryOperation): 95 | def forward(self, x): 96 | return torch.asinh(x) 97 | 98 | class aTanh(UnaryOperation): 99 | def forward(self, x): 100 | return torch.atanh(x) 101 | 102 | class Sinc(UnaryOperation): 103 | def forward(self, x): 104 | return torch.sinc(x) 105 | 106 | class UnaryMax(UnaryOperation): 107 | def forward(self, x): 108 | return torch.maximum(x, torch.zeros_like(x)) 109 | 110 | class UnaryMin(UnaryOperation): 111 | def forward(self, x): 112 | return torch.minimum(x, torch.zeros_like(x)) 113 | 114 | class Sigmoid(UnaryOperation): 115 | def forward(self, x): 116 | return torch.sigmoid(x) 117 | 118 | class LogExp(UnaryOperation): 119 | def forward(self, x): 120 | return torch.log(1 + torch.exp(x)) 121 | 122 | class Gaussian(UnaryOperation): 123 | def forward(self, x): 124 | return torch.exp(-x**2) 125 | 126 | class Erf(UnaryOperation): 127 | def forward(self, x): 128 | return torch.erf(x) 129 | 130 | class Constant(TrainableUnaryOperation): 131 | get_op_name = "constant" 132 | 133 | def forward(self, x): 134 | return torch.ones_like(x) * self.beta 135 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/darts_cnn/cell.cfg: -------------------------------------------------------------------------------- 1 | CELL -> "DARTS" OP IN1 OP IN1 OP IN2 OP IN2 OP IN3 OP IN3 OP IN4 OP IN4 2 | OP -> "sep_conv_3x3" | "sep_conv_5x5" | "dil_conv_3x3" | "dil_conv_5x5" | "max_pool_3x3" | "avg_pool_3x3" | "skip_connect" 3 | IN1 -> "0" | "1" 4 | IN2 -> "0" | "1" | "2" 5 | IN3 -> "0" | "1" | "2" | "3" 6 | IN4 -> "0" | "1" | "2" | "3" | "4" 7 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/darts_cnn/genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple("Genotype", "normal normal_concat reduce reduce_concat") 4 | 5 | PRIMITIVES = [ 6 | # 'none', 7 | "max_pool_3x3", 8 | "avg_pool_3x3", 9 | "skip_connect", 10 | "sep_conv_3x3", 11 | "sep_conv_5x5", 12 | "dil_conv_3x3", 13 | "dil_conv_5x5", 14 | ] 15 | 16 | NASNet = Genotype( 17 | normal=[ 18 | ("sep_conv_5x5", 1), 19 | ("sep_conv_3x3", 0), 20 | ("sep_conv_5x5", 0), 21 | ("sep_conv_3x3", 0), 22 | ("avg_pool_3x3", 1), 23 | ("skip_connect", 0), 24 | ("avg_pool_3x3", 0), 25 | ("avg_pool_3x3", 0), 26 | ("sep_conv_3x3", 1), 27 | ("skip_connect", 1), 28 | ], 29 | normal_concat=[2, 3, 4, 5, 6], 30 | reduce=[ 31 | ("sep_conv_5x5", 1), 32 | ("sep_conv_7x7", 0), 33 | ("max_pool_3x3", 1), 34 | ("sep_conv_7x7", 0), 35 | ("avg_pool_3x3", 1), 36 | ("sep_conv_5x5", 0), 37 | ("skip_connect", 3), 38 | ("avg_pool_3x3", 2), 39 | ("sep_conv_3x3", 2), 40 | ("max_pool_3x3", 1), 41 | ], 42 | reduce_concat=[4, 5, 6], 43 | ) 44 | 45 | AmoebaNet = Genotype( 46 | normal=[ 47 | ("avg_pool_3x3", 0), 48 | ("max_pool_3x3", 1), 49 | ("sep_conv_3x3", 0), 50 | ("sep_conv_5x5", 2), 51 | ("sep_conv_3x3", 0), 52 | ("avg_pool_3x3", 3), 53 | ("sep_conv_3x3", 1), 54 | ("skip_connect", 1), 55 | ("skip_connect", 0), 56 | ("avg_pool_3x3", 1), 57 | ], 58 | normal_concat=[4, 5, 6], 59 | reduce=[ 60 | ("avg_pool_3x3", 0), 61 | ("sep_conv_3x3", 1), 62 | ("max_pool_3x3", 0), 63 | ("sep_conv_7x7", 2), 64 | ("sep_conv_7x7", 0), 65 | ("avg_pool_3x3", 1), 66 | ("max_pool_3x3", 0), 67 | ("max_pool_3x3", 1), 68 | ("conv_7x1_1x7", 0), 69 | ("sep_conv_3x3", 5), 70 | ], 71 | reduce_concat=[3, 4, 6], 72 | ) 73 | 74 | DARTS_V1 = Genotype( 75 | normal=[ 76 | ("sep_conv_3x3", 1), 77 | ("sep_conv_3x3", 0), 78 | ("skip_connect", 0), 79 | ("sep_conv_3x3", 1), 80 | ("skip_connect", 0), 81 | ("sep_conv_3x3", 1), 82 | ("sep_conv_3x3", 0), 83 | ("skip_connect", 2), 84 | ], 85 | normal_concat=[2, 3, 4, 5], 86 | reduce=[ 87 | ("max_pool_3x3", 0), 88 | ("max_pool_3x3", 1), 89 | ("skip_connect", 2), 90 | ("max_pool_3x3", 0), 91 | ("max_pool_3x3", 0), 92 | ("skip_connect", 2), 93 | ("skip_connect", 2), 94 | ("avg_pool_3x3", 0), 95 | ], 96 | reduce_concat=[2, 3, 4, 5], 97 | ) 98 | DARTS_V2 = Genotype( 99 | normal=[ 100 | ("sep_conv_3x3", 0), 101 | ("sep_conv_3x3", 1), 102 | ("sep_conv_3x3", 0), 103 | ("sep_conv_3x3", 1), 104 | ("sep_conv_3x3", 1), 105 | ("skip_connect", 0), 106 | ("skip_connect", 0), 107 | ("dil_conv_3x3", 2), 108 | ], 109 | normal_concat=[2, 3, 4, 5], 110 | reduce=[ 111 | ("max_pool_3x3", 0), 112 | ("max_pool_3x3", 1), 113 | ("skip_connect", 2), 114 | ("max_pool_3x3", 1), 115 | ("max_pool_3x3", 0), 116 | ("skip_connect", 2), 117 | ("skip_connect", 2), 118 | ("max_pool_3x3", 1), 119 | ], 120 | reduce_concat=[2, 3, 4, 5], 121 | ) 122 | 123 | 124 | PC_DARTS_cifar = Genotype( 125 | normal=[ 126 | ("sep_conv_3x3", 1), 127 | ("skip_connect", 0), 128 | ("sep_conv_3x3", 0), 129 | ("dil_conv_3x3", 1), 130 | ("sep_conv_5x5", 0), 131 | ("sep_conv_3x3", 1), 132 | ("avg_pool_3x3", 0), 133 | ("dil_conv_3x3", 1), 134 | ], 135 | normal_concat=range(2, 6), 136 | reduce=[ 137 | ("sep_conv_5x5", 1), 138 | ("max_pool_3x3", 0), 139 | ("sep_conv_5x5", 1), 140 | ("sep_conv_5x5", 2), 141 | ("sep_conv_3x3", 0), 142 | ("sep_conv_3x3", 3), 143 | ("sep_conv_3x3", 1), 144 | ("sep_conv_3x3", 2), 145 | ], 146 | reduce_concat=range(2, 6), 147 | ) 148 | PC_DARTS_image = Genotype( 149 | normal=[ 150 | ("skip_connect", 1), 151 | ("sep_conv_3x3", 0), 152 | ("sep_conv_3x3", 0), 153 | ("skip_connect", 1), 154 | ("sep_conv_3x3", 1), 155 | ("sep_conv_3x3", 3), 156 | ("sep_conv_3x3", 1), 157 | ("dil_conv_5x5", 4), 158 | ], 159 | normal_concat=range(2, 6), 160 | reduce=[ 161 | ("sep_conv_3x3", 0), 162 | ("skip_connect", 1), 163 | ("dil_conv_5x5", 2), 164 | ("max_pool_3x3", 1), 165 | ("sep_conv_3x3", 2), 166 | ("sep_conv_3x3", 1), 167 | ("sep_conv_5x5", 0), 168 | ("sep_conv_3x3", 3), 169 | ], 170 | reduce_concat=range(2, 6), 171 | ) 172 | 173 | 174 | DrNAS_cifar10 = Genotype( 175 | normal=[ 176 | ("sep_conv_3x3", 0), 177 | ("sep_conv_5x5", 1), 178 | ("sep_conv_3x3", 1), 179 | ("sep_conv_3x3", 2), 180 | ("skip_connect", 0), 181 | ("sep_conv_3x3", 1), 182 | ("sep_conv_3x3", 2), 183 | ("dil_conv_5x5", 3), 184 | ], 185 | normal_concat=range(2, 6), 186 | reduce=[ 187 | ("max_pool_3x3", 0), 188 | ("sep_conv_5x5", 1), 189 | ("dil_conv_5x5", 2), 190 | ("sep_conv_5x5", 1), 191 | ("sep_conv_5x5", 1), 192 | ("dil_conv_5x5", 3), 193 | ("skip_connect", 4), 194 | ("sep_conv_5x5", 1), 195 | ], 196 | reduce_concat=range(2, 6), 197 | ) 198 | DrNAS_imagenet = Genotype( 199 | normal=[ 200 | ("sep_conv_3x3", 1), 201 | ("sep_conv_3x3", 0), 202 | ("sep_conv_3x3", 0), 203 | ("sep_conv_3x3", 1), 204 | ("sep_conv_3x3", 0), 205 | ("dil_conv_3x3", 3), 206 | ("skip_connect", 0), 207 | ("sep_conv_3x3", 1), 208 | ], 209 | normal_concat=range(2, 6), 210 | reduce=[ 211 | ("max_pool_3x3", 0), 212 | ("sep_conv_3x3", 1), 213 | ("sep_conv_3x3", 0), 214 | ("skip_connect", 2), 215 | ("sep_conv_3x3", 0), 216 | ("sep_conv_3x3", 2), 217 | ("sep_conv_3x3", 3), 218 | ("sep_conv_3x3", 1), 219 | ], 220 | reduce_concat=range(2, 6), 221 | ) 222 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/darts_cnn/net2wider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # bias = 0 6 | def InChannelWider(module, new_channels, index=None): 7 | weight = module.weight 8 | in_channels = weight.size(1) 9 | 10 | if index is None: 11 | index = torch.randint(low=0, high=in_channels, size=(new_channels - in_channels,)) 12 | module.weight = nn.Parameter( 13 | torch.cat([weight, weight[:, index, :, :].clone()], dim=1), requires_grad=True 14 | ) 15 | 16 | module.in_channels = new_channels 17 | module.weight.in_index = index 18 | module.weight.t = "conv" 19 | if hasattr(weight, "out_index"): 20 | module.weight.out_index = weight.out_index 21 | module.weight.raw_id = weight.raw_id if hasattr(weight, "raw_id") else id(weight) 22 | return module, index 23 | 24 | 25 | # bias = 0 26 | def OutChannelWider(module, new_channels, index=None): 27 | weight = module.weight 28 | out_channels = weight.size(0) 29 | 30 | if index is None: 31 | index = torch.randint( 32 | low=0, high=out_channels, size=(new_channels - out_channels,) 33 | ) 34 | module.weight = nn.Parameter( 35 | torch.cat([weight, weight[index, :, :, :].clone()], dim=0), requires_grad=True 36 | ) 37 | 38 | module.out_channels = new_channels 39 | module.weight.out_index = index 40 | module.weight.t = "conv" 41 | if hasattr(weight, "in_index"): 42 | module.weight.in_index = weight.in_index 43 | module.weight.raw_id = weight.raw_id if hasattr(weight, "raw_id") else id(weight) 44 | return module, index 45 | 46 | 47 | def BNWider(module, new_features, index=None): 48 | running_mean = module.running_mean 49 | running_var = module.running_var 50 | if module.affine: 51 | weight = module.weight 52 | bias = module.bias 53 | num_features = module.num_features 54 | 55 | if index is None: 56 | index = torch.randint( 57 | low=0, high=num_features, size=(new_features - num_features,) 58 | ) 59 | module.running_mean = torch.cat([running_mean, running_mean[index].clone()]) 60 | module.running_var = torch.cat([running_var, running_var[index].clone()]) 61 | if module.affine: 62 | module.weight = nn.Parameter( 63 | torch.cat([weight, weight[index].clone()], dim=0), requires_grad=True 64 | ) 65 | module.bias = nn.Parameter( 66 | torch.cat([bias, bias[index].clone()], dim=0), requires_grad=True 67 | ) 68 | 69 | module.weight.out_index = index 70 | module.bias.out_index = index 71 | module.weight.t = "bn" 72 | module.bias.t = "bn" 73 | module.weight.raw_id = weight.raw_id if hasattr(weight, "raw_id") else id(weight) 74 | module.bias.raw_id = bias.raw_id if hasattr(bias, "raw_id") else id(bias) 75 | module.num_features = new_features 76 | return module, index 77 | 78 | 79 | def configure_optimizer(optimizer_old, optimizer_new): 80 | for i, p in enumerate(optimizer_new.param_groups[0]["params"]): 81 | if not hasattr(p, "raw_id"): 82 | optimizer_new.state[p] = optimizer_old.state[p] 83 | continue 84 | state_old = optimizer_old.state_dict()["state"][p.raw_id] 85 | state_new = optimizer_new.state[p] 86 | 87 | state_new["momentum_buffer"] = state_old["momentum_buffer"] 88 | if p.t == "bn": 89 | # BN layer 90 | state_new["momentum_buffer"] = torch.cat( 91 | [ 92 | state_new["momentum_buffer"], 93 | state_new["momentum_buffer"][p.out_index].clone(), 94 | ], 95 | dim=0, 96 | ) 97 | # clean to enable multiple call 98 | del p.t, p.raw_id, p.out_index 99 | 100 | elif p.t == "conv": 101 | # conv layer 102 | if hasattr(p, "in_index"): 103 | state_new["momentum_buffer"] = torch.cat( 104 | [ 105 | state_new["momentum_buffer"], 106 | state_new["momentum_buffer"][:, p.in_index, :, :].clone(), 107 | ], 108 | dim=1, 109 | ) 110 | if hasattr(p, "out_index"): 111 | state_new["momentum_buffer"] = torch.cat( 112 | [ 113 | state_new["momentum_buffer"], 114 | state_new["momentum_buffer"][p.out_index, :, :, :].clone(), 115 | ], 116 | dim=0, 117 | ) 118 | # clean to enable multiple call 119 | del p.t, p.raw_id 120 | if hasattr(p, "in_index"): 121 | del p.in_index 122 | if hasattr(p, "out_index"): 123 | del p.out_index 124 | print( 125 | "%d momemtum buffers loaded" % (i + 1) # pylint: disable=undefined-loop-variable 126 | ) 127 | return optimizer_new 128 | 129 | 130 | def configure_scheduler(scheduler_old, scheduler_new): 131 | scheduler_new.load_state_dict(scheduler_old.state_dict()) 132 | print("scheduler loaded") 133 | return scheduler_new 134 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/darts_cnn/primitives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from neps.search_spaces.graph_grammar.primitives import AbstractPrimitive 3 | from torch import nn 4 | 5 | 6 | class DARTSAbstractPrimitive(AbstractPrimitive): 7 | def __init__(self, C: int, stride: int, affine: bool = True): 8 | super().__init__(locals()) 9 | self.C = C 10 | self.stride = stride 11 | self.affine = affine 12 | 13 | def forward(self, x): 14 | raise NotImplementedError 15 | 16 | 17 | class FactorizedReduce(nn.Module): 18 | def __init__(self, C_in, C_out, affine=True): 19 | super().__init__() 20 | assert C_out % 2 == 0 21 | self.relu = nn.ReLU(inplace=False) 22 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 23 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 24 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 25 | 26 | def forward(self, x): 27 | x = self.relu(x) 28 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) 29 | out = self.bn(out) 30 | return out 31 | 32 | 33 | class SkipConnect(DARTSAbstractPrimitive): 34 | def __init__(self, C: int, stride: int, affine: bool = True): 35 | super().__init__(C, stride, affine) 36 | 37 | self.identity = ( 38 | nn.Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine) 39 | ) 40 | 41 | def forward(self, x): 42 | return self.identity(x) 43 | 44 | @property 45 | def get_op_name(self): 46 | return "skip_connect" 47 | 48 | 49 | class Zero(DARTSAbstractPrimitive): 50 | def __init__(self, C: int, stride: int, affine: bool = True): 51 | super().__init__(C, stride, affine) 52 | 53 | def forward(self, x): 54 | if self.stride == 1: 55 | return x.mul(0.0) 56 | return x[:, :, :: self.stride, :: self.stride].mul(0.0) 57 | 58 | @property 59 | def get_op_name(self): 60 | return "none" 61 | 62 | 63 | class Pooling(DARTSAbstractPrimitive): 64 | def __init__(self, pool_type: str, C: int, stride, affine: bool = True): 65 | super().__init__(C, stride, affine) 66 | 67 | assert pool_type in ["avg", "max"] 68 | self.pool_type = pool_type 69 | if "avg" == pool_type: 70 | self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) 71 | elif "max" == pool_type: 72 | self.pool = nn.MaxPool2d(3, stride=stride, padding=1) 73 | 74 | def forward(self, x): 75 | return self.pool(x) 76 | 77 | @property 78 | def get_op_name(self): 79 | return f"{self.pool_type}_pool_3x3" 80 | 81 | 82 | class SepConv(DARTSAbstractPrimitive): 83 | def __init__(self, kernel_size: int, C: int, stride: int, affine: bool = True): 84 | super().__init__(C, stride, affine) 85 | C_in = C_out = C 86 | padding = kernel_size // 2 87 | self.kernel_size = kernel_size 88 | self.op = nn.Sequential( 89 | nn.ReLU(inplace=False), 90 | nn.Conv2d( 91 | C_in, 92 | C_in, 93 | kernel_size=kernel_size, 94 | stride=stride, 95 | padding=padding, 96 | groups=C_in, 97 | bias=False, 98 | ), 99 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 100 | nn.BatchNorm2d(C_in, affine=affine), 101 | nn.ReLU(inplace=False), 102 | nn.Conv2d( 103 | C_in, 104 | C_in, 105 | kernel_size=kernel_size, 106 | stride=1, 107 | padding=padding, 108 | groups=C_in, 109 | bias=False, 110 | ), 111 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 112 | nn.BatchNorm2d(C_out, affine=affine), 113 | ) 114 | 115 | def forward(self, x): 116 | return self.op(x) 117 | 118 | @property 119 | def get_op_name(self): 120 | return f"sep_conv_{self.kernel_size}x{self.kernel_size}" 121 | 122 | 123 | class DilConv(DARTSAbstractPrimitive): 124 | def __init__(self, kernel_size: int, C: int, stride: int, affine: bool = True): 125 | super().__init__(C, stride, affine) 126 | C_in = C_out = C 127 | padding = (kernel_size // 2) * 2 128 | dilation = 2 129 | self.kernel_size = kernel_size 130 | self.op = nn.Sequential( 131 | nn.ReLU(inplace=False), 132 | nn.Conv2d( 133 | C_in, 134 | C_in, 135 | kernel_size=kernel_size, 136 | stride=stride, 137 | padding=padding, 138 | dilation=dilation, 139 | groups=C_in, 140 | bias=False, 141 | ), 142 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 143 | nn.BatchNorm2d(C_out, affine=affine), 144 | ) 145 | 146 | def forward(self, x): 147 | return self.op(x) 148 | 149 | @property 150 | def get_op_name(self): 151 | return f"dil_conv_{self.kernel_size}x{self.kernel_size}" 152 | 153 | 154 | class Concat(AbstractPrimitive): 155 | """ 156 | Implementation of the channel-wise concatination. 157 | """ 158 | 159 | def __init__(self): 160 | super().__init__(locals()) 161 | 162 | def forward(self, x): # pylint: disable=no-self-use 163 | """ 164 | Expecting a list of input tensors. Stacking them channel-wise. 165 | """ 166 | x = torch.cat(x, dim=1) 167 | return x 168 | 169 | 170 | class Unbinder(AbstractPrimitive): 171 | def __init__(self, idx): 172 | super().__init__(locals()) 173 | self.idx = idx 174 | 175 | def forward(self, x): 176 | return torch.unbind(x, dim=0)[self.idx] 177 | 178 | 179 | class Stacking(AbstractPrimitive): 180 | def __init__(self, **kwargs): # pylint: disable=W0613 181 | super().__init__(locals()) 182 | 183 | def forward(self, x): # pylint: disable=no-self-use 184 | return torch.stack(x, dim=0) 185 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/darts_cnn/topologies.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from neps.search_spaces.graph_grammar.topologies import AbstractTopology 3 | 4 | from benchmarks.search_spaces.darts_cnn.primitives import ( 5 | Concat, 6 | SkipConnect, 7 | Stacking, 8 | Unbinder, 9 | ) 10 | 11 | 12 | class DARTSCell(AbstractTopology): 13 | def __init__(self, *edge_vals): 14 | super().__init__() 15 | self.name = "darts_cell" 16 | 17 | in_nodes = [int(val) for idx, val in enumerate(edge_vals) if idx % 2 == 1] 18 | self.edge_list = [ 19 | (in_nodes[0], 2), 20 | (in_nodes[1], 2), 21 | (in_nodes[2], 3), 22 | (in_nodes[3], 3), 23 | (in_nodes[4], 4), 24 | (in_nodes[5], 4), 25 | (in_nodes[6], 5), 26 | (in_nodes[7], 5), 27 | (2, 6), 28 | (3, 6), 29 | (4, 6), 30 | (5, 6), 31 | ] 32 | edge_list = [(-1, 0), (-1, 1)] + self.edge_list # -1 is an helper input node 33 | 34 | self.edge_vals = list(edge_vals) 35 | op_on_edge = [{"op": Unbinder, "idx": 0}, {"op": Unbinder, "idx": 1}] 36 | op_on_edge += [val for idx, val in enumerate(edge_vals) if idx % 2 == 0] 37 | op_on_edge += [{"op": SkipConnect, "C": None, "stride": 1} for _ in range(4)] 38 | 39 | self.create_graph(dict(zip(edge_list, op_on_edge))) 40 | 41 | # Assign dummy variables as node attributes: 42 | for i in self.nodes: 43 | self.nodes[i]["op_name"] = "1" 44 | 45 | self.nodes[-1].update({"comb_op": Stacking()}) 46 | for idx in range(6): 47 | self.nodes[idx].update({"comb_op": sum}) 48 | self.nodes[6].update({"comb_op": Concat()}) 49 | 50 | self.graph_type = "edge_attr" 51 | self.set_scope(self.name, recursively=False) 52 | 53 | def get_node_list_and_ops(self): 54 | ops = [val["op_name"] for idx, val in enumerate(self.edge_vals) if idx % 2 == 0] 55 | in_nodes = [val for idx, val in enumerate(self.edge_vals) if idx % 2 == 1] 56 | cell = list(zip(ops, in_nodes)) 57 | 58 | G = nx.DiGraph() 59 | n_nodes = (8 // 2) * 3 + 3 60 | G.add_nodes_from(range(n_nodes), op_name=None) 61 | n_ops = 8 // 2 62 | G.nodes[0]["op_name"] = "input1" 63 | G.nodes[1]["op_name"] = "input2" 64 | G.nodes[n_nodes - 1]["op_name"] = "output" 65 | for i in range(n_ops): 66 | G.nodes[i * 3 + 2]["op_name"] = cell[i * 2][0] 67 | G.nodes[i * 3 + 3]["op_name"] = cell[i * 2 + 1][0] 68 | G.nodes[i * 3 + 4]["op_name"] = "add" 69 | G.add_edge(i * 3 + 2, i * 3 + 4) 70 | G.add_edge(i * 3 + 3, i * 3 + 4) 71 | 72 | for i in range(n_ops): 73 | # Add the connections to the input 74 | for offset in range(2): 75 | if cell[i * 2 + offset][1] == 0: 76 | G.add_edge(0, i * 3 + 2 + offset) 77 | elif cell[i * 2 + offset][1] == 1: 78 | G.add_edge(1, i * 3 + 2 + offset) 79 | else: 80 | k = cell[i * 2 + offset][1] - 2 81 | # Add a connection from the output of another block 82 | G.add_edge(int(k) * 3 + 4, i * 3 + 2 + offset) 83 | # Add connections to the output 84 | for i in range(2, 6): 85 | if i <= 1: 86 | G.add_edge(i, n_nodes - 1) # Directly from either input to the output 87 | else: 88 | op_number = i - 2 89 | G.add_edge(op_number * 3 + 4, n_nodes - 1) 90 | # Remove the skip link nodes, do another sweep of the graph 91 | for j in range(n_nodes): 92 | try: 93 | G.nodes[j] 94 | except KeyError: 95 | continue 96 | if G.nodes[j]["op_name"] == "skip_connect": 97 | in_edges = list(G.in_edges(j)) 98 | out_edge = list(G.out_edges(j))[0][ 99 | 1 100 | ] # There should be only one out edge really... 101 | for in_edge in in_edges: 102 | G.add_edge(in_edge[0], out_edge) 103 | G.remove_node(j) 104 | elif G.nodes[j]["op_name"] == "none": 105 | G.remove_node(j) 106 | for j in range(n_nodes): 107 | try: 108 | G.nodes[j] 109 | except KeyError: 110 | continue 111 | 112 | if G.nodes[j]["op_name"] not in ["input1", "input2"]: 113 | # excepting the input nodes, if the node has no incoming edge, remove it 114 | if len(list(G.in_edges(j))) == 0: 115 | G.remove_node(j) 116 | elif G.nodes[j]["op_name"] != "output": 117 | # excepting the output nodes, if the node has no outgoing edge, remove it 118 | if len(list(G.out_edges(j))) == 0: 119 | G.remove_node(j) 120 | elif ( 121 | G.nodes[j]["op_name"] == "add" 122 | ): # If add has one incoming edge only, remove the node 123 | in_edges = list(G.in_edges(j)) 124 | out_edges = list(G.out_edges(j)) 125 | if len(in_edges) == 1 and len(out_edges) == 1: 126 | G.add_edge(in_edges[0][0], out_edges[0][1]) 127 | G.remove_node(j) 128 | 129 | return G 130 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/darts_cnn/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision.transforms as transforms 8 | from torch.autograd import Variable 9 | 10 | TORCH_VERSION = torch.__version__ 11 | 12 | 13 | class AvgrageMeter: 14 | def __init__(self): 15 | self.avg = None 16 | self.reset() 17 | 18 | def reset(self): 19 | self.avg = 0 20 | self.sum = 0 21 | self.cnt = 0 22 | 23 | def update(self, val, n=1): 24 | self.sum += val * n 25 | self.cnt += n 26 | self.avg = self.sum / self.cnt 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | maxk = max(topk) 31 | batch_size = target.size(0) 32 | 33 | _, pred = output.topk(maxk, 1, True, True) 34 | pred = pred.t() 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 36 | 37 | res = [] 38 | for k in topk: 39 | if TORCH_VERSION.startswith("1"): 40 | correct_k = correct[:k].reshape(-1).float().sum(0) 41 | else: 42 | correct_k = correct[:k].view(-1).float().sum(0) 43 | res.append(correct_k.mul_(100.0 / batch_size)) 44 | return res 45 | 46 | 47 | class Cutout: 48 | def __init__(self, length): 49 | self.length = length 50 | 51 | def __call__(self, img): 52 | h, w = img.size(1), img.size(2) 53 | mask = np.ones((h, w), np.float32) 54 | y = np.random.randint(h) 55 | x = np.random.randint(w) 56 | 57 | y1 = np.clip(y - self.length // 2, 0, h) 58 | y2 = np.clip(y + self.length // 2, 0, h) 59 | x1 = np.clip(x - self.length // 2, 0, w) 60 | x2 = np.clip(x + self.length // 2, 0, w) 61 | 62 | mask[y1:y2, x1:x2] = 0.0 63 | mask = torch.from_numpy(mask) 64 | mask = mask.expand_as(img) 65 | img *= mask 66 | return img 67 | 68 | 69 | def _data_transforms_svhn(args): 70 | SVHN_MEAN = [0.4377, 0.4438, 0.4728] 71 | SVHN_STD = [0.1980, 0.2010, 0.1970] 72 | 73 | train_transform = transforms.Compose( 74 | [ 75 | transforms.RandomCrop(32, padding=4), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | transforms.Normalize(SVHN_MEAN, SVHN_STD), 79 | ] 80 | ) 81 | if args.cutout: 82 | train_transform.transforms.append(Cutout(args.cutout_length, args.cutout_prob)) 83 | 84 | valid_transform = transforms.Compose( 85 | [ 86 | transforms.ToTensor(), 87 | transforms.Normalize(SVHN_MEAN, SVHN_STD), 88 | ] 89 | ) 90 | return train_transform, valid_transform 91 | 92 | 93 | def _data_transforms_cifar100(args): 94 | CIFAR_MEAN = [0.5071, 0.4865, 0.4409] 95 | CIFAR_STD = [0.2673, 0.2564, 0.2762] 96 | 97 | train_transform = transforms.Compose( 98 | [ 99 | transforms.RandomCrop(32, padding=4), 100 | transforms.RandomHorizontalFlip(), 101 | transforms.ToTensor(), 102 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 103 | ] 104 | ) 105 | if args.cutout: 106 | train_transform.transforms.append(Cutout(args.cutout_length, args.cutout_prob)) 107 | 108 | valid_transform = transforms.Compose( 109 | [ 110 | transforms.ToTensor(), 111 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 112 | ] 113 | ) 114 | return train_transform, valid_transform 115 | 116 | 117 | def _data_transforms_cifar10(cutout, cutout_length): 118 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] 119 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] 120 | 121 | train_transform = transforms.Compose( 122 | [ 123 | transforms.RandomCrop(32, padding=4), 124 | transforms.RandomHorizontalFlip(), 125 | transforms.ToTensor(), 126 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 127 | ] 128 | ) 129 | if cutout: 130 | train_transform.transforms.append(Cutout(cutout_length)) 131 | 132 | valid_transform = transforms.Compose( 133 | [ 134 | transforms.ToTensor(), 135 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 136 | ] 137 | ) 138 | return train_transform, valid_transform 139 | 140 | 141 | def count_parameters_in_MB(model): 142 | return ( 143 | np.sum( 144 | np.prod(v.size()) 145 | for name, v in model.named_parameters() 146 | if "auxiliary" not in name 147 | ) 148 | / 1e6 149 | ) 150 | 151 | 152 | def save_checkpoint(state, is_best, save): 153 | filename = os.path.join(save, "checkpoint.pth.tar") 154 | torch.save(state, filename) 155 | if is_best: 156 | best_filename = os.path.join(save, "model_best.pth.tar") 157 | shutil.copyfile(filename, best_filename) 158 | 159 | 160 | def save(model, model_path): 161 | torch.save(model.state_dict(), model_path) 162 | 163 | 164 | def load(model, model_path): 165 | model.load_state_dict(torch.load(model_path)) 166 | 167 | 168 | def drop_path(x, drop_prob): 169 | if drop_prob > 0.0: 170 | keep_prob = 1.0 - drop_prob 171 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 172 | x.div_(keep_prob) 173 | x.mul_(mask) 174 | return x 175 | 176 | 177 | def create_exp_dir(path, scripts_to_save=None): 178 | if not os.path.exists(path): 179 | os.mkdir(path) 180 | print(f"Experiment dir : {path}") 181 | 182 | if scripts_to_save is not None: 183 | os.mkdir(os.path.join(path, "scripts")) 184 | for script in scripts_to_save: 185 | dst_file = os.path.join(path, "scripts", os.path.basename(script)) 186 | shutil.copyfile(script, dst_file) 187 | 188 | 189 | def process_step_vector(x, method, mask, tau=None): 190 | if method == "softmax": 191 | output = F.softmax(x, dim=-1) 192 | elif method == "dirichlet": 193 | output = torch.distributions.dirichlet.Dirichlet(F.elu(x) + 1).rsample() 194 | elif method == "gumbel": 195 | output = F.gumbel_softmax(x, tau=tau, hard=False, dim=-1) 196 | 197 | if mask is None: 198 | return output 199 | else: 200 | output_pruned = torch.zeros_like(output) 201 | output_pruned[mask] = output[mask] 202 | output_pruned /= output_pruned.sum() 203 | assert (output_pruned[~mask] == 0.0).all() 204 | return output_pruned 205 | 206 | 207 | def process_step_matrix(x, method, mask, tau=None): 208 | weights = [] 209 | if mask is None: 210 | for line in x: 211 | weights.append(process_step_vector(line, method, None, tau)) 212 | else: 213 | for i, line in enumerate(x): 214 | weights.append(process_step_vector(line, method, mask[i], tau)) 215 | return torch.stack(weights) 216 | 217 | 218 | def prune(x, num_keep, mask, reset=False): 219 | if not mask is None: 220 | x.data[~mask] -= 1000000 221 | src, index = x.topk(k=num_keep, dim=-1) 222 | if not reset: 223 | x.data.copy_(torch.zeros_like(x).scatter(dim=1, index=index, src=src)) 224 | else: 225 | x.data.copy_( 226 | torch.zeros_like(x).scatter( 227 | dim=1, index=index, src=1e-3 * torch.randn_like(src) 228 | ) 229 | ) 230 | mask = torch.zeros_like(x, dtype=torch.bool).scatter( 231 | dim=1, index=index, src=torch.ones_like(src, dtype=torch.bool) 232 | ) 233 | return mask 234 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/darts_cnn/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from graphviz import Digraph 4 | from path import Path 5 | 6 | 7 | def plot(genotype, filename): 8 | g = Digraph( 9 | format="pdf", 10 | edge_attr=dict(fontsize="20", fontname="times"), 11 | node_attr=dict( 12 | style="filled", 13 | shape="rect", 14 | align="center", 15 | fontsize="20", 16 | height="0.5", 17 | width="0.5", 18 | penwidth="2", 19 | fontname="times", 20 | ), 21 | engine="dot", 22 | ) 23 | g.body.extend(["rankdir=LR"]) 24 | 25 | g.node("c_{k-2}", fillcolor="darkseagreen2") 26 | g.node("c_{k-1}", fillcolor="darkseagreen2") 27 | assert len(genotype) % 2 == 0 28 | steps = len(genotype) // 2 29 | 30 | for i in range(steps): 31 | g.node(str(i), fillcolor="lightblue") 32 | 33 | for i in range(steps): 34 | for k in [2 * i, 2 * i + 1]: 35 | op, j = genotype[k] 36 | if j == 0: 37 | u = "c_{k-2}" 38 | elif j == 1: 39 | u = "c_{k-1}" 40 | else: 41 | u = str(j - 2) 42 | v = str(i) 43 | g.edge(u, v, label=op, fillcolor="gray") 44 | 45 | g.node("c_{k}", fillcolor="palegoldenrod") 46 | for i in range(steps): 47 | g.edge(str(i), "c_{k}", fillcolor="gray") 48 | 49 | g.render(filename, view=True) 50 | dir_path = Path(os.path.dirname(os.path.realpath(__file__))) 51 | os.remove(dir_path / filename) 52 | 53 | 54 | def plot_from_graph(graph, filename, node_attr: bool = True): 55 | g = Digraph( 56 | format="pdf", 57 | edge_attr=dict(fontsize="20", fontname="times"), 58 | node_attr=dict( 59 | style="filled", 60 | shape="rect", 61 | align="center", 62 | fontsize="20", 63 | height="0.5", 64 | width="0.5", 65 | penwidth="2", 66 | fontname="times", 67 | ), 68 | engine="dot", 69 | ) 70 | g.body.extend(["rankdir=LR"]) 71 | 72 | if node_attr: 73 | for n, data in graph.nodes(data=True): 74 | op_name = data["op_name"] 75 | if "input" in op_name: 76 | g.node(str(n), op_name, fillcolor="darkseagreen2") 77 | elif "output" == op_name: 78 | g.node(str(n), op_name, fillcolor="palegoldenrod") 79 | elif op_name == "add": 80 | g.node(str(n), op_name, fillcolor="gray", shape="ellipse") 81 | else: 82 | g.node(str(n), op_name, fillcolor="lightblue", shape="ellipse") 83 | for u, v in graph.edges: 84 | g.edge(str(u), str(v), fillcolor="gray") 85 | else: 86 | raise NotImplementedError 87 | 88 | g.render(filename, view=True) 89 | dir_path = Path(os.path.dirname(os.path.realpath(__file__))) 90 | os.remove(dir_path / filename) 91 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/hierarchical_nb201/grammars/cell.cfg: -------------------------------------------------------------------------------- 1 | CELL -> "Cell" OPS OPS OPS OPS OPS OPS 2 | OPS -> "id" | "zero" | "conv3x3" | "conv1x1" | "avg_pool" 3 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/hierarchical_nb201/grammars/cell_flexible.cfg: -------------------------------------------------------------------------------- 1 | CELL -> "Cell" OPS OPS OPS OPS OPS OPS 2 | OPS -> "id" | "zero" | "Linear1" CONVBLOCK | "Linear1" CONVBLOCK | "avg_pool" 3 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/hierarchical_nb201/grammars/conv_block.cfg: -------------------------------------------------------------------------------- 1 | CONVBLOCK -> "Linear3" ACT CONV NORM 2 | CONV -> "conv3x3o" | "conv1x1o" | "dconv3x3o" 3 | NORM -> "batch" | "instance" | "layer" 4 | ACT -> "relu" | "hardswish" | "mish" 5 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/hierarchical_nb201/grammars/macro.cfg: -------------------------------------------------------------------------------- 1 | D2 -> "Linear3" D1 D1 D0 | "Linear3" D0 D1 D1 | "Linear4" D1 D1 D0 D0 2 | D1 -> "Linear3" C C DOWN | "Linear4" C C C DOWN | "Residual3" C C DOWN DOWN 3 | D0 -> "Linear3" C C CELL | "Linear4" C C C CELL | "Residual3" C C CELL CELL 4 | DOWN -> "Linear2" CELL "resBlock" | "Linear3" CELL CELL "resBlock" | "Residual2" CELL "resBlock" "resBlock" 5 | C -> "Linear2" CELL CELL | "Linear3" CELL CELL CELL | "Residual2" CELL CELL CELL 6 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/hierarchical_nb201/grammars/macro_fixed_repetitive.cfg: -------------------------------------------------------------------------------- 1 | D2 -> "Linear3" D1 D1 D0 2 | D1 -> "Linear3" C C DOWN 3 | D0 -> "Linear3" C C "SharedCell" 4 | DOWN -> "Linear2" "SharedCell" "resBlock" 5 | C -> "Linear2" "SharedCell" "SharedCell" 6 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/hierarchical_nb201/primitives.py: -------------------------------------------------------------------------------- 1 | from neps.search_spaces.graph_grammar.primitives import AbstractPrimitive 2 | from torch import nn 3 | 4 | 5 | class ResNetBasicblock(AbstractPrimitive): 6 | def __init__( 7 | self, 8 | C_in: int, 9 | C_out: int, 10 | stride: int, 11 | affine: bool = True, 12 | track_running_stats: bool = True, 13 | ): 14 | super().__init__(locals()) 15 | assert stride == 1 or stride == 2, f"invalid stride {stride}" 16 | self.conv_a = ReLUConvBN( 17 | C_in, C_out, 3, stride, 1, 1, affine, track_running_stats 18 | ) 19 | self.conv_b = ReLUConvBN(C_out, C_out, 3, 1, 1, 1, affine, track_running_stats) 20 | if stride == 2: 21 | self.downsample = nn.Sequential( 22 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0), 23 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=1, padding=0, bias=False), 24 | ) 25 | elif C_in != C_out: 26 | self.downsample = ReLUConvBN( 27 | C_in, C_out, 1, 1, 0, 1, affine, track_running_stats 28 | ) 29 | else: 30 | self.downsample = None 31 | self.in_dim = C_in 32 | self.out_dim = C_out 33 | self.stride = stride 34 | self.num_conv = 2 35 | 36 | def forward(self, inputs): 37 | basicblock = self.conv_a(inputs) 38 | basicblock = self.conv_b(basicblock) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(inputs) 42 | else: 43 | residual = inputs 44 | return residual + basicblock 45 | 46 | 47 | class ReLUConvBN(AbstractPrimitive): 48 | def __init__( 49 | self, 50 | C_in, 51 | C_out, 52 | kernel_size, 53 | stride, 54 | padding, 55 | dilation, 56 | affine, 57 | track_running_stats=True, 58 | ): 59 | super().__init__(locals()) 60 | kernel_size = int(kernel_size) 61 | stride = int(stride) 62 | 63 | self.kernel_size = kernel_size 64 | self.op = nn.Sequential( 65 | nn.ReLU(inplace=False), 66 | nn.Conv2d( 67 | C_in, 68 | C_out, 69 | kernel_size, 70 | stride=stride, 71 | padding=padding, 72 | dilation=dilation, 73 | bias=not affine, 74 | ), 75 | nn.BatchNorm2d(C_out, affine=affine, track_running_stats=track_running_stats), 76 | ) 77 | 78 | def forward(self, x): 79 | return self.op(x) 80 | 81 | @property 82 | def get_op_name(self): 83 | op_name = super().get_op_name 84 | op_name += f"{self.kernel_size}x{self.kernel_size}" 85 | return op_name 86 | 87 | 88 | class POOLING(AbstractPrimitive): 89 | def __init__( 90 | self, 91 | C_in: int, 92 | C_out: int, 93 | stride: int, 94 | mode: str, 95 | affine: bool = True, 96 | track_running_stats: bool = True, 97 | ): 98 | super().__init__(locals()) 99 | if C_in == C_out: 100 | self.preprocess = None 101 | else: 102 | self.preprocess = ReLUConvBN( 103 | C_in, C_out, 1, 1, 0, 1, affine, track_running_stats 104 | ) 105 | self.mode = mode 106 | if mode == "avg": 107 | self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) 108 | elif mode == "max": 109 | self.op = nn.MaxPool2d(3, stride=stride, padding=1) 110 | else: 111 | raise ValueError(f"Invalid mode={mode} in POOLING") 112 | 113 | def forward(self, inputs): 114 | if self.preprocess: 115 | x = self.preprocess(inputs) 116 | else: 117 | x = inputs 118 | return self.op(x) 119 | 120 | @property 121 | def get_op_name(self): 122 | return f"{self.mode}pool" 123 | 124 | 125 | class Conv(AbstractPrimitive): 126 | def __init__( 127 | self, C_in: int, C_out: int, kernel_size: int, stride: int = 1, bias: bool = False 128 | ): 129 | super().__init__(locals()) 130 | pad = 0 if stride == 1 and kernel_size == 1 else 1 131 | self.conv = nn.Conv2d( 132 | C_in, C_out, kernel_size, stride=stride, padding=pad, bias=bias 133 | ) 134 | 135 | def forward(self, x): 136 | return self.conv(x) 137 | 138 | 139 | class DepthwiseConv(AbstractPrimitive): 140 | def __init__( 141 | self, 142 | C_in: int, 143 | C_out: int, # pylint: disable=W0613 144 | kernel_size: int, 145 | stride: int = 1, 146 | padding: int = 1, 147 | bias: bool = False, 148 | ): 149 | super().__init__(locals()) 150 | self.conv = nn.Conv2d( 151 | C_in, 152 | C_in, 153 | kernel_size=kernel_size, 154 | stride=stride, 155 | padding=padding, 156 | groups=C_in, 157 | bias=bias, 158 | ) 159 | 160 | def forward(self, x): 161 | return self.conv(x) 162 | 163 | 164 | class Normalization(AbstractPrimitive): 165 | def __init__( 166 | self, 167 | C_out: int, 168 | norm_type: str, 169 | affine: bool = True, 170 | **kwargs, # pylint: disable=W0613 171 | ): 172 | super().__init__(locals()) 173 | self.norm_type = norm_type 174 | self.affine = affine 175 | if norm_type == "batch_norm": 176 | self.norm = nn.BatchNorm2d(C_out, affine=affine) 177 | elif norm_type == "layer_norm": 178 | self.norm = None 179 | elif norm_type == "instance_norm": 180 | self.norm = nn.InstanceNorm2d(C_out, affine=affine) 181 | else: 182 | raise NotImplementedError 183 | 184 | def forward(self, x): 185 | if self.norm_type == "layer_norm" and self.norm is None: 186 | self.norm = nn.LayerNorm(x.shape[1:], elementwise_affine=self.affine) 187 | if x.is_cuda: 188 | self.norm = self.norm.cuda() 189 | return self.norm(x) 190 | 191 | 192 | class Activation(AbstractPrimitive): 193 | def __init__(self, C_out: int, act_type: str, **kwargs): # pylint: disable=W0613 194 | super().__init__(locals()) 195 | self.act_type = act_type 196 | if act_type == "relu": 197 | self.act = nn.ReLU(inplace=False) 198 | elif act_type == "gelu": 199 | self.act = nn.GELU() 200 | elif act_type == "silu": 201 | self.act = nn.SiLU(inplace=False) 202 | elif act_type == "hardswish": 203 | self.act = nn.Hardswish(inplace=False) 204 | elif act_type == "mish": 205 | self.act = nn.Mish(inplace=False) 206 | else: 207 | raise NotImplementedError 208 | 209 | def forward(self, x): # pylint: disable=W0613 210 | return self.act(x) 211 | -------------------------------------------------------------------------------- /benchmarks/search_spaces/hierarchical_nb201/topologies.py: -------------------------------------------------------------------------------- 1 | from neps.search_spaces.graph_grammar.topologies import AbstractTopology 2 | 3 | 4 | class NASBench201Cell(AbstractTopology): 5 | edge_list = [(1, 2), (1, 3), (2, 3), (1, 4), (2, 4), (3, 4)] 6 | 7 | def __init__(self, *edge_vals): 8 | super().__init__() 9 | 10 | self.name = "cell" 11 | self.create_graph(dict(zip(self.edge_list, edge_vals))) 12 | 13 | # Assign dummy variables as node attributes: 14 | for i in self.nodes: 15 | self.nodes[i]["op_name"] = "1" 16 | self.graph_type = "edge_attr" 17 | self.set_scope(self.name, recursively=False) 18 | 19 | 20 | class Residual3(AbstractTopology): 21 | edge_list = [(1, 2), (2, 3), (1, 4), (3, 4)] 22 | 23 | def __init__(self, *edge_vals): 24 | super().__init__() 25 | 26 | self.name = "residual_3" 27 | self.create_graph(dict(zip(self.edge_list, edge_vals))) 28 | self.set_scope(self.name, recursively=False) 29 | 30 | 31 | class Diamond3(AbstractTopology): 32 | edge_list = [(1, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 6)] 33 | 34 | def __init__(self, *edge_vals): 35 | super().__init__() 36 | self.name = "diamond_3" 37 | self.create_graph(dict(zip(self.edge_list, edge_vals))) 38 | self.set_scope(self.name, recursively=False) 39 | -------------------------------------------------------------------------------- /benchmarks/utils/objective.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from benchmarks.evaluation.objective import Objective 4 | 5 | 6 | class ObjectiveWithAPI(Objective): 7 | def __init__(self, api) -> None: 8 | super().__init__(None, None, None) 9 | self.api = api 10 | 11 | def __call__(self, config): 12 | _config = config.get_model_for_evaluation() 13 | start = time.time() 14 | loss = self.api.eval(_config) 15 | end = time.time() 16 | return { 17 | "loss": self.api.transform(loss), 18 | "info_dict": { 19 | "config_id": config.id, 20 | "val_score": loss, 21 | "test_score": self.api.test(_config), 22 | "train_time": end - start, 23 | }, 24 | } 25 | -------------------------------------------------------------------------------- /benchmarks/utils/torch_error_message.py: -------------------------------------------------------------------------------- 1 | error_message = "Please install torch (and torchvision)! We provide a script for installation here: install_dev_utils/install_torch." 2 | -------------------------------------------------------------------------------- /benchmarks/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_seed(seed): 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | torch.manual_seed(seed) 11 | if torch.cuda.is_available(): 12 | torch.backends.cudnn.benchmark = False 13 | torch.backends.cudnn.enabled = True 14 | torch.backends.cudnn.deterministic = True 15 | torch.cuda.manual_seed_all(seed) 16 | -------------------------------------------------------------------------------- /experiments/darts_evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | from functools import partial 5 | from typing import Union 6 | 7 | import numpy as np 8 | import torch 9 | from benchmarks.objectives.addNIST import AddNISTObjective 10 | from benchmarks.objectives.cifarTile import CifarTileObjective 11 | from benchmarks.objectives.hierarchical_nb201 import NB201Pipeline 12 | from benchmarks.search_spaces.hierarchical_nb201.graph import NB201Spaces 13 | from nas_201_api import NASBench201API 14 | from neps.search_spaces.search_space import SearchSpace 15 | from path import Path 16 | 17 | ObjectiveMapping = { 18 | "nb201_addNIST": AddNISTObjective, 19 | "nb201_cifarTile": CifarTileObjective, 20 | "nb201_cifar10": partial(NB201Pipeline, dataset="cifar10"), 21 | "nb201_cifar100": partial(NB201Pipeline, dataset="cifar100"), 22 | "nb201_ImageNet16-120": partial(NB201Pipeline, dataset="ImageNet16-120"), 23 | } 24 | 25 | 26 | def get_genotype(path_to_genotypes: Union[str, Path]) -> str: 27 | with open(Path(path_to_genotypes) / "genotypes.txt") as f: 28 | data = f.readlines() 29 | return data[-1][:-1] 30 | 31 | 32 | def genotype_to_identifier(genotype: str): 33 | replace_map = { 34 | "nor_conv_3x3": "conv3x3", 35 | "nor_conv_1x1": "conv1x1", 36 | "skip_connect": "id", 37 | "none": "zero", 38 | "avg_pool_3x3": "avg_pool", 39 | } 40 | identifier = "(CELL Cell" 41 | for src_node_connections in genotype.split("+"): 42 | for key in src_node_connections[1:-1].split("|"): 43 | identifier += f" (OPS {replace_map[key.split('~')[0]]})" 44 | return identifier + ")" 45 | 46 | 47 | def distill(result): 48 | result = result.split("\n") 49 | cifar10 = result[5].replace(" ", "").split(":") 50 | cifar100 = result[7].replace(" ", "").split(":") 51 | imagenet16 = result[9].replace(" ", "").split(":") 52 | 53 | cifar10_train = float(cifar10[1].strip(",test")[-7:-2].strip("=")) 54 | cifar10_test = float(cifar10[2][-7:-2].strip("=")) 55 | cifar100_train = float(cifar100[1].strip(",valid")[-7:-2].strip("=")) 56 | cifar100_valid = float(cifar100[2].strip(",test")[-7:-2].strip("=")) 57 | cifar100_test = float(cifar100[3][-7:-2].strip("=")) 58 | imagenet16_train = float(imagenet16[1].strip(",valid")[-7:-2].strip("=")) 59 | imagenet16_valid = float(imagenet16[2].strip(",test")[-7:-2].strip("=")) 60 | imagenet16_test = float(imagenet16[3][-7:-2].strip("=")) 61 | 62 | return ( 63 | cifar10_train, 64 | cifar10_test, 65 | cifar100_train, 66 | cifar100_valid, 67 | cifar100_test, 68 | imagenet16_train, 69 | imagenet16_valid, 70 | imagenet16_test, 71 | ) 72 | 73 | 74 | parser = argparse.ArgumentParser("DARTS evaluation on cell-based nb201") 75 | parser.add_argument("--working_directory", type=str, help="where data should be saved") 76 | parser.add_argument( 77 | "--data_path", type=str, default="datapath", help="location of the data corpus" 78 | ) 79 | parser.add_argument("--api_path", type=str, default="", help="location of the api data") 80 | parser.add_argument( 81 | "--objective", 82 | type=str, 83 | default="cifar10", 84 | help="choose dataset", 85 | choices=["cifar10", "cifar100", "ImageNet16-120", "cifarTile", "addNIST"], 86 | ) 87 | args = parser.parse_args() 88 | 89 | splits = args.working_directory.split("/") 90 | search_space = splits[-4] 91 | dataset = splits[-3] 92 | method = splits[-2] 93 | args.seed = int(splits[-1]) 94 | 95 | genotype = get_genotype(args.working_directory) 96 | 97 | run_pipeline_fn = ObjectiveMapping[dataset]( 98 | data_path=args.data_path, seed=args.seed, eval_mode=True 99 | ) 100 | 101 | if hasattr(run_pipeline_fn, "set_seed"): 102 | run_pipeline_fn.set_seed(args.seed) 103 | np.random.seed(args.seed) 104 | random.seed(args.seed) 105 | torch.manual_seed(args.seed) 106 | if torch.cuda.is_available(): 107 | torch.backends.cudnn.benchmark = False 108 | torch.backends.cudnn.enabled = True 109 | torch.backends.cudnn.deterministic = True 110 | torch.cuda.manual_seed_all(args.seed) 111 | 112 | search_space = NB201Spaces( 113 | space=search_space[6:-4], dataset=dataset[6:], adjust_params=False 114 | ) 115 | 116 | if not isinstance(search_space, dict) and not isinstance(search_space, SearchSpace): 117 | search_space = {"architecture": search_space} 118 | 119 | # read in best config 120 | best_config = SearchSpace(**search_space) 121 | identifier = genotype_to_identifier(genotype) 122 | best_config.load_from({"architecture": identifier}) 123 | model = best_config["architecture"].to_pytorch() 124 | 125 | # evaluate 126 | results = run_pipeline_fn("", "", architecture=model) 127 | 128 | if args.objective in ["cifar10", "cifar100", "ImageNet16-120"] and args.api_path: 129 | api = NASBench201API(args.api_path) 130 | result = api.query_by_arch(genotype, hp="200") 131 | print(result) 132 | ( 133 | cifar10_train, 134 | cifar10_test, 135 | cifar100_train, 136 | cifar100_valid, 137 | cifar100_test, 138 | imagenet16_train, 139 | imagenet16_valid, 140 | imagenet16_test, 141 | ) = distill(result) 142 | results["api"] = { 143 | "cifar10-test": cifar10_test, 144 | "cifar100-valid": cifar100_valid, 145 | "cifar100-test": cifar100_test, 146 | "ImageNet16-120-valid": imagenet16_valid, 147 | "ImageNet16-120-test": imagenet16_test, 148 | } 149 | 150 | print(results) 151 | with open(Path(args.working_directory) / "best_config_eval.json", "w") as f: 152 | json.dump(results, f, indent=4) 153 | -------------------------------------------------------------------------------- /experiments/darts_utils/architect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | 6 | def _concat(xs): 7 | return torch.cat([x.view(-1) for x in xs]) 8 | 9 | 10 | class Architect: 11 | def __init__(self, model, args): 12 | self.network_momentum = args.momentum 13 | self.network_weight_decay = args.weight_decay 14 | self.model = model 15 | if args.reg_type == "l2": 16 | weight_decay = args.reg_scale 17 | elif args.reg_type == "kl": 18 | weight_decay = 0 19 | self.optimizer = torch.optim.Adam( 20 | self.model.arch_parameters(), 21 | lr=args.arch_learning_rate, 22 | betas=(0.5, 0.999), 23 | weight_decay=weight_decay, 24 | ) 25 | 26 | def _compute_unrolled_model(self, input, target, eta, network_optimizer): 27 | loss = self.model._loss(input, target) 28 | theta = _concat(self.model.parameters()).data 29 | try: 30 | moment = _concat( 31 | network_optimizer.state[v]["momentum_buffer"] 32 | for v in self.model.parameters() 33 | ).mul_(self.network_momentum) 34 | except: 35 | moment = torch.zeros_like(theta) 36 | dtheta = ( 37 | _concat(torch.autograd.grad(loss, self.model.parameters())).data 38 | + self.network_weight_decay * theta 39 | ) 40 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment + dtheta)) 41 | return unrolled_model 42 | 43 | def step( 44 | self, 45 | input_train, 46 | target_train, 47 | input_valid, 48 | target_valid, 49 | eta, 50 | network_optimizer, 51 | unrolled, 52 | ): 53 | self.optimizer.zero_grad() 54 | if unrolled: 55 | self._backward_step_unrolled( 56 | input_train, 57 | target_train, 58 | input_valid, 59 | target_valid, 60 | eta, 61 | network_optimizer, 62 | ) 63 | else: 64 | self._backward_step(input_valid, target_valid) 65 | self.optimizer.step() 66 | 67 | # def pruning(self, masks): 68 | # for i, p in enumerate(self.optimizer.param_groups[0]['params']): 69 | # if masks[i] is None: 70 | # continue 71 | # state = self.optimizer.state[p] 72 | # mask = masks[i] 73 | # state['exp_avg'][~mask] = 0.0 74 | # state['exp_avg_sq'][~mask] = 0.0 75 | 76 | def _backward_step(self, input_valid, target_valid): 77 | loss = self.model._loss(input_valid, target_valid) 78 | loss.backward() 79 | 80 | def _backward_step_unrolled( 81 | self, input_train, target_train, input_valid, target_valid, eta, network_optimizer 82 | ): 83 | unrolled_model = self._compute_unrolled_model( 84 | input_train, target_train, eta, network_optimizer 85 | ) 86 | unrolled_loss = unrolled_model._loss(input_valid, target_valid) 87 | 88 | unrolled_loss.backward() 89 | dalpha = [v.grad for v in unrolled_model.arch_parameters()] 90 | vector = [v.grad.data for v in unrolled_model.parameters()] 91 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train) 92 | 93 | for g, ig in zip(dalpha, implicit_grads): 94 | g.data.sub_(eta, ig.data) 95 | 96 | for v, g in zip(self.model.arch_parameters(), dalpha): 97 | if v.grad is None: 98 | v.grad = Variable(g.data) 99 | else: 100 | v.grad.data.copy_(g.data) 101 | 102 | def _construct_model_from_theta(self, theta): 103 | model_new = self.model.new() 104 | model_dict = self.model.state_dict() 105 | 106 | params, offset = {}, 0 107 | for k, v in self.model.named_parameters(): 108 | v_length = np.prod(v.size()) 109 | params[k] = theta[offset : offset + v_length].view(v.size()) 110 | offset += v_length 111 | 112 | assert offset == len(theta) 113 | model_dict.update(params) 114 | model_new.load_state_dict(model_dict) 115 | return model_new.cuda() 116 | 117 | def _hessian_vector_product(self, vector, input, target, r=1e-2): 118 | R = r / _concat(vector).norm() 119 | for p, v in zip(self.model.parameters(), vector): 120 | p.data.add_(R, v) 121 | loss = self.model._loss(input, target) 122 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) 123 | 124 | for p, v in zip(self.model.parameters(), vector): 125 | p.data.sub_(2 * R, v) 126 | loss = self.model._loss(input, target) 127 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) 128 | 129 | for p, v in zip(self.model.parameters(), vector): 130 | p.data.add_(R, v) 131 | 132 | return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)] 133 | -------------------------------------------------------------------------------- /experiments/darts_utils/net2wider.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # bias = 0 7 | def InChannelWider(module, new_channels, index=None): 8 | weight = module.weight 9 | in_channels = weight.size(1) 10 | 11 | if index is None: 12 | index = torch.randint(low=0, high=in_channels, 13 | size=(new_channels-in_channels,)) 14 | module.weight = nn.Parameter( 15 | torch.cat([weight, weight[:, index, :, :].clone()], dim=1), requires_grad=True) 16 | 17 | module.in_channels = new_channels 18 | module.weight.in_index = index 19 | module.weight.t = 'conv' 20 | if hasattr(weight, 'out_index'): 21 | module.weight.out_index = weight.out_index 22 | module.weight.raw_id = weight.raw_id if hasattr( 23 | weight, 'raw_id') else id(weight) 24 | return module, index 25 | 26 | 27 | # bias = 0 28 | def OutChannelWider(module, new_channels, index=None): 29 | weight = module.weight 30 | out_channels = weight.size(0) 31 | 32 | if index is None: 33 | index = torch.randint(low=0, high=out_channels, 34 | size=(new_channels-out_channels,)) 35 | module.weight = nn.Parameter( 36 | torch.cat([weight, weight[index, :, :, :].clone()], dim=0), requires_grad=True) 37 | 38 | module.out_channels = new_channels 39 | module.weight.out_index = index 40 | module.weight.t = 'conv' 41 | if hasattr(weight, 'in_index'): 42 | module.weight.in_index = weight.in_index 43 | module.weight.raw_id = weight.raw_id if hasattr( 44 | weight, 'raw_id') else id(weight) 45 | return module, index 46 | 47 | 48 | def BNWider(module, new_features, index=None): 49 | running_mean = module.running_mean 50 | running_var = module.running_var 51 | if module.affine: 52 | weight = module.weight 53 | bias = module.bias 54 | num_features = module.num_features 55 | 56 | if index is None: 57 | index = torch.randint(low=0, high=num_features, 58 | size=(new_features-num_features,)) 59 | module.running_mean = torch.cat([running_mean, running_mean[index].clone()]) 60 | module.running_var = torch.cat([running_var, running_var[index].clone()]) 61 | if module.affine: 62 | module.weight = nn.Parameter( 63 | torch.cat([weight, weight[index].clone()], dim=0), requires_grad=True) 64 | module.bias = nn.Parameter( 65 | torch.cat([bias, bias[index].clone()], dim=0), requires_grad=True) 66 | 67 | module.weight.out_index = index 68 | module.bias.out_index = index 69 | module.weight.t = 'bn' 70 | module.bias.t = 'bn' 71 | module.weight.raw_id = weight.raw_id if hasattr( 72 | weight, 'raw_id') else id(weight) 73 | module.bias.raw_id = bias.raw_id if hasattr( 74 | bias, 'raw_id') else id(bias) 75 | module.num_features = new_features 76 | return module, index 77 | 78 | 79 | def configure_optimizer(optimizer_old, optimizer_new): 80 | for i, p in enumerate(optimizer_new.param_groups[0]['params']): 81 | if not hasattr(p, 'raw_id'): 82 | optimizer_new.state[p] = optimizer_old.state[p] 83 | continue 84 | state_old = optimizer_old.state_dict()['state'][p.raw_id] 85 | state_new = optimizer_new.state[p] 86 | 87 | state_new['momentum_buffer'] = state_old['momentum_buffer'] 88 | if p.t == 'bn': 89 | # BN layer 90 | state_new['momentum_buffer'] = torch.cat( 91 | [state_new['momentum_buffer'], state_new['momentum_buffer'][p.out_index].clone()], dim=0) 92 | # clean to enable multiple call 93 | del p.t, p.raw_id, p.out_index 94 | 95 | elif p.t == 'conv': 96 | # conv layer 97 | if hasattr(p, 'in_index'): 98 | state_new['momentum_buffer'] = torch.cat( 99 | [state_new['momentum_buffer'], state_new['momentum_buffer'][:, p.in_index, :, :].clone()], dim=1) 100 | if hasattr(p, 'out_index'): 101 | state_new['momentum_buffer'] = torch.cat( 102 | [state_new['momentum_buffer'], state_new['momentum_buffer'][p.out_index, :, :, :].clone()], dim=0) 103 | # clean to enable multiple call 104 | del p.t, p.raw_id 105 | if hasattr(p, 'in_index'): 106 | del p.in_index 107 | if hasattr(p, 'out_index'): 108 | del p.out_index 109 | print('%d momemtum buffers loaded' % (i+1)) 110 | return optimizer_new 111 | 112 | 113 | def configure_scheduler(scheduler_old, scheduler_new): 114 | scheduler_new.load_state_dict(scheduler_old.state_dict()) 115 | print('scheduler loaded') 116 | return scheduler_new 117 | 118 | -------------------------------------------------------------------------------- /experiments/darts_utils/search_cells.py: -------------------------------------------------------------------------------- 1 | import random, torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from experiments.darts_utils.cell_operations import OPS 5 | 6 | 7 | # This module is used for NAS-Bench-201, represents a small search space with a complete DAG 8 | class NAS201SearchCell(nn.Module): 9 | 10 | def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True): 11 | super(NAS201SearchCell, self).__init__() 12 | 13 | self.op_names = deepcopy(op_names) 14 | self.edges = nn.ModuleDict() 15 | self.max_nodes = max_nodes 16 | self.in_dim = C_in 17 | self.out_dim = C_out 18 | for i in range(1, max_nodes): 19 | for j in range(i): 20 | node_str = '{:}<-{:}'.format(i, j) 21 | if j == 0: 22 | xlists = [OPS[op_name](C_in , C_out, stride, affine, track_running_stats) for op_name in op_names] 23 | else: 24 | xlists = [OPS[op_name](C_in , C_out, 1, affine, track_running_stats) for op_name in op_names] 25 | self.edges[ node_str ] = nn.ModuleList( xlists ) 26 | self.edge_keys = sorted(list(self.edges.keys())) 27 | self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} 28 | self.num_edges = len(self.edges) 29 | 30 | def extra_repr(self): 31 | string = 'info :: {max_nodes} nodes, inC={in_dim}, outC={out_dim}'.format(**self.__dict__) 32 | return string 33 | 34 | def forward(self, inputs, weightss): 35 | nodes = [inputs] 36 | for i in range(1, self.max_nodes): 37 | inter_nodes = [] 38 | for j in range(i): 39 | node_str = '{:}<-{:}'.format(i, j) 40 | weights = weightss[ self.edge2index[node_str] ] 41 | inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) 42 | nodes.append( sum(inter_nodes) ) 43 | return nodes[-1] 44 | 45 | # GDAS 46 | def forward_gdas(self, inputs, hardwts, index): 47 | nodes = [inputs] 48 | for i in range(1, self.max_nodes): 49 | inter_nodes = [] 50 | for j in range(i): 51 | node_str = '{:}<-{:}'.format(i, j) 52 | weights = hardwts[ self.edge2index[node_str] ] 53 | argmaxs = index[ self.edge2index[node_str] ].item() 54 | weigsum = sum( weights[_ie] * edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) 55 | inter_nodes.append( weigsum ) 56 | nodes.append( sum(inter_nodes) ) 57 | return nodes[-1] 58 | 59 | # joint 60 | def forward_joint(self, inputs, weightss): 61 | nodes = [inputs] 62 | for i in range(1, self.max_nodes): 63 | inter_nodes = [] 64 | for j in range(i): 65 | node_str = '{:}<-{:}'.format(i, j) 66 | weights = weightss[ self.edge2index[node_str] ] 67 | #aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) / weights.numel() 68 | aggregation = sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) 69 | inter_nodes.append( aggregation ) 70 | nodes.append( sum(inter_nodes) ) 71 | return nodes[-1] 72 | 73 | # uniform random sampling per iteration, SETN 74 | def forward_urs(self, inputs): 75 | nodes = [inputs] 76 | for i in range(1, self.max_nodes): 77 | while True: # to avoid select zero for all ops 78 | sops, has_non_zero = [], False 79 | for j in range(i): 80 | node_str = '{:}<-{:}'.format(i, j) 81 | candidates = self.edges[node_str] 82 | select_op = random.choice(candidates) 83 | sops.append( select_op ) 84 | if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True 85 | if has_non_zero: break 86 | inter_nodes = [] 87 | for j, select_op in enumerate(sops): 88 | inter_nodes.append( select_op(nodes[j]) ) 89 | nodes.append( sum(inter_nodes) ) 90 | return nodes[-1] 91 | 92 | # select the argmax 93 | def forward_select(self, inputs, weightss): 94 | nodes = [inputs] 95 | for i in range(1, self.max_nodes): 96 | inter_nodes = [] 97 | for j in range(i): 98 | node_str = '{:}<-{:}'.format(i, j) 99 | weights = weightss[ self.edge2index[node_str] ] 100 | inter_nodes.append( self.edges[node_str][ weights.argmax().item() ]( nodes[j] ) ) 101 | #inter_nodes.append( sum( layer(nodes[j]) * w for layer, w in zip(self.edges[node_str], weights) ) ) 102 | nodes.append( sum(inter_nodes) ) 103 | return nodes[-1] 104 | 105 | # forward with a specific structure 106 | def forward_dynamic(self, inputs, structure): 107 | nodes = [inputs] 108 | for i in range(1, self.max_nodes): 109 | cur_op_node = structure.nodes[i-1] 110 | inter_nodes = [] 111 | for op_name, j in cur_op_node: 112 | node_str = '{:}<-{:}'.format(i, j) 113 | op_index = self.op_names.index( op_name ) 114 | inter_nodes.append( self.edges[node_str][op_index]( nodes[j] ) ) 115 | nodes.append( sum(inter_nodes) ) 116 | return nodes[-1] 117 | 118 | 119 | def channel_shuffle(x, groups): 120 | batchsize, num_channels, height, width = x.data.size() 121 | channels_per_group = num_channels // groups 122 | # reshape 123 | x = x.view(batchsize, groups, 124 | channels_per_group, height, width) 125 | x = torch.transpose(x, 1, 2).contiguous() 126 | # flatten 127 | x = x.view(batchsize, -1, height, width) 128 | return x 129 | 130 | 131 | class NAS201SearchCell_PartialChannel(NAS201SearchCell): 132 | 133 | def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False, track_running_stats=True, k=4): 134 | super(NAS201SearchCell, self).__init__() 135 | 136 | self.k = k 137 | self.op_names = deepcopy(op_names) 138 | self.edges = nn.ModuleDict() 139 | self.max_nodes = max_nodes 140 | self.in_dim = C_in 141 | self.out_dim = C_out 142 | for i in range(1, max_nodes): 143 | for j in range(i): 144 | node_str = '{:}<-{:}'.format(i, j) 145 | if j == 0: 146 | xlists = [OPS[op_name](C_in//self.k , C_out//self.k, stride, affine, track_running_stats) for op_name in op_names] 147 | else: 148 | xlists = [OPS[op_name](C_in//self.k , C_out//self.k, 1, affine, track_running_stats) for op_name in op_names] 149 | self.edges[ node_str ] = nn.ModuleList( xlists ) 150 | self.edge_keys = sorted(list(self.edges.keys())) 151 | self.edge2index = {key:i for i, key in enumerate(self.edge_keys)} 152 | self.num_edges = len(self.edges) 153 | 154 | def MixedOp(self, x, ops, weights): 155 | dim_2 = x.shape[1] 156 | xtemp = x[ : , : dim_2//self.k, :, :] 157 | xtemp2 = x[ : , dim_2//self.k:, :, :] 158 | temp1 = sum(w * op(xtemp) for w, op in zip(weights, ops) if not w == 0) 159 | if self.k == 1: 160 | return temp1 161 | ans = torch.cat([temp1,xtemp2],dim=1) 162 | ans = channel_shuffle(ans,self.k) 163 | return ans 164 | 165 | def forward(self, inputs, weightss): 166 | nodes = [inputs] 167 | for i in range(1, self.max_nodes): 168 | inter_nodes = [] 169 | for j in range(i): 170 | node_str = '{:}<-{:}'.format(i, j) 171 | weights = weightss[ self.edge2index[node_str] ] 172 | inter_nodes.append(self.MixedOp(x=nodes[j], ops=self.edges[node_str], weights=weights)) 173 | nodes.append( sum(inter_nodes) ) 174 | return nodes[-1] 175 | 176 | def wider(self, k): 177 | self.k = k 178 | for key in self.edges.keys(): 179 | for op in self.edges[key]: 180 | op.wider(self.in_dim//k, self.out_dim//k) 181 | -------------------------------------------------------------------------------- /experiments/darts_utils/search_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from copy import deepcopy 5 | from experiments.darts_utils.cell_operations import ResNetBasicblock 6 | from experiments.darts_utils.search_cells import NAS201SearchCell_PartialChannel as SearchCell 7 | from experiments.darts_utils.genotypes import Structure 8 | from experiments.darts_utils.utils import process_step_matrix, prune 9 | import logging 10 | from torch.distributions.dirichlet import Dirichlet 11 | from torch.distributions.kl import kl_divergence 12 | 13 | class TinyNetwork(nn.Module): 14 | 15 | def __init__(self, C, N, max_nodes, num_classes, criterion, search_space, affine=False, track_running_stats=True, k=2, species='softmax', reg_type='l2', reg_scale=1e-3): 16 | super(TinyNetwork, self).__init__() 17 | self._C = C 18 | self._layerN = N 19 | self.max_nodes = max_nodes 20 | self._criterion = criterion 21 | self.k = k 22 | self.species = species 23 | self.stem = nn.Sequential( 24 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), 25 | nn.BatchNorm2d(C)) 26 | 27 | layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N 28 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 29 | 30 | C_prev, num_edge, edge2index = C, None, None 31 | self.cells = nn.ModuleList() 32 | for _, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)): 33 | if reduction: 34 | cell = ResNetBasicblock(C_prev, C_curr, 2) 35 | else: 36 | cell = SearchCell(C_prev, C_curr, 1, max_nodes, search_space, affine, track_running_stats, k) 37 | if num_edge is None: num_edge, edge2index = cell.num_edges, cell.edge2index 38 | else: assert num_edge == cell.num_edges and edge2index == cell.edge2index, 'invalid {:} vs. {:}.'.format(num_edge, cell.num_edges) 39 | self.cells.append( cell ) 40 | C_prev = cell.out_dim 41 | self.op_names = deepcopy( search_space ) 42 | self._Layer = len(self.cells) 43 | self.edge2index = edge2index 44 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 45 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 46 | self.classifier = nn.Linear(C_prev, num_classes) 47 | self._arch_parameters = nn.Parameter( 1e-3*torch.randn(num_edge, len(search_space)) ) 48 | self.tau = 10 if species == 'gumbel' else None 49 | self._mask = None 50 | 51 | #### reg 52 | self.reg_type = reg_type 53 | self.reg_scale = reg_scale 54 | self.anchor = Dirichlet(torch.ones_like(self._arch_parameters).cuda()) 55 | 56 | def _loss(self, input, target): 57 | logits = self(input) 58 | loss = self._criterion(logits, target) 59 | if self.reg_type == 'kl': 60 | loss += self._get_kl_reg() 61 | return loss 62 | 63 | def _get_kl_reg(self): 64 | assert(self.species == 'dirichlet') # kl implemented only for Dirichlet 65 | cons = (F.elu(self._arch_parameters) + 1) 66 | q = Dirichlet(cons) 67 | p = self.anchor 68 | kl_reg = self.reg_scale * torch.sum(kl_divergence(q, p)) 69 | return kl_reg 70 | 71 | def get_weights(self): 72 | xlist = list( self.stem.parameters() ) + list( self.cells.parameters() ) 73 | xlist+= list( self.lastact.parameters() ) + list( self.global_pooling.parameters() ) 74 | xlist+= list( self.classifier.parameters() ) 75 | return xlist 76 | 77 | def set_tau(self, tau): 78 | self.tau = tau 79 | 80 | def get_tau(self): 81 | return self.tau 82 | 83 | def arch_parameters(self): 84 | return [self._arch_parameters] 85 | 86 | def show_arch_parameters(self): 87 | with torch.no_grad(): 88 | logging.info('arch-parameters :\n{:}'.format(process_step_matrix(self._arch_parameters, 'softmax', self._mask).cpu())) 89 | if self.species == 'dirichlet': 90 | logging.info('concentration :\n{:}'.format((F.elu(self._arch_parameters) + 1).cpu())) 91 | 92 | def get_message(self): 93 | string = self.extra_repr() 94 | for i, cell in enumerate(self.cells): 95 | string += '\n {:02d}/{:02d} :: {:}'.format(i, len(self.cells), cell.extra_repr()) 96 | return string 97 | 98 | def extra_repr(self): 99 | return ('{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})'.format(name=self.__class__.__name__, **self.__dict__)) 100 | 101 | def genotype(self): 102 | genotypes = [] 103 | alphas = process_step_matrix(self._arch_parameters, 'softmax', self._mask) 104 | for i in range(1, self.max_nodes): 105 | xlist = [] 106 | for j in range(i): 107 | node_str = '{:}<-{:}'.format(i, j) 108 | with torch.no_grad(): 109 | weights = alphas[ self.edge2index[node_str] ] 110 | op_name = self.op_names[ weights.argmax().item() ] 111 | xlist.append((op_name, j)) 112 | genotypes.append( tuple(xlist) ) 113 | return Structure( genotypes ) 114 | 115 | def pruning(self, num_keep): 116 | self._mask = prune(self._arch_parameters, num_keep, self._mask) 117 | 118 | def forward(self, inputs): 119 | alphas = process_step_matrix(self._arch_parameters, self.species, self._mask, self.tau) 120 | 121 | feature = self.stem(inputs) 122 | for i, cell in enumerate(self.cells): 123 | if isinstance(cell, SearchCell): 124 | feature = cell(feature, alphas) 125 | else: 126 | feature = cell(feature) 127 | 128 | out = self.lastact(feature) 129 | out = self.global_pooling( out ) 130 | out = out.view(out.size(0), -1) 131 | logits = self.classifier(out) 132 | return logits 133 | 134 | def wider(self, k): 135 | self.k = k 136 | for cell in self.cells: 137 | if isinstance(cell, SearchCell): 138 | cell.wider(k) 139 | -------------------------------------------------------------------------------- /experiments/darts_utils/search_model_gdas.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from experiments.darts_utils.cell_operations import ResNetBasicblock 8 | from experiments.darts_utils.genotypes import Structure 9 | from experiments.darts_utils.search_cells import ( 10 | NAS201SearchCell as SearchCell, 11 | ) 12 | from experiments.darts_utils.utils import process_step_matrix 13 | 14 | 15 | class TinyNetworkGDAS(nn.Module): 16 | def __init__( 17 | self, 18 | C, 19 | N, 20 | max_nodes, 21 | num_classes, 22 | criterion, 23 | search_space, 24 | affine=False, 25 | track_running_stats=True, 26 | ): 27 | super().__init__() 28 | self._C = C 29 | self._layerN = N 30 | self.max_nodes = max_nodes 31 | self._criterion = criterion 32 | self.stem = nn.Sequential( 33 | nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) 34 | ) 35 | 36 | layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N 37 | layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N 38 | 39 | C_prev, num_edge, edge2index = C, None, None 40 | self.cells = nn.ModuleList() 41 | for index, (C_curr, reduction) in enumerate( 42 | zip(layer_channels, layer_reductions) 43 | ): 44 | if reduction: 45 | cell = ResNetBasicblock(C_prev, C_curr, 2) 46 | else: 47 | cell = SearchCell( 48 | C_prev, 49 | C_curr, 50 | 1, 51 | max_nodes, 52 | search_space, 53 | affine, 54 | track_running_stats, 55 | ) 56 | if num_edge is None: 57 | num_edge, edge2index = cell.num_edges, cell.edge2index 58 | else: 59 | assert ( 60 | num_edge == cell.num_edges and edge2index == cell.edge2index 61 | ), f"invalid {num_edge} vs. {cell.num_edges}." 62 | self.cells.append(cell) 63 | C_prev = cell.out_dim 64 | self.op_names = deepcopy(search_space) 65 | self._Layer = len(self.cells) 66 | self.edge2index = edge2index 67 | self.lastact = nn.Sequential(nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True)) 68 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 69 | self.classifier = nn.Linear(C_prev, num_classes) 70 | self._arch_parameters = nn.Parameter( 71 | 1e-3 * torch.randn(num_edge, len(search_space)) 72 | ) 73 | self.tau = 10 74 | 75 | def _loss(self, input, target, updateType=None): 76 | logits = self(input, updateType) 77 | return self._criterion(logits, target) 78 | 79 | def get_weights(self): 80 | xlist = list(self.stem.parameters()) + list(self.cells.parameters()) 81 | xlist += list(self.lastact.parameters()) + list(self.global_pooling.parameters()) 82 | xlist += list(self.classifier.parameters()) 83 | return xlist 84 | 85 | def set_tau(self, tau): 86 | self.tau = tau 87 | 88 | def get_tau(self): 89 | return self.tau 90 | 91 | def arch_parameters(self): 92 | return [self._arch_parameters] 93 | 94 | def show_arch_parameters(self): 95 | with torch.no_grad(): 96 | logging.info( 97 | "arch-parameters :\n{:}".format( 98 | process_step_matrix(self._arch_parameters, "softmax", None).cpu() 99 | ) 100 | ) 101 | 102 | def get_message(self): 103 | string = self.extra_repr() 104 | for i, cell in enumerate(self.cells): 105 | string += "\n {:02d}/{:02d} :: {:}".format( 106 | i, len(self.cells), cell.extra_repr() 107 | ) 108 | return string 109 | 110 | def extra_repr(self): 111 | return "{name}(C={_C}, Max-Nodes={max_nodes}, N={_layerN}, L={_Layer})".format( 112 | name=self.__class__.__name__, **self.__dict__ 113 | ) 114 | 115 | def genotype(self): 116 | genotypes = [] 117 | for i in range(1, self.max_nodes): 118 | xlist = [] 119 | for j in range(i): 120 | node_str = f"{i}<-{j}" 121 | with torch.no_grad(): 122 | weights = self._arch_parameters[self.edge2index[node_str]] 123 | op_name = self.op_names[weights.argmax().item()] 124 | xlist.append((op_name, j)) 125 | genotypes.append(tuple(xlist)) 126 | return Structure(genotypes) 127 | 128 | def forward(self, inputs, updateType=None): 129 | while True: 130 | gumbels = -torch.empty_like(self._arch_parameters).exponential_().log() 131 | logits = (self._arch_parameters.log_softmax(dim=1) + gumbels) / self.tau 132 | probs = nn.functional.softmax(logits, dim=1) 133 | index = probs.max(-1, keepdim=True)[1] 134 | one_h = torch.zeros_like(logits).scatter_(-1, index, 1.0) 135 | hardwts = one_h - probs.detach() + probs 136 | if ( 137 | (torch.isinf(gumbels).any()) 138 | or (torch.isinf(probs).any()) 139 | or (torch.isnan(probs).any()) 140 | ): 141 | continue 142 | else: 143 | break 144 | 145 | feature = self.stem(inputs) 146 | for i, cell in enumerate(self.cells): 147 | if isinstance(cell, SearchCell): 148 | feature = cell.forward_gdas(feature, hardwts, index) 149 | else: 150 | feature = cell(feature) 151 | out = self.lastact(feature) 152 | out = self.global_pooling(out) 153 | out = out.view(out.size(0), -1) 154 | logits = self.classifier(out) 155 | 156 | return logits 157 | -------------------------------------------------------------------------------- /experiments/darts_utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import shutil 6 | import torchvision.transforms as transforms 7 | from torch.autograd import Variable 8 | 9 | 10 | class AvgrageMeter(object): 11 | 12 | def __init__(self): 13 | self.avg = 0 14 | self.reset() 15 | 16 | def reset(self): 17 | self.avg = 0 18 | self.sum = 0 19 | self.cnt = 0 20 | 21 | def update(self, val, n=1): 22 | self.sum += val * n 23 | self.cnt += n 24 | self.avg = self.sum / self.cnt 25 | 26 | 27 | def accuracy(output, target, topk=(1,)): 28 | maxk = max(topk) 29 | batch_size = target.size(0) 30 | 31 | _, pred = output.topk(maxk, 1, True, True) 32 | pred = pred.t() 33 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 34 | 35 | res = [] 36 | for k in topk: 37 | correct_k = correct[:k].view(-1).float().sum(0) 38 | res.append(correct_k.mul_(100.0/batch_size)) 39 | return res 40 | 41 | 42 | class Cutout(object): 43 | def __init__(self, length): 44 | self.length = length 45 | 46 | def __call__(self, img): 47 | h, w = img.size(1), img.size(2) 48 | mask = np.ones((h, w), np.float32) 49 | y = np.random.randint(h) 50 | x = np.random.randint(w) 51 | 52 | y1 = np.clip(y - self.length // 2, 0, h) 53 | y2 = np.clip(y + self.length // 2, 0, h) 54 | x1 = np.clip(x - self.length // 2, 0, w) 55 | x2 = np.clip(x + self.length // 2, 0, w) 56 | 57 | mask[y1: y2, x1: x2] = 0. 58 | mask = torch.from_numpy(mask) 59 | mask = mask.expand_as(img) 60 | img *= mask 61 | return img 62 | 63 | 64 | def _data_transforms_svhn(args): 65 | SVHN_MEAN = [0.4377, 0.4438, 0.4728] 66 | SVHN_STD = [0.1980, 0.2010, 0.1970] 67 | 68 | train_transform = transforms.Compose([ 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | transforms.Normalize(SVHN_MEAN, SVHN_STD), 73 | ]) 74 | if args.cutout: 75 | train_transform.transforms.append(Cutout(args.cutout_length, 76 | args.cutout_prob)) 77 | 78 | valid_transform = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize(SVHN_MEAN, SVHN_STD), 81 | ]) 82 | return train_transform, valid_transform 83 | 84 | 85 | def _data_transforms_cifar100(args): 86 | CIFAR_MEAN = [0.5071, 0.4865, 0.4409] 87 | CIFAR_STD = [0.2673, 0.2564, 0.2762] 88 | 89 | train_transform = transforms.Compose([ 90 | transforms.RandomCrop(32, padding=4), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 94 | ]) 95 | if args.cutout: 96 | train_transform.transforms.append(Cutout(args.cutout_length, 97 | args.cutout_prob)) 98 | 99 | valid_transform = transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 102 | ]) 103 | return train_transform, valid_transform 104 | 105 | 106 | def _data_transforms_cifar10(args): 107 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] 108 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] 109 | 110 | train_transform = transforms.Compose([ 111 | transforms.RandomCrop(32, padding=4), 112 | transforms.RandomHorizontalFlip(), 113 | transforms.ToTensor(), 114 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 115 | ]) 116 | if args.cutout: 117 | train_transform.transforms.append(Cutout(args.cutout_length)) 118 | 119 | valid_transform = transforms.Compose([ 120 | transforms.ToTensor(), 121 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 122 | ]) 123 | return train_transform, valid_transform 124 | 125 | 126 | def count_parameters_in_MB(model): 127 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 128 | 129 | 130 | def save_checkpoint(state, is_best, save): 131 | filename = os.path.join(save, 'checkpoint.pth.tar') 132 | torch.save(state, filename) 133 | if is_best: 134 | best_filename = os.path.join(save, 'model_best.pth.tar') 135 | shutil.copyfile(filename, best_filename) 136 | 137 | 138 | def save(model, model_path): 139 | torch.save(model.state_dict(), model_path) 140 | 141 | 142 | def load(model, model_path): 143 | model.load_state_dict(torch.load(model_path)) 144 | 145 | 146 | def drop_path(x, drop_prob): 147 | if drop_prob > 0.: 148 | keep_prob = 1.-drop_prob 149 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 150 | x.div_(keep_prob) 151 | x.mul_(mask) 152 | return x 153 | 154 | 155 | def create_exp_dir(path, scripts_to_save=None): 156 | if not os.path.exists(path): 157 | os.mkdir(path) 158 | print('Experiment dir : {}'.format(path)) 159 | 160 | if scripts_to_save is not None: 161 | os.mkdir(os.path.join(path, 'scripts')) 162 | for script in scripts_to_save: 163 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 164 | shutil.copyfile(script, dst_file) 165 | 166 | 167 | def process_step_vector(x, method, mask, tau=None): 168 | if method == 'softmax': 169 | output = F.softmax(x, dim=-1) 170 | elif method == 'dirichlet': 171 | output = torch.distributions.dirichlet.Dirichlet( 172 | F.elu(x) + 1).rsample() 173 | elif method == 'gumbel': 174 | output = F.gumbel_softmax(x, tau=tau, hard=False, dim=-1) 175 | 176 | if mask is None: 177 | return output 178 | else: 179 | output_pruned = torch.zeros_like(output) 180 | output_pruned[mask] = output[mask] 181 | output_pruned /= output_pruned.sum() 182 | assert (output_pruned[~mask] == 0.0).all() 183 | return output_pruned 184 | 185 | 186 | def process_step_matrix(x, method, mask, tau=None): 187 | weights = [] 188 | if mask is None: 189 | for line in x: 190 | weights.append(process_step_vector(line, method, None, tau)) 191 | else: 192 | for i, line in enumerate(x): 193 | weights.append(process_step_vector(line, method, mask[i], tau)) 194 | return torch.stack(weights) 195 | 196 | 197 | def prune(x, num_keep, mask, reset=False): 198 | if not mask is None: 199 | x.data[~mask] -= 1000000 200 | src, index = x.topk(k=num_keep, dim=-1) 201 | if not reset: 202 | x.data.copy_(torch.zeros_like(x).scatter(dim=1, index=index, src=src)) 203 | else: 204 | x.data.copy_(torch.zeros_like(x).scatter(dim=1, index=index, src=1e-3*torch.randn_like(src))) 205 | mask = torch.zeros_like(x, dtype=torch.bool).scatter( 206 | dim=1, index=index, src=torch.ones_like(src,dtype=torch.bool)) 207 | return mask 208 | -------------------------------------------------------------------------------- /experiments/utils/dataset_generation/data_packager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import os 4 | import random 5 | import shutil 6 | 7 | import numpy as np 8 | import torch 9 | from gen_cifartile import load_cifartile_data 10 | from gen_gutenberg import load_gutenberg 11 | from gen_language_data import load_language_data 12 | from gen_multnist_data import load_addnist_data, load_multnist_data 13 | from sklearn.model_selection import train_test_split as tts 14 | from torchvision import datasets, transforms 15 | from visualize_examples import ( 16 | load_data, 17 | show_CIFARTile, 18 | show_fashionMNIST_examples, 19 | show_mnist_examples, 20 | ) 21 | 22 | parser = argparse.ArgumentParser(description="Data packager") 23 | parser.add_argument( 24 | "--base_path", 25 | default=os.getcwd(), 26 | type=str, 27 | help="Where to save generated datasets", 28 | required=True, 29 | ) 30 | parser.add_argument( 31 | "--seed", 32 | default=0, 33 | type=int, 34 | help="Seed for dataset generation", 35 | required=False, 36 | ) 37 | parser.add_argument( 38 | "--save_image_path", 39 | default=os.getcwd(), 40 | type=str, 41 | help="Where to save examples of dataset", 42 | required=False, 43 | ) 44 | args = parser.parse_args() 45 | 46 | 47 | # convert a list of tensors into a list of np arrays 48 | def tlist_to_numpy(tlist): 49 | return [x.numpy() for x in tlist] 50 | 51 | 52 | def process_torch_dataset(name, location, verbose=True, return_data=False): 53 | # load various datasets, put into respective dirs 54 | # pylint: disable=W0632 55 | if name == "MultNIST": 56 | (train_x, train_y, _), (test_x, test_y) = load_multnist_data() 57 | elif name == "AddNIST": 58 | (train_x, train_y, _), (test_x, test_y) = load_addnist_data() 59 | elif name == "Language": 60 | (train_x, train_y), (test_x, test_y) = load_language_data( 61 | metainfo=False, verbose=False 62 | ) 63 | elif name == "Gutenberg": 64 | (train_x, train_y), (test_x, test_y) = load_gutenberg() 65 | elif name == "CIFARTile": 66 | (train_x, train_y), (test_x, test_y) = load_cifartile_data() 67 | elif name == "FashionMNIST": 68 | download = name not in os.listdir("raw_data") 69 | train_data = datasets.FashionMNIST( 70 | "raw_data/" + name, 71 | train=True, 72 | download=download, 73 | transform=transforms.Compose( 74 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 75 | ), 76 | ) 77 | test_data = datasets.FashionMNIST( 78 | "raw_data/" + name, 79 | train=False, 80 | download=download, 81 | transform=transforms.Compose( 82 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 83 | ), 84 | ) 85 | train_x, train_y = zip(*train_data) 86 | test_x, test_y = zip(*test_data) 87 | train_x = tlist_to_numpy(train_x) 88 | test_x = tlist_to_numpy(test_x) 89 | train_x, train_y = np.stack(train_x), np.array(train_y) 90 | test_x, test_y = np.stack(test_x), np.array(test_y) 91 | else: 92 | raise ValueError("Invalid dataset name!") 93 | # pylint: enable=W0632 94 | 95 | # split train data into train and valid 96 | train_x, valid_x, train_y, valid_y = tts( 97 | train_x, train_y, train_size=45000, test_size=15000 98 | ) 99 | 100 | # print out stats of label distribution across classes 101 | def sort_counter(c): 102 | return sorted(list(c.items()), key=lambda x: x[0]) 103 | 104 | if verbose: 105 | print("=== {} ===".format(name)) 106 | print("Train Bal:", sort_counter(collections.Counter(train_y))) 107 | print("Valid Bal:", sort_counter(collections.Counter(valid_y))) 108 | print("Test Bal:", sort_counter(collections.Counter(test_y))) 109 | 110 | # randomly shuffle arrays 111 | train_shuff = np.arange(len(train_y)) 112 | valid_shuff = np.arange(len(valid_y)) 113 | test_shuff = np.arange(len(test_y)) 114 | np.random.shuffle(train_shuff) 115 | np.random.shuffle(valid_shuff) 116 | np.random.shuffle(test_shuff) 117 | train_x, train_y = train_x[train_shuff], train_y[train_shuff] 118 | valid_x, valid_y = valid_x[valid_shuff], valid_y[valid_shuff] 119 | test_x, test_y = test_x[test_shuff], test_y[test_shuff] 120 | 121 | # print out data shapes of each split 122 | if verbose: 123 | print( 124 | "{} | Train: {}, {} | Valid: {}, {} | Test: {}, {} |".format( 125 | name, 126 | train_x.shape, 127 | train_y.shape, 128 | valid_x.shape, 129 | valid_y.shape, 130 | test_x.shape, 131 | test_y.shape, 132 | ) 133 | ) 134 | 135 | # name and tag datasets 136 | dataset_path = os.path.join(location, name) 137 | 138 | if return_data: 139 | return [train_x, train_y], [valid_x, valid_y], [test_x, test_y] 140 | 141 | if os.path.isdir(dataset_path): 142 | shutil.rmtree(dataset_path) 143 | 144 | os.mkdir(dataset_path) 145 | np.save(dataset_path + "/train_x.npy", train_x, allow_pickle=False) 146 | np.save(dataset_path + "/train_y.npy", train_y, allow_pickle=False) 147 | np.save(dataset_path + "/valid_x.npy", valid_x, allow_pickle=False) 148 | np.save(dataset_path + "/valid_y.npy", valid_y, allow_pickle=False) 149 | np.save(dataset_path + "/test_x.npy", test_x, allow_pickle=False) 150 | np.save(dataset_path + "/test_y.npy", test_y, allow_pickle=False) 151 | print(f"Processed dataset {name}") 152 | 153 | 154 | if __name__ == "__main__": 155 | base_path = args.base_path 156 | # set seed for reproducibility 157 | seed = args.seed 158 | np.random.seed(seed) 159 | random.seed(seed) 160 | torch.manual_seed(seed) 161 | save_image_path = args.save_image_path 162 | 163 | # load and save development datasets 164 | if not os.path.isdir("raw_data"): 165 | os.mkdir("raw_data") 166 | if not os.path.isdir(save_image_path): 167 | os.mkdir(save_image_path) 168 | for dataset in ["AddNIST", "FashionMNIST", "MultNIST", "CIFARTile"]: 169 | process_torch_dataset(dataset, location=base_path) 170 | full_path = os.path.join(base_path, dataset) 171 | x, y = load_data(full_path) 172 | if dataset == "AddNIST" or dataset == "MultNIST": 173 | show_mnist_examples(x, y, save_image_path + f"/{dataset}.png") 174 | elif dataset == "FashionMNIST": 175 | show_fashionMNIST_examples(x, y, save_image_path + f"/{dataset}.png") 176 | elif dataset == "CIFARTile": 177 | show_CIFARTile(x, y, save_image_path + f"/{dataset}.png") 178 | else: 179 | print("WARNING: No visualizer for examples!") 180 | shutil.rmtree("raw_data") 181 | -------------------------------------------------------------------------------- /experiments/utils/dataset_generation/gen_cifartile.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | 7 | 8 | def generate_n(n, class_idx, class_dict, data): 9 | xs, ys, metainfo = [], [], [] 10 | 11 | for i in range(n): 12 | print("\r{}/{}".format(i, n), end="") 13 | n_classes = np.random.choice([1, 2, 3, 4]) 14 | 15 | images = [] 16 | if n_classes == 1: 17 | c = np.random.choice(class_idx, 1, replace=False) 18 | classes = c 19 | images += list(np.random.choice(class_dict[c[0]], size=4, replace=False)) 20 | elif n_classes == 2: 21 | classes = np.random.choice(class_idx, 2, replace=False) 22 | for c in classes: 23 | images += list(np.random.choice(class_dict[c], size=2, replace=False)) 24 | 25 | elif n_classes == 3: 26 | classes = np.random.choice(class_idx, 3, replace=False) 27 | images += list( 28 | np.random.choice(class_dict[classes[0]], size=2, replace=False) 29 | ) 30 | for c in classes[1:]: 31 | images += list(np.random.choice(class_dict[c], size=1, replace=False)) 32 | else: 33 | classes = np.random.choice(class_idx, 4, replace=False) 34 | for c in classes: 35 | images += list(np.random.choice(class_dict[c], size=1, replace=False)) 36 | 37 | np.random.shuffle(images) 38 | metainfo.append(classes) 39 | out = np.zeros((3, 64, 64)) 40 | out[:, :32, :32] = data[images[0]][0] 41 | out[:, :32, 32:] = data[images[1]][0] 42 | out[:, 32:, :32] = data[images[2]][0] 43 | out[:, 32:, 32:] = data[images[3]][0] 44 | xs.append(out) 45 | ys.append(n_classes - 1) 46 | 47 | xs = np.array(xs).astype(np.float32) 48 | ys = np.array(ys).astype(np.long) 49 | return xs, ys, metainfo 50 | 51 | 52 | def load_cifartile_data(metainfo=False): 53 | data_path = os.getcwd() + "/raw_data/" 54 | dataset = "CIFAR" 55 | download = dataset not in os.listdir(data_path) 56 | 57 | MEAN = [0.49139968, 0.48215827, 0.44653124] 58 | STD = [0.24703233, 0.24348505, 0.26158768] 59 | 60 | train_data = datasets.CIFAR10( 61 | data_path + dataset, 62 | train=True, 63 | download=download, 64 | transform=transforms.Compose( 65 | [ 66 | transforms.RandomCrop(32, padding=4), 67 | transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor(), 69 | transforms.Normalize(MEAN, STD), 70 | ] 71 | ), 72 | ) 73 | test_data = datasets.CIFAR10( 74 | data_path + dataset, 75 | train=False, 76 | download=download, 77 | transform=transforms.Compose( 78 | [transforms.ToTensor(), transforms.Normalize(MEAN, STD)] 79 | ), 80 | ) 81 | 82 | train_class_dict = {} 83 | for i, (_, y) in enumerate(train_data): 84 | if y not in train_class_dict: 85 | train_class_dict[y] = [] 86 | train_class_dict[y].append(i) 87 | 88 | test_class_dict = {} 89 | for i, (_, y) in enumerate(test_data): 90 | if y not in test_class_dict: 91 | test_class_dict[y] = [] 92 | test_class_dict[y].append(i) 93 | 94 | class_idx = np.arange(10) 95 | 96 | if metainfo: 97 | train_x, train_y, metainfo = generate_n( 98 | 600, class_idx, train_class_dict, train_data 99 | ) 100 | test_x, test_y, _ = generate_n(100, class_idx, test_class_dict, test_data) 101 | 102 | return [train_x, train_y], [test_x, test_y], metainfo 103 | else: 104 | train_x, train_y, metainfo = generate_n( 105 | 60000, class_idx, train_class_dict, train_data 106 | ) 107 | test_x, test_y, _ = generate_n(10000, class_idx, test_class_dict, test_data) 108 | 109 | return [train_x, train_y], [test_x, test_y] 110 | -------------------------------------------------------------------------------- /experiments/utils/dataset_generation/gen_gutenberg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import string 4 | 5 | import numpy as np 6 | from sklearn.model_selection import train_test_split as tts 7 | 8 | 9 | def is_all_upper(l): 10 | return all([c.isupper() or c == " " for c in l]) 11 | 12 | 13 | def expand(word): 14 | return word + "_" * (phrase_word_bounds[1] - len(word)) 15 | 16 | 17 | phrase_size = 3 18 | phrase_word_bounds = 3, 6 19 | encoding = "abcdefghijklmnopqrstuvwxyz_" 20 | 21 | 22 | def encode(phrase): 23 | out = np.zeros([1, 27, phrase_size * phrase_word_bounds[1]], dtype=np.float32) 24 | for i, word in enumerate(phrase): 25 | for j, letter in enumerate(word): 26 | out[0, encoding.index(letter), j + (phrase_word_bounds[1] * i)] = 1.0 27 | return out 28 | 29 | 30 | def test_train_split(author_data, author_idx, return_metainfo=False): 31 | author_splits = {} 32 | for author in author_data.keys(): 33 | idxs = np.arange(len(author_data[author]["encodings"])) 34 | train, test = tts(idxs, train_size=11100, test_size=1000) 35 | author_splits[author] = train, test 36 | 37 | train_xs, train_ys, test_xs, test_ys = [], [], [], [] 38 | metainfo = [] 39 | for k, v in author_splits.items(): 40 | train_embeddings = [author_data[k]["encodings"][i] for i in v[0]] 41 | train_phrases = [author_data[k]["phrases"][i] for i in v[0]] 42 | test_embeddings = [author_data[k]["encodings"][i] for i in v[1]] 43 | 44 | train_xs += train_embeddings 45 | train_ys += [author_idx[k]] * len(train_embeddings) 46 | metainfo += train_phrases 47 | test_xs += test_embeddings 48 | test_ys += [author_idx[k]] * len(test_embeddings) 49 | 50 | train_xs = np.array(train_xs).astype(np.float32) 51 | train_ys = np.array(train_ys).astype(np.long) 52 | test_xs = np.array(test_xs).astype(np.float32) 53 | test_ys = np.array(test_ys).astype(np.long) 54 | 55 | if return_metainfo: 56 | return (train_xs, train_ys), (test_xs, test_ys), metainfo 57 | else: 58 | return (train_xs, train_ys), (test_xs, test_ys) 59 | 60 | 61 | def load_gutenberg(metainfo=False): 62 | texts = {} 63 | authors = {} 64 | for text in os.listdir("texts"): 65 | with open("texts/" + text, "r") as f: 66 | author = text.split(".")[0][:-1] 67 | if author in [ 68 | "aquinas", 69 | "confucius", 70 | "hawthorne", 71 | "plato", 72 | "shakespeare", 73 | "tolstoy", 74 | ]: 75 | texts[text] = f.readlines()[95:] 76 | authors[text] = author 77 | all_authors = sorted(list(set(authors.values()))) 78 | author_idx = {a: float(i) for i, a in enumerate(all_authors)} 79 | lat_letters = "abcdefghijklmnopqrstuvwxyz " 80 | _RE_COMBINE_WHITESPACE = re.compile(r"\s+") 81 | punc = string.punctuation 82 | invalids = [ 83 | "prologue", 84 | "epilogue", 85 | "chapter", 86 | "scene", 87 | "act", 88 | "ii", 89 | "iii", 90 | "iv", 91 | "v", 92 | "vi", 93 | "vii", 94 | "viii", 95 | "ix", 96 | "x", 97 | "xi", 98 | "xii", 99 | "xiii", 100 | "xiv", 101 | "xv", 102 | "xvi", 103 | "xvii", 104 | "xviii", 105 | "xix", 106 | "xx", 107 | ] 108 | 109 | text_proc = {} 110 | for k, v in texts.items(): 111 | t = [l.strip().translate(str.maketrans(punc, " " * len(punc))) for l in v] 112 | t = [l for l in t if l and not is_all_upper(l)] 113 | t = " ".join(t).lower() 114 | t = "".join([c for c in t if c in lat_letters]) 115 | t = _RE_COMBINE_WHITESPACE.sub(" ", t).strip() 116 | t = [w for w in t.split() if w not in invalids] 117 | text_proc[k] = t 118 | 119 | # extract matching phrases 120 | author_phrase_set = {} 121 | for k, v in text_proc.items(): 122 | author = authors[k] 123 | 124 | if author not in author_phrase_set: 125 | author_phrase_set[author] = set() 126 | 127 | for word in range(0, len(v) - phrase_size): 128 | phrase = [w for w in v[word : word + phrase_size]] 129 | if all( 130 | [phrase_word_bounds[0] <= len(w) <= phrase_word_bounds[1] for w in phrase] 131 | ): 132 | exp_phrase = tuple([expand(w) for w in phrase]) 133 | author_phrase_set[author].add(exp_phrase) 134 | 135 | # filter overlapping phrases 136 | author_unique_phrases = {} 137 | for a, ws in author_phrase_set.items(): 138 | no_overlap = ws 139 | for other_a, other_ws in author_phrase_set.items(): 140 | if a != other_a: 141 | no_overlap = no_overlap.difference(other_ws) 142 | author_unique_phrases[a] = no_overlap 143 | 144 | # encode phrases 145 | author_data = {} 146 | for k, v in author_unique_phrases.items(): 147 | if k not in author_data: 148 | author_data[k] = {"encodings": [], "phrases": []} 149 | for phrase in v: 150 | author_data[k]["encodings"].append(encode(phrase)) 151 | author_data[k]["phrases"].append(phrase) 152 | 153 | return test_train_split(author_data, author_idx, metainfo) 154 | -------------------------------------------------------------------------------- /experiments/utils/dataset_generation/gen_language_data.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import numpy as np 4 | from sklearn.model_selection import train_test_split as tts 5 | 6 | 7 | def get_lang(lang): 8 | dump = subprocess.Popen( 9 | ("aspell", "-d", lang, "dump", "master"), stdout=subprocess.PIPE 10 | ) 11 | expand = subprocess.check_output(("aspell", "-l", lang, "expand"), stdin=dump.stdout) 12 | dump.wait() 13 | word_exp = [x.split() for x in expand.decode("utf-8").split("\n")] 14 | return [word for words in word_exp for word in words] 15 | 16 | 17 | def retrieve_langs(lang_list): 18 | langs = {} 19 | for lang in lang_list: 20 | langs[lang] = get_lang(lang) 21 | return langs 22 | 23 | 24 | def filter_extra_chars(lang_dict): 25 | filters = "'- " 26 | for lang, words in lang_dict.items(): 27 | filt_words = [] 28 | for word in words: 29 | if not any([f in word for f in filters]): 30 | filt_words.append(word) 31 | lang_dict[lang] = filt_words 32 | return lang_dict 33 | 34 | 35 | lat_letters = "abcdefghijklmnopqrstuvwx" 36 | 37 | 38 | def latin_filter(lang_dict, verbose=True): 39 | lang_latin = {} 40 | for lang, words in lang_dict.items(): 41 | lang_latin[lang] = set() 42 | for word in words: 43 | if len([l for l in word if l not in lat_letters]) == 0 and len(word) == 6: 44 | lang_latin[lang].add(word) 45 | if verbose: 46 | print(lang, len(lang_latin[lang])) 47 | return lang_latin 48 | 49 | 50 | def overlap_filter(lang_dict): 51 | lang_no_overlap = {} 52 | for lang, words in lang_dict.items(): 53 | no_overlap = words 54 | for other_lang in lang_dict.keys(): 55 | if other_lang != lang: 56 | no_overlap = no_overlap.difference(lang_dict[other_lang]) 57 | lang_no_overlap[lang] = no_overlap 58 | return lang_no_overlap 59 | 60 | 61 | def convert(words): 62 | one_hot = True 63 | 64 | if one_hot: 65 | out = np.zeros([24, 24]) 66 | for i, word in enumerate(words): 67 | for j, letter in enumerate(word): 68 | out[i * 6 + j, lat_letters.index(letter)] = 1.0 69 | else: 70 | out = np.zeros([6, 6]) 71 | for i, word in enumerate(words): 72 | for j, letter in enumerate(word): 73 | out[i, j] = lat_letters.index(letter) / len(lat_letters) 74 | 75 | return out 76 | 77 | 78 | def test_train_split(lang_dict, return_metainfo=False): 79 | lang_splits = {} 80 | for lang, words in lang_dict.items(): 81 | train, test = tts(list(words), train_size=1500, test_size=700) 82 | lang_splits[lang] = train, test 83 | 84 | lang_groups = {} 85 | n = 4 86 | n_train = 6000 87 | n_test = 1000 88 | for lang, (train, test) in lang_splits.items(): 89 | train_groups = list( 90 | set(zip(*[np.random.choice(train, n_train + 500) for _ in range(n)])) 91 | )[:n_train] 92 | test_groups = list( 93 | set(zip(*[np.random.choice(test, n_test + 500) for _ in range(n)])) 94 | )[:n_test] 95 | lang_groups[lang] = train_groups, test_groups 96 | 97 | train_xs, train_ys = [], [] 98 | test_xs, test_ys = [], [] 99 | metainfo = [] 100 | lang_idxs = {l: i for i, l in enumerate(lang_groups.keys())} 101 | 102 | for lang, (train, test) in lang_groups.items(): 103 | train_xs += [convert(ws) for ws in train] 104 | train_ys += [lang_idxs[lang] for _ in train] 105 | metainfo += [ws for ws in train] 106 | test_xs += [convert(ws) for ws in test] 107 | test_ys += [lang_idxs[lang] for _ in test] 108 | 109 | train_xs = np.expand_dims(np.array(train_xs), axis=1).astype(np.float32) 110 | test_xs = np.expand_dims(np.array(test_xs), axis=1).astype(np.float32) 111 | train_ys = np.array(train_ys).astype(np.long) 112 | test_ys = np.array(test_ys).astype(np.long) 113 | 114 | train_shuff = np.arange(len(train_ys)) 115 | np.random.shuffle(train_shuff) 116 | test_shuff = np.arange(len(test_ys)) 117 | np.random.shuffle(test_shuff) 118 | 119 | train_xs = train_xs[train_shuff] 120 | train_ys = train_ys[train_shuff] 121 | metainfo = [metainfo[i] for i in train_shuff] 122 | test_xs = test_xs[test_shuff] 123 | test_ys = test_ys[test_shuff] 124 | 125 | if return_metainfo: 126 | return (train_xs, train_ys), (test_xs, test_ys), metainfo, lang_idxs 127 | else: 128 | return (train_xs, train_ys), (test_xs, test_ys) 129 | 130 | 131 | def load_language_data(metainfo=False, verbose=True): 132 | lang_dict = retrieve_langs( 133 | ["en", "nl", "de", "es", "fr", "pt_PT", "sw", "zu", "fi", "sv"] 134 | ) 135 | lang_dict = filter_extra_chars(lang_dict) 136 | lang_dict = latin_filter(lang_dict, verbose=verbose) 137 | return test_train_split(lang_dict, return_metainfo=metainfo) 138 | -------------------------------------------------------------------------------- /experiments/utils/dataset_generation/gen_multnist_data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | 4 | import numpy as np 5 | from torchvision import datasets, transforms 6 | 7 | 8 | def train_test_filter(op, combs, ratio): 9 | paths = {} 10 | for i, j, k in combs: 11 | val = op(i, j, k) 12 | if val not in paths: 13 | paths[val] = [] 14 | paths[val].append((i, j, k)) 15 | train_combs, test_combs = [], [] 16 | 17 | for val, paths in paths.items(): 18 | idxs = np.arange(len(paths)) 19 | if len(paths) > 1: 20 | train_idxs = np.random.choice( 21 | idxs, size=int(len(paths) * ratio), replace=False 22 | ) 23 | train_combs += [paths[i] for i in train_idxs] 24 | test_combs += [paths[i] for i in idxs if i not in train_idxs] 25 | return train_combs, test_combs 26 | 27 | 28 | def generate_examples(op, combs, weights, nums, n): 29 | x, y, metainfo = [], [], [] 30 | for n1, n2, n3 in combs: 31 | mod = op(n1, n2, n3) 32 | r_idxs = np.arange(len(nums[n1])) 33 | g_idxs = np.arange(len(nums[n2])) 34 | b_idxs = np.arange(len(nums[n3])) 35 | 36 | for _ in range(int(n * weights[mod])): 37 | r = nums[n1][np.random.choice(r_idxs)] 38 | g = nums[n2][np.random.choice(g_idxs)] 39 | b = nums[n3][np.random.choice(b_idxs)] 40 | x.append(np.vstack([r, g, b])) 41 | y.append(mod) 42 | metainfo.append([n1, n2, n3]) 43 | return x, y, metainfo 44 | 45 | 46 | def proc_weights(op, combs): 47 | weights = collections.Counter(op(i, j, k) for i, j, k in combs) 48 | weights = {k: sum(weights.values()) / v for k, v in weights.items()} 49 | return {k: v / min(weights.values()) for k, v in weights.items()} 50 | 51 | 52 | def generate_data(op, n, lb, ub): 53 | download = "MNIST" not in os.listdir("raw_data") 54 | train_data = datasets.MNIST( 55 | "raw_data/MNIST", 56 | train=True, 57 | download=download, 58 | transform=transforms.Compose( 59 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 60 | ), 61 | ) 62 | test_data = datasets.MNIST( 63 | "raw_data/MNIST", 64 | train=False, 65 | download=download, 66 | transform=transforms.Compose( 67 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 68 | ), 69 | ) 70 | 71 | train_nums = {i: [] for i in range(10)} 72 | test_nums = {i: [] for i in range(10)} 73 | 74 | for image, number in train_data: 75 | train_nums[number].append(image) 76 | for image, number in test_data: 77 | test_nums[number].append(image) 78 | 79 | combs = [ 80 | (i, j, k) 81 | for i in range(10) 82 | for j in range(10) 83 | for k in range(10) 84 | if lb <= op(i, j, k) <= ub 85 | ] 86 | train_combs, test_combs = train_test_filter(op, combs, 0.75) 87 | train_weights = proc_weights(op, train_combs) 88 | test_weights = proc_weights(op, test_combs) 89 | 90 | train_n = n 91 | test_n = int(0.33 * train_n) 92 | train_x, train_y, metainfo = generate_examples( 93 | op, combs, train_weights, train_nums, train_n 94 | ) 95 | test_x, test_y, _ = generate_examples(op, combs, test_weights, test_nums, test_n) 96 | 97 | train_x, train_y = ( 98 | np.array(train_x, dtype=np.float32).squeeze(), 99 | np.array(train_y).squeeze(), 100 | ) 101 | test_x, test_y = ( 102 | np.array(test_x, dtype=np.float32).squeeze(), 103 | np.array(test_y).squeeze(), 104 | ) 105 | 106 | train_shuff = np.arange(len(train_y)) 107 | np.random.shuffle(train_shuff) 108 | test_shuff = np.arange(len(test_y)) 109 | np.random.shuffle(test_shuff) 110 | train_x, train_y, metainfo = ( 111 | train_x[train_shuff], 112 | train_y[train_shuff], 113 | [metainfo[i] for i in train_shuff], 114 | ) 115 | test_x, test_y = test_x[test_shuff], test_y[test_shuff] 116 | return (train_x[:60000], train_y[:60000], metainfo[:60000]), ( 117 | test_x[:10000], 118 | test_y[:10000], 119 | ) 120 | 121 | 122 | def load_multnist_data(): 123 | op = lambda i, j, k: (i * j * k) % 10 124 | return generate_data(op, 20, 0, 9) 125 | 126 | 127 | def load_addnist_data(): 128 | op = lambda i, j, k: (i + j + k) - 1 129 | return generate_data(op, 200, 0, 19) 130 | -------------------------------------------------------------------------------- /experiments/utils/dataset_generation/visualize_examples.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | plt.rcParams.update({"figure.facecolor": (1.0, 1.0, 1.0, 1)}) 7 | 8 | 9 | def image_normalization(arr): 10 | return (arr - arr.min()) / (arr.max() - arr.min()) 11 | 12 | 13 | def disable_ax_ticks(ax): 14 | ax.set_xticks([]) 15 | ax.set_xticks([], minor=True) 16 | ax.set_yticks([]) 17 | ax.set_yticks([], minor=True) 18 | 19 | 20 | def show_mnist_examples(x, y, save_path): 21 | for i in range(3): 22 | fig = plt.figure(constrained_layout=True, figsize=(12, 9), dpi=100) 23 | gs = fig.add_gridspec(3, 4) 24 | main_ax = fig.add_subplot(gs[:3, :3]) 25 | fig.suptitle(f"{y[i]}") 26 | main_ax.imshow(image_normalization(np.moveaxis(x[i], 0, -1))) 27 | disable_ax_ticks(main_ax) 28 | 29 | for j in range(3): 30 | c_ax = fig.add_subplot(gs[j, -1]) 31 | subimage = x[i].copy() 32 | subimage[:j] = 0 33 | subimage[j + 1 :] = 0 34 | subimage[j] = subimage[j] - subimage[j].min() 35 | c_ax.imshow(image_normalization(np.moveaxis(subimage, 0, -1))) 36 | disable_ax_ticks(c_ax) 37 | plt.savefig(save_path[:-4] + str(i) + save_path[-4:]) 38 | plt.close() 39 | 40 | 41 | def show_fashionMNIST_examples(x, y, save_path): 42 | plt.figure(figsize=(9, 9), dpi=100) 43 | for i in range(4): 44 | plt.subplot(2, 2, i + 1) 45 | plt.title("y={}".format(y[i])) 46 | plt.imshow(image_normalization(x[i][0]), cmap="gray") 47 | plt.tight_layout() 48 | plt.savefig(save_path) 49 | plt.close() 50 | 51 | 52 | def show_CIFARTile(x, y, save_path): 53 | plt.figure(figsize=(9, 9), dpi=100) 54 | for i in range(4): 55 | plt.subplot(2, 2, i + 1) 56 | plt.title(f"Tile Classes={y[i]}") 57 | plt.imshow(np.moveaxis(image_normalization(x[i]), 0, -1)) 58 | plt.tight_layout() 59 | plt.savefig(save_path) 60 | plt.close() 61 | 62 | 63 | def load_data(path, mode="train"): 64 | if not os.path.isdir(path): 65 | raise ValueError(f"Path {path} is no valid directory!") 66 | if mode not in ["train", "valid", "test"]: 67 | raise ValueError(f"Type {mode} does not exist in {path}: {os.listdir(path)}") 68 | 69 | full_path_x = os.path.join(path, f"{mode}_x.npy") 70 | full_path_y = os.path.join(path, f"{mode}_y.npy") 71 | 72 | return np.load(full_path_x), np.load(full_path_y) 73 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | 17 | available_measures = [] 18 | _measure_impls = {} 19 | 20 | 21 | def measure(name, bn=True, copy_net=True, force_clean=True, **impl_args): 22 | def make_impl(func): 23 | def measure_impl(net_orig, device, *args, **kwargs): 24 | if copy_net: 25 | net = net_orig.get_prunable_copy(bn=bn).to(device) 26 | # set model.train() 27 | else: 28 | net = net_orig 29 | ret = func(net, *args, **kwargs, **impl_args) 30 | if copy_net and force_clean: 31 | import gc 32 | 33 | import torch 34 | 35 | del net 36 | torch.cuda.empty_cache() 37 | gc.collect() 38 | return ret 39 | 40 | global _measure_impls # pylint: disable=global-variable-not-assigned 41 | if name in _measure_impls: 42 | raise KeyError(f"Duplicated measure! {name}") 43 | available_measures.append(name) 44 | _measure_impls[name] = measure_impl 45 | return func 46 | 47 | return make_impl 48 | 49 | 50 | def calc_measure(name, net, device, *args, **kwargs): 51 | return _measure_impls[name](net, device, *args, **kwargs) 52 | 53 | 54 | # pylint: disable=unused-import 55 | from . import ( 56 | epe_nas, 57 | fisher, 58 | grad_norm, 59 | grasp, 60 | jacov, 61 | l2_norm, 62 | nwot, 63 | plain, 64 | snip, 65 | synflow, 66 | zen, 67 | ) 68 | 69 | # pylint: enable=unused-import 70 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/epe_nas.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2021 VascoLopes 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | # Implementation of EPE-NAS: Efficient Performance Estimation Without Training 25 | # for Neural Architecture Search (https://arxiv.org/abs/2102.08099) taken from the 26 | # authors' Github repository https://github.com/VascoLopes/EPE-NAS/blob/main/search.py 27 | 28 | import numpy as np 29 | import torch 30 | 31 | from . import measure 32 | 33 | 34 | def get_batch_jacobian( 35 | net, x, target, to, device, args=None 36 | ): # pylint: disable=unused-argument 37 | net.zero_grad() 38 | 39 | x.requires_grad_(True) 40 | 41 | y = net(x) 42 | 43 | y.backward(torch.ones_like(y)) 44 | jacob = x.grad.detach() 45 | 46 | return jacob, target.detach(), y.shape[-1] 47 | 48 | 49 | def eval_score_perclass(jacob, labels=None, n_classes=10): 50 | k = 1e-5 51 | 52 | per_class = {} 53 | for i, label in enumerate(labels[0]): 54 | if label in per_class: 55 | per_class[label] = np.vstack((per_class[label], jacob[i])) 56 | else: 57 | per_class[label] = jacob[i] 58 | 59 | ind_corr_matrix_score = {} 60 | for c in per_class: 61 | s = 0 62 | try: 63 | corrs = np.array(np.corrcoef(per_class[c])) 64 | 65 | s = np.sum(np.log(abs(corrs) + k)) # /len(corrs) 66 | if n_classes > 100: 67 | s /= len(corrs) 68 | except Exception: # defensive programming 69 | continue 70 | ind_corr_matrix_score[c] = s 71 | 72 | # per class-corr matrix A and B 73 | score = 0 74 | ind_corr_matrix_score_keys = ind_corr_matrix_score.keys() 75 | if n_classes <= 100: 76 | 77 | for c in ind_corr_matrix_score_keys: 78 | # B) 79 | score += np.absolute(ind_corr_matrix_score[c]) 80 | else: 81 | for c in ind_corr_matrix_score_keys: 82 | # A) 83 | for cj in ind_corr_matrix_score_keys: 84 | score += np.absolute(ind_corr_matrix_score[c] - ind_corr_matrix_score[cj]) 85 | 86 | if len(ind_corr_matrix_score_keys) > 0: 87 | # should divide by number of classes seen 88 | score /= len(ind_corr_matrix_score_keys) 89 | 90 | return score 91 | 92 | 93 | @measure("epe_nas") 94 | def compute_epe_score( 95 | net, inputs, targets, loss_fn, split_data=1 96 | ): # pylint: disable=unused-argument 97 | jacobs = [] 98 | labels = [] 99 | 100 | try: 101 | 102 | jacobs_batch, target, n_classes = get_batch_jacobian( 103 | net, inputs, targets, None, None 104 | ) 105 | jacobs.append(jacobs_batch.reshape(jacobs_batch.size(0), -1).cpu().numpy()) 106 | 107 | if len(target.shape) == 2: # Hack to handle TNB101 classification tasks 108 | target = torch.argmax(target, dim=1) 109 | 110 | labels.append(target.cpu().numpy()) 111 | 112 | jacobs = np.concatenate(jacobs, axis=0) 113 | 114 | s = eval_score_perclass(jacobs, labels, n_classes) 115 | 116 | except Exception as e: 117 | print(e) 118 | s = np.nan 119 | 120 | return s 121 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/fisher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import types 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from . import measure 23 | from .p_utils import get_layer_metric_array, reshape_elements 24 | 25 | 26 | def fisher_forward_conv2d(self, x): 27 | x = F.conv2d( 28 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups 29 | ) 30 | # intercept and store the activations after passing through 'hooked' identity op 31 | self.act = self.dummy(x) 32 | return self.act 33 | 34 | 35 | def fisher_forward_linear(self, x): 36 | x = F.linear(x, self.weight, self.bias) 37 | self.act = self.dummy(x) 38 | return self.act 39 | 40 | 41 | @measure("fisher", bn=True, mode="channel") 42 | def compute_fisher_per_weight(net, inputs, targets, loss_fn, mode, split_data=1): 43 | 44 | device = inputs.device 45 | 46 | if mode == "param": 47 | raise ValueError("Fisher pruning does not support parameter pruning.") 48 | 49 | net.train() 50 | for layer in net.modules(): 51 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 52 | # variables/op needed for fisher computation 53 | layer.fisher = None 54 | layer.act = 0.0 55 | layer.dummy = nn.Identity() 56 | 57 | # replace forward method of conv/linear 58 | if isinstance(layer, nn.Conv2d): 59 | layer.forward = types.MethodType(fisher_forward_conv2d, layer) 60 | if isinstance(layer, nn.Linear): 61 | layer.forward = types.MethodType(fisher_forward_linear, layer) 62 | 63 | # function to call during backward pass (hooked on identity op at output of layer) 64 | def hook_factory(layer): 65 | def hook( 66 | module, grad_input, grad_output 67 | ): # pylint: disable=unused-argument 68 | act = layer.act.detach() 69 | grad = grad_output[0].detach() 70 | if len(act.shape) > 2: 71 | g_nk = torch.sum((act * grad), list(range(2, len(act.shape)))) 72 | else: 73 | g_nk = act * grad 74 | del_k = g_nk.pow(2).mean(0).mul(0.5) 75 | if layer.fisher is None: 76 | layer.fisher = del_k 77 | else: 78 | layer.fisher += del_k 79 | del ( 80 | layer.act 81 | ) # without deleting this, a nasty memory leak occurs! related: https://discuss.pytorch.org/t/memory-leak-when-using-forward-hook-and-backward-hook-simultaneously/27555 82 | 83 | return hook 84 | 85 | # register backward hook on identity fcn to compute fisher info 86 | layer.dummy.register_backward_hook(hook_factory(layer)) 87 | 88 | N = inputs.shape[0] 89 | for sp in range(split_data): 90 | st = sp * N // split_data 91 | en = (sp + 1) * N // split_data 92 | 93 | net.zero_grad() 94 | outputs = net(inputs[st:en]) 95 | loss = loss_fn(outputs, targets[st:en]) 96 | loss.backward() 97 | 98 | # retrieve fisher info 99 | def fisher(layer): 100 | if layer.fisher is not None: 101 | return torch.abs(layer.fisher.detach()) 102 | else: 103 | return torch.zeros(layer.weight.shape[0]) # size=ch 104 | 105 | grads_abs_ch = get_layer_metric_array(net, fisher, mode) 106 | 107 | # broadcast channel value here to all parameters in that channel 108 | # to be compatible with stuff downstream (which expects per-parameter metrics) 109 | # TODO cleanup on the selectors/apply_prune_mask side (?) 110 | shapes = get_layer_metric_array(net, lambda l: l.weight.shape[1:], mode) 111 | 112 | grads_abs = reshape_elements(grads_abs_ch, shapes, device) 113 | 114 | return grads_abs 115 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/grad_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | import torch 16 | 17 | from . import measure 18 | from .p_utils import get_layer_metric_array 19 | 20 | 21 | @measure("grad_norm", bn=True) 22 | def get_grad_norm_arr( 23 | net, inputs, targets, loss_fn, split_data=1, skip_grad=False 24 | ): # pylint: disable=unused-argument 25 | net.zero_grad() 26 | N = inputs.shape[0] 27 | for sp in range(split_data): 28 | st = sp * N // split_data 29 | en = (sp + 1) * N // split_data 30 | 31 | outputs = net.forward(inputs[st:en]) 32 | loss = loss_fn(outputs, targets[st:en]) 33 | loss.backward() 34 | 35 | grad_norm_arr = get_layer_metric_array( 36 | net, 37 | lambda l: l.weight.grad.norm() 38 | if l.weight.grad is not None 39 | else torch.zeros_like(l.weight), 40 | mode="param", 41 | ) 42 | 43 | return grad_norm_arr 44 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/grasp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import torch 17 | import torch.autograd as autograd 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from . import measure 22 | from .p_utils import get_layer_metric_array 23 | 24 | 25 | # code from https://github.com/ultralytics/yolov5/blob/6ab589583c86f265a8a226620cc1e87c1c2266f1/utils/activations.py#L15 26 | class Hardswish(nn.Module): # export-friendly version of nn.Hardswish() 27 | @staticmethod 28 | def forward(x): 29 | # return x * F.hardsigmoid(x) # for torchscript and CoreML 30 | return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for torchscript, CoreML and ONNX 31 | 32 | 33 | def replace_hardswish(base_module: nn.Module): 34 | for name, module in base_module.named_children(): 35 | if isinstance(module, nn.Hardswish): 36 | setattr(base_module, name, Hardswish()) 37 | elif isinstance(module, nn.Module): 38 | new_module = replace_hardswish(module) 39 | setattr(base_module, name, new_module) 40 | return base_module 41 | 42 | 43 | @measure("grasp", bn=True, mode="param") 44 | def compute_grasp_per_weight( 45 | net, inputs, targets, mode, loss_fn, T=1, num_iters=1, split_data=1 46 | ): 47 | net = replace_hardswish(net) 48 | 49 | # get all applicable weights 50 | weights = [] 51 | for layer in net.modules(): 52 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 53 | weights.append(layer.weight) 54 | layer.weight.requires_grad_(True) # TODO isn't this already true? 55 | 56 | # NOTE original code had some input/target splitting into 2 57 | # I am guessing this was because of GPU mem limit 58 | net.zero_grad() 59 | N = inputs.shape[0] 60 | for sp in range(split_data): 61 | st = sp * N // split_data 62 | en = (sp + 1) * N // split_data 63 | 64 | # forward/grad pass #1 65 | grad_w = None 66 | for _ in range(num_iters): 67 | # TODO get new data, otherwise num_iters is useless! 68 | outputs = net.forward(inputs[st:en]) / T 69 | loss = loss_fn(outputs, targets[st:en]) 70 | grad_w_p = autograd.grad(loss, weights, allow_unused=True) 71 | if grad_w is None: 72 | grad_w = list(grad_w_p) 73 | else: 74 | for idx, _ in enumerate(grad_w): 75 | grad_w[idx] += grad_w_p[idx] 76 | 77 | for sp in range(split_data): 78 | st = sp * N // split_data 79 | en = (sp + 1) * N // split_data 80 | 81 | # forward/grad pass #2 82 | outputs = net.forward(inputs[st:en]) / T 83 | loss = loss_fn(outputs, targets[st:en]) 84 | grad_f = autograd.grad(loss, weights, create_graph=True, allow_unused=True) 85 | 86 | # accumulate gradients computed in previous step and call backwards 87 | z, count = 0, 0 88 | for layer in net.modules(): 89 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 90 | if grad_w[count] is not None: 91 | z += (grad_w[count].data * grad_f[count]).sum() 92 | count += 1 93 | z.backward() 94 | 95 | # compute final sensitivity metric and put in grads 96 | def grasp(layer): 97 | if layer.weight.grad is not None: 98 | return -layer.weight.data * layer.weight.grad # -theta_q Hg 99 | # NOTE in the grasp code they take the *bottom* (1-p)% of values 100 | # but we take the *top* (1-p)%, therefore we remove the -ve sign 101 | # EDIT accuracy seems to be negatively correlated with this metric, so we add -ve sign here! 102 | else: 103 | return torch.zeros_like(layer.weight) 104 | 105 | grads = get_layer_metric_array(net, grasp, mode) 106 | 107 | return grads 108 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/jacov.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """ 16 | This contains implementations of jacov based on 17 | https://github.com/BayesWatch/nas-without-training (jacov). 18 | This script was based on 19 | https://github.com/SamsungLabs/zero-cost-nas/blob/main/foresight/pruners/measures/jacob_cov.py 20 | We found this version of jacov tends to perform 21 | better. 22 | Author: Robin Ru @ University of Oxford 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | from . import measure 29 | 30 | 31 | def get_batch_jacobian(net, x, target): 32 | net.zero_grad() 33 | 34 | x.requires_grad_(True) 35 | 36 | y = net(x) 37 | 38 | y.backward(torch.ones_like(y)) 39 | jacob = x.grad.detach() 40 | 41 | return jacob, target.detach() 42 | 43 | 44 | def eval_score(jacob, labels=None): # pylint: disable=unused-argument 45 | corrs = np.corrcoef(jacob) 46 | v, _ = np.linalg.eig(corrs) 47 | k = 1e-5 48 | return -np.sum(np.log(v + k) + 1.0 / (v + k)) 49 | 50 | 51 | @measure("jacov", bn=True) 52 | def compute_jacob_cov( 53 | net, inputs, targets, split_data=1, loss_fn=None # pylint: disable=unused-argument 54 | ): 55 | try: 56 | # Compute gradients (but don't apply them) 57 | jacobs, labels = get_batch_jacobian(net, inputs, targets) 58 | jacobs = jacobs.reshape(jacobs.size(0), -1).cpu().numpy() 59 | jc = eval_score(jacobs, labels) 60 | except Exception as e: 61 | print(e) 62 | jc = np.nan 63 | 64 | return jc 65 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/l2_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | from . import measure 17 | from .p_utils import get_layer_metric_array 18 | 19 | 20 | @measure("l2_norm", copy_net=False, mode="param") 21 | def get_l2_norm_array( 22 | net, inputs, targets, mode, split_data=1, **kwargs 23 | ): # pylint: disable=unused-argument 24 | return get_layer_metric_array(net, lambda l: l.weight.norm(), mode=mode) 25 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/model_stats.py: -------------------------------------------------------------------------------- 1 | import tensorwatch as tw 2 | 3 | 4 | def get_model_stats(model, input_tensor_shape, clone_model=True) -> tw.ModelStats: 5 | # model stats is doing some hooks so do it last 6 | model_stats = tw.ModelStats(model, input_tensor_shape, clone_model=clone_model) 7 | return model_stats 8 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/nwot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """ 16 | This contains implementations of nwot based on the updated version of 17 | https://github.com/BayesWatch/nas-without-training 18 | to reflect the second version of the paper https://arxiv.org/abs/2006.04647 19 | """ 20 | 21 | import numpy as np 22 | import torch 23 | 24 | from . import measure 25 | 26 | 27 | @measure("nwot", bn=True) 28 | def compute_nwot( 29 | net, inputs, targets, split_data=1, loss_fn=None 30 | ): # pylint: disable=unused-argument 31 | 32 | batch_size = len(targets) 33 | 34 | def counting_forward_hook(module, inp, out): # pylint: disable=unused-argument 35 | inp = inp[0].view(inp[0].size(0), -1) 36 | x = (inp > 0).float() # binary indicator 37 | K = x @ x.t() 38 | K2 = (1.0 - x) @ (1.0 - x.t()) 39 | net.K = net.K + K.cpu().numpy() + K2.cpu().numpy() # hamming distance 40 | 41 | def counting_backward_hook( 42 | module, inp, out 43 | ): # pylint: disable=unused-argument,unused-variable 44 | module.visited_backwards = True 45 | 46 | net.K = np.zeros((batch_size, batch_size)) 47 | for _, module in net.named_modules(): 48 | module_type = str(type(module)) 49 | if ("ReLU" in module_type) and ("naslib" not in module_type): 50 | # module.register_full_backward_hook(counting_backward_hook) 51 | module.register_forward_hook(counting_forward_hook) 52 | 53 | x = torch.clone(inputs) 54 | net(x) 55 | s, jc = np.linalg.slogdet(net.K) # pylint: disable=unused-variable 56 | 57 | return jc 58 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/p_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | 20 | def get_some_data(train_dataloader, num_batches, device): 21 | traindata = [] 22 | dataloader_iter = iter(train_dataloader) 23 | for _ in range(num_batches): 24 | traindata.append(next(dataloader_iter)) 25 | inputs = torch.cat([a for a, _ in traindata]) 26 | targets = torch.cat([b for _, b in traindata]) 27 | inputs = inputs.to(device) 28 | targets = targets.to(device) 29 | return inputs, targets 30 | 31 | 32 | def get_some_data_grasp(train_dataloader, num_classes, samples_per_class, device): 33 | datas = [[] for _ in range(num_classes)] 34 | labels = [[] for _ in range(num_classes)] 35 | mark = dict() 36 | dataloader_iter = iter(train_dataloader) 37 | while True: 38 | inputs, targets = next(dataloader_iter) 39 | for idx in range(inputs.shape[0]): 40 | x, y = inputs[idx : idx + 1], targets[idx : idx + 1] 41 | category = y.item() 42 | if len(datas[category]) == samples_per_class: 43 | mark[category] = True 44 | continue 45 | datas[category].append(x) 46 | labels[category].append(y) 47 | if len(mark) == num_classes: 48 | break 49 | 50 | x = torch.cat([torch.cat(_, 0) for _ in datas]).to(device) 51 | y = torch.cat([torch.cat(_) for _ in labels]).view(-1).to(device) 52 | return x, y 53 | 54 | 55 | def get_layer_metric_array(net, metric, mode): 56 | metric_array = [] 57 | 58 | for layer in net.modules(): 59 | if mode == "channel" and hasattr(layer, "dont_ch_prune"): 60 | continue 61 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 62 | metric_array.append(metric(layer)) 63 | 64 | return metric_array 65 | 66 | 67 | def reshape_elements(elements, shapes, device): 68 | def broadcast_val(elements, shapes): 69 | ret_grads = [] 70 | for e, sh in zip(elements, shapes): 71 | ret_grads.append( 72 | torch.stack([torch.Tensor(sh).fill_(v) for v in e], dim=0).to(device) 73 | ) 74 | return ret_grads 75 | 76 | if isinstance(elements[0], list): 77 | outer = [] 78 | for e, sh in zip(elements, shapes): 79 | outer.append(broadcast_val(e, sh)) 80 | return outer 81 | else: 82 | return broadcast_val(elements, shapes) 83 | 84 | 85 | def count_parameters(model): 86 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 87 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/plain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import torch 17 | 18 | from . import measure 19 | from .p_utils import get_layer_metric_array 20 | 21 | 22 | @measure("plain", bn=True, mode="param") 23 | def compute_plain_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): 24 | 25 | net.zero_grad() 26 | N = inputs.shape[0] 27 | for sp in range(split_data): 28 | st = sp * N // split_data 29 | en = (sp + 1) * N // split_data 30 | 31 | outputs = net.forward(inputs[st:en]) 32 | loss = loss_fn(outputs, targets[st:en]) 33 | loss.backward() 34 | 35 | # select the gradients that we want to use for search/prune 36 | def plain(layer): 37 | if layer.weight.grad is not None: 38 | return layer.weight.grad * layer.weight 39 | else: 40 | return torch.zeros_like(layer.weight) 41 | 42 | grads_abs = get_layer_metric_array(net, plain, mode) 43 | return grads_abs 44 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/snip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import types 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from . import measure 23 | from .p_utils import get_layer_metric_array 24 | 25 | 26 | def snip_forward_conv2d(self, x): 27 | return F.conv2d( 28 | x, 29 | self.weight * self.weight_mask, 30 | self.bias, 31 | self.stride, 32 | self.padding, 33 | self.dilation, 34 | self.groups, 35 | ) 36 | 37 | 38 | def snip_forward_linear(self, x): 39 | return F.linear(x, self.weight * self.weight_mask, self.bias) 40 | 41 | 42 | @measure("snip", bn=True, mode="param") 43 | def compute_snip_per_weight(net, inputs, targets, mode, loss_fn, split_data=1): 44 | for layer in net.modules(): 45 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 46 | layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight)) 47 | layer.weight.requires_grad = False 48 | 49 | # Override the forward methods: 50 | if isinstance(layer, nn.Conv2d): 51 | layer.forward = types.MethodType(snip_forward_conv2d, layer) 52 | 53 | if isinstance(layer, nn.Linear): 54 | layer.forward = types.MethodType(snip_forward_linear, layer) 55 | 56 | # Compute gradients (but don't apply them) 57 | net.zero_grad() 58 | N = inputs.shape[0] 59 | for sp in range(split_data): 60 | st = sp * N // split_data 61 | en = (sp + 1) * N // split_data 62 | 63 | outputs = net.forward(inputs[st:en]) 64 | loss = loss_fn(outputs, targets[st:en]) 65 | loss.backward() 66 | 67 | # select the gradients that we want to use for search/prune 68 | def snip(layer): 69 | if layer.weight_mask.grad is not None: 70 | return torch.abs(layer.weight_mask.grad) 71 | else: 72 | return torch.zeros_like(layer.weight) 73 | 74 | grads_abs = get_layer_metric_array(net, snip, mode) 75 | 76 | return grads_abs 77 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/synflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Samsung Electronics Co., Ltd. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | 16 | import torch 17 | 18 | from . import measure 19 | from .p_utils import get_layer_metric_array 20 | 21 | 22 | @measure("synflow", bn=False, mode="param") 23 | @measure("synflow_bn", bn=True, mode="param") 24 | def compute_synflow_per_weight( 25 | net, inputs, targets, mode, split_data=1, loss_fn=None 26 | ): # pylint: disable=unused-argument 27 | 28 | device = inputs.device 29 | 30 | # Dummy to lazy initialize layer norm 31 | net.zero_grad() 32 | net.double() 33 | input_dim = list(inputs[0, :].shape) 34 | # inputs = torch.ones([1] + input_dim).to(device) 35 | inputs = torch.ones([1] + input_dim).double().to(device) 36 | output = net.forward(inputs) 37 | 38 | # convert params to their abs. Keep sign for converting it back. 39 | @torch.no_grad() 40 | def linearize(net): 41 | signs = {} 42 | for name, param in net.state_dict().items(): 43 | signs[name] = torch.sign(param) 44 | param.abs_() 45 | return signs 46 | 47 | # convert to orig values 48 | @torch.no_grad() 49 | def nonlinearize(net, signs): 50 | for name, param in net.state_dict().items(): 51 | if "weight_mask" not in name: 52 | param.mul_(signs[name]) 53 | 54 | # keep signs of all params 55 | signs = linearize(net.double()) 56 | 57 | # Compute gradients with input of 1s 58 | net.zero_grad() 59 | net.double() 60 | input_dim = list(inputs[0, :].shape) 61 | # inputs = torch.ones([1] + input_dim).to(device) 62 | inputs = torch.ones([1] + input_dim).double().to(device) 63 | output = net.forward(inputs) 64 | torch.sum(output).backward() 65 | 66 | # select the gradients that we want to use for search/prune 67 | def synflow(layer): 68 | if layer.weight.grad is not None: 69 | return torch.abs(layer.weight * layer.weight.grad) 70 | else: 71 | return torch.zeros_like(layer.weight) 72 | 73 | grads_abs = get_layer_metric_array(net, synflow, mode) 74 | 75 | # apply signs of all params 76 | nonlinearize(net, signs) 77 | 78 | return grads_abs 79 | -------------------------------------------------------------------------------- /experiments/zero_cost_proxies_utils/zen.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2010-2021 Alibaba Group Holding Limited. 2 | # ============================================================================= 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from . import measure 8 | 9 | 10 | def network_weight_gaussian_init(net: nn.Module): 11 | with torch.no_grad(): 12 | for m in net.modules(): 13 | if isinstance(m, nn.Conv2d): 14 | nn.init.normal_(m.weight) 15 | if hasattr(m, "bias") and m.bias is not None: 16 | nn.init.zeros_(m.bias) 17 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 18 | if m.weight is None: 19 | continue 20 | nn.init.ones_(m.weight) 21 | nn.init.zeros_(m.bias) 22 | elif isinstance(m, nn.Linear): 23 | nn.init.normal_(m.weight) 24 | if hasattr(m, "bias") and m.bias is not None: 25 | nn.init.zeros_(m.bias) 26 | else: 27 | continue 28 | 29 | return net 30 | 31 | 32 | @measure("zen", bn=True) 33 | def compute_zen_score( 34 | net, 35 | inputs, 36 | targets, # pylint: disable=unused-argument 37 | loss_fn=None, # pylint: disable=unused-argument 38 | split_data=1, # pylint: disable=unused-argument 39 | repeat=1, 40 | mixup_gamma=1e-2, 41 | fp16=False, 42 | ): 43 | nas_score_list = [] 44 | 45 | device = inputs.device 46 | dtype = torch.half if fp16 else torch.float32 47 | 48 | with torch.no_grad(): 49 | for repeat_count in range(repeat): # pylint: disable=unused-variable 50 | network_weight_gaussian_init(net) 51 | input = torch.randn( # pylint: disable=redefined-builtin 52 | size=list(inputs.shape), device=device, dtype=dtype 53 | ) 54 | input2 = torch.randn(size=list(inputs.shape), device=device, dtype=dtype) 55 | mixup_input = input + mixup_gamma * input2 56 | 57 | # output = net.forward_before_global_avg_pool(input) 58 | # mixup_output = net.forward_before_global_avg_pool(mixup_input) 59 | output = net(input) 60 | mixup_output = net(mixup_input) 61 | 62 | nas_score = torch.sum(torch.abs(output - mixup_output), dim=[1, 2, 3]) 63 | nas_score = torch.mean(nas_score) 64 | 65 | # compute BN scaling 66 | log_bn_scaling_factor = 0.0 67 | for m in net.modules(): 68 | if isinstance(m, nn.BatchNorm2d): 69 | bn_scaling_factor = torch.sqrt(torch.mean(m.running_var)) 70 | log_bn_scaling_factor += torch.log(bn_scaling_factor) 71 | nas_score = torch.log(nas_score) + log_bn_scaling_factor 72 | nas_score_list.append(float(nas_score)) 73 | 74 | avg_nas_score = float(np.mean(nas_score_list)) 75 | 76 | return avg_nas_score 77 | -------------------------------------------------------------------------------- /install_dev_utils/poetry.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e # Exit on first failure 3 | 4 | wget https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py -O get-poetry.py 5 | python get-poetry.py 6 | rm get-poetry.py 7 | 8 | echo 'Append to your .zshrc / .bashrc or run: export PATH="$HOME/.poetry/bin:$PATH"' 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "Towards Discovering Neural Architectures Scratch" 3 | version = "0.1.0" 4 | description = "This work proposes an algebraic notion of neural architecture search." 5 | authors = [ 6 | "Simon Schrodi ", 7 | "Danny Stoll ", 8 | "Binxin Ru ", 9 | "Rhea Sukthanker ", 10 | "Thomas Brox ", 11 | "Frank Hutter ", 12 | ] 13 | readme = "README.md" 14 | license = "MIT" 15 | packages = [ 16 | { include = "benchmarks" }, 17 | { include = "experiments" }, 18 | ] 19 | 20 | [tool.poetry.dependencies] 21 | python = ">=3.7.1,<3.8" 22 | torch = "^1.11.0+cu102" 23 | neps = { git = "https://github.com/automl/neps.git", branch = "hnas" } 24 | more-itertools = "^8.12.0" 25 | numpy = "^1.21.4" 26 | scipy = "^1.7.3" 27 | pandas = "^1.3.5" 28 | scikit-learn = "0.24.2" 29 | seaborn = "^0.11.2" 30 | matplotlib = "^3.5.1" 31 | timm = "^0.4.12" 32 | networkx = "^2.6.1" 33 | tqdm = "^4.61.2" 34 | nltk = "^3.6.3" 35 | path = "^16.0.0" 36 | torchmetrics = "^0.8.1" 37 | tensorboard = "^2.9.0" 38 | debugpy = "^1.4.1" 39 | 40 | [tool.poetry.dev-dependencies] 41 | jupyter = "^1.0" 42 | pre-commit = "^2.10" 43 | black = "20.8b1" 44 | isort = "^5.7" 45 | pylint = "^2.6" 46 | 47 | [tool.black] 48 | line-length = 90 49 | target-version = ['py37'] 50 | 51 | [tool.isort] 52 | profile = 'black' 53 | line_length = 90 54 | 55 | [tool.pylint.messages_control] # Can use lists now, maybe update in future 56 | disable = """ 57 | all 58 | """ 59 | enable = """ 60 | # ------------------------------------ 61 | # Spelling 62 | # ------------------------------------ 63 | invalid-characters-in-docstring,wrong-spelling-in-comment,wrong-spelling-in-docstring, 64 | # ------------------------------------ 65 | # Basic checks 66 | # ------------------------------------ 67 | not-in-loop,function-redefined,continue-in-finally,abstract-class-instantiated,star-needs-assignment-target, 68 | duplicate-argument-name,return-in-init,too-many-star-expressions,nonlocal-and-global,return-outside-function, 69 | return-arg-in-generator,invalid-star-assignment-target,bad-reversed-sequence,nonexistent-operator, 70 | yield-outside-function,init-is-generator,nonlocal-without-binding,lost-exception,assert-on-tuple, 71 | dangerous-default-value,duplicate-key,useless-else-on-loop,expression-not-assigned,confusing-with-statement, 72 | unnecessary-lambda,pointless-statement,unnecessary-pass,unreachable,eval-used,exec-used,using-constant-test, 73 | deprecated-lambda,blacklisted-name,misplaced-comparison-constant,singleton-comparison,unneeded-not, 74 | consider-iterating-dictionary,consider-using-enumerate,empty-docstring,unidiomatic-typecheck, 75 | condition-evals-to-constant, 76 | # ------------------------------------ 77 | # Async 78 | # ------------------------------------ 79 | not-async-context-manager,yield-inside-async-function, 80 | # ------------------------------------ 81 | # Typecheck 82 | # ------------------------------------ 83 | invalid-unary-operand-type,unsupported-binary-operation,not-callable,redundant-keyword-arg,assignment-from-no-return, 84 | assignment-from-none,not-context-manager,repeated-keyword,missing-kwoa,no-value-for-parameter,invalid-sequence-index, 85 | invalid-slice-index,unexpected-keyword-arg,unsupported-membership-test,unsubscriptable-object, 86 | # ------------------------------------ 87 | # Exceptions 88 | # ------------------------------------ 89 | bad-except-order,catching-non-exception,bad-exception-context,notimplemented-raised,raising-bad-type, 90 | raising-non-exception,misplaced-bare-raise,duplicate-except,nonstandard-exception,binary-op-exception, 91 | bare-except, 92 | # ------------------------------------ 93 | # Stdlib 94 | # ------------------------------------ 95 | bad-open-mode,redundant-unittest-assert,boolean-datetime,deprecated-method 96 | # ------------------------------------ 97 | # Imports 98 | # ------------------------------------ 99 | import-error,import-self,reimported,relative-import,deprecated-module,wildcard-import,misplaced-future,cyclic-import, 100 | wrong-import-position,ungrouped-imports,multiple-imports, 101 | # ------------------------------------ 102 | # Variables 103 | # ------------------------------------ 104 | unpacking-non-sequence,invalid-all-object,unbalanced-tuple-unpacking,undefined-variable,undefined-all-variable, 105 | used-before-assignment,cell-var-from-loop,global-variable-undefined,redefined-builtin,redefine-in-handler, 106 | unused-import,unused-argument,unused-wildcard-import,unused-variable,global-variable-not-assigned, 107 | undefined-loop-variable,global-statement,global-at-module-level, 108 | # ------------------------------------ 109 | # Strings 110 | # ------------------------------------ 111 | format-needs-mapping,truncated-format-string,missing-format-string-key,mixed-format-string,too-few-format-args, 112 | bad-str-strip-call,too-many-format-args,bad-format-character,format-combined-specification,bad-format-string-key, 113 | bad-format-string,missing-format-attribute,missing-format-argument-key,unused-format-string-argument, 114 | unused-format-string-key,invalid-format-index,f-string-without-interpolation 115 | # ------------------------------------ 116 | # String Constant 117 | # ------------------------------------ 118 | anomalous-unicode-escape-in-string,anomalous-backslash-in-string, 119 | # ------------------------------------ 120 | # Elif 121 | # ------------------------------------ 122 | simplifiable-if-statement, 123 | # ------------------------------------ 124 | # Logging 125 | # ------------------------------------ 126 | logging-format-truncated,logging-too-few-args,logging-too-many-args,logging-unsupported-format, 127 | # ------------------------------------ 128 | # Iterable 129 | # ------------------------------------ 130 | not-an-iterable,not-a-mapping, 131 | # ----------------------------------- 132 | # Format 133 | # ----------------------------------- 134 | bad-indentation,unnecessary-semicolon,missing-final-newline,mixed-line-endings,multiple-statements,trailing-newlines, 135 | trailing-whitespace,unexpected-line-ending-format,superfluous-parens, 136 | # ------------------------------------ 137 | # Classes 138 | # ------------------------------------ 139 | access-member-before-definition,method-hidden,assigning-non-slot,duplicate-bases,inconsistent-mro,inherit-non-class, 140 | invalid-slots,invalid-slots-object,no-method-argument,no-self-argument,unexpected-special-method-signature, 141 | non-iterator-returned,invalid-length-returned,protected-access,attribute-defined-outside-init,abstract-method, 142 | bad-staticmethod-argument,non-parent-init-called,super-init-not-called,no-classmethod-decorator, 143 | no-staticmethod-decorator,no-self-use,bad-classmethod-argument,bad-mcs-classmethod-argument,bad-mcs-method-argument, 144 | method-check-failed,invalid-bool-returned,invalid-index-returned,invalid-repr-returned,invalid-str-returned, 145 | invalid-bytes-returned,invalid-hash-returned,invalid-length-hint-returned,invalid-format-returned, 146 | invalid-getnewargs-returned,invalid-getnewargs-ex-returned,super-with-arguments 147 | """ 148 | 149 | [build-system] 150 | requires = ["poetry-core>=1.0.0"] 151 | build-backend = "poetry.core.masonry.api" 152 | --------------------------------------------------------------------------------