├── models ├── __init__.py ├── model_cnn.py └── model_rnn.py ├── tests ├── __init__.py ├── test_genetic.py ├── common4testing.py ├── test_plot_augmentation.py ├── test_modules.py └── test_search_space.py ├── gnas ├── common │ ├── __init__.py │ ├── bit_utils.py │ ├── result.py │ └── graph_draw.py ├── search_space │ ├── __init__.py │ ├── mutation.py │ ├── individual.py │ ├── search_space.py │ ├── factory.py │ ├── cross_over.py │ └── operation_space.py ├── genetic_algorithm │ ├── __init__.py │ ├── ga_results.py │ ├── population_dict.py │ └── genetic.py ├── modules │ ├── drop_path.py │ ├── __init__.py │ ├── operation_factory.py │ ├── sub_graph_module.py │ ├── cnn_block.py │ ├── rnn_layer.py │ ├── module_generator.py │ └── node_module.py └── __init__.py ├── modules ├── __init__.py ├── identity.py ├── se_block.py ├── cut_out.py ├── drop_module.py ├── cosine_annealing.py └── weight_drop.py ├── images ├── search_result_cifar10.png └── search_result_cifar100.png ├── .idea └── vcs.xml ├── configs ├── config_cnn_final_cifar10.json ├── config_cnn_final_cifar100.json ├── config_cnn_search_cifar10.json └── config_cnn_search_cifar100.json ├── common.py ├── LICENSE ├── gif_creator.py ├── cnn_utils.py ├── README.md ├── config.py ├── data.py ├── rnn_utils.py ├── plot_result.py └── main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnas/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnas/search_space/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnas/genetic_algorithm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnas/modules/drop_path.py: -------------------------------------------------------------------------------- 1 | from modules.drop_module import DropModule, DropModuleControl 2 | -------------------------------------------------------------------------------- /images/search_result_cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihabi/GeneticNAS/HEAD/images/search_result_cifar10.png -------------------------------------------------------------------------------- /images/search_result_cifar100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haihabi/GeneticNAS/HEAD/images/search_result_cifar100.png -------------------------------------------------------------------------------- /gnas/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from gnas.modules.rnn_layer import RnnSearchModule 2 | from gnas.modules.cnn_block import CnnSearchModule 3 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /modules/identity.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Identity(nn.Module): 5 | def __init__(self): 6 | super(Identity, self).__init__() 7 | 8 | def forward(self, x): 9 | return x 10 | -------------------------------------------------------------------------------- /gnas/common/bit_utils.py: -------------------------------------------------------------------------------- 1 | def vector_bits2int(arr): 2 | n = arr.shape[0] # number of columns 3 | a = arr[0] << n - 1 4 | 5 | for j in range(1, n): 6 | # "overlay" with the shifted bits of the next column 7 | a |= arr[j] << n - 1 - j 8 | return a 9 | -------------------------------------------------------------------------------- /gnas/__init__.py: -------------------------------------------------------------------------------- 1 | from gnas.search_space.factory import get_gnas_cnn_search_space, get_gnas_rnn_search_space, SearchSpaceType 2 | from gnas.genetic_algorithm.genetic import genetic_algorithm_searcher 3 | from gnas.common.result import ResultAppender 4 | from gnas import modules 5 | from gnas.common.graph_draw import draw_network 6 | -------------------------------------------------------------------------------- /gnas/genetic_algorithm/ga_results.py: -------------------------------------------------------------------------------- 1 | class GenetricResult(object): 2 | def __init__(self): 3 | self.population_list = [] 4 | self.fitness_list = [] 5 | self.fitness_full_list = [] 6 | self.population_full_list = [] 7 | 8 | def add_generation_result(self, fitness, population): 9 | self.fitness_list.append(fitness) 10 | self.population_list.append(population) 11 | 12 | def add_population_result(self, fitness, population): 13 | self.fitness_full_list.append(fitness) 14 | self.population_full_list.append(population) 15 | -------------------------------------------------------------------------------- /gnas/modules/operation_factory.py: -------------------------------------------------------------------------------- 1 | from gnas.search_space.operation_space import RnnNodeConfig, RnnInputNodeConfig, CnnNodeConfig 2 | from gnas.modules.node_module import RnnNodeModule, RnnInputNodeModule, ConvNodeModule 3 | 4 | __module_dict__ = {RnnNodeConfig: RnnNodeModule, 5 | RnnInputNodeConfig: RnnInputNodeModule, 6 | CnnNodeConfig: ConvNodeModule} 7 | 8 | 9 | def get_module(node_config, config_dict): 10 | m = __module_dict__.get(type(node_config)) 11 | if m is None: 12 | raise Exception('Can\'t find module named:' + node_config) 13 | return m(node_config, config_dict) 14 | -------------------------------------------------------------------------------- /modules/se_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SEBlock(nn.Module): 6 | def __init__(self, n_channels, rate): 7 | super(SEBlock, self).__init__() 8 | self.se_block = nn.Sequential(nn.Linear(n_channels, int(n_channels / rate)), 9 | nn.ReLU(), 10 | nn.Linear(int(n_channels / rate), n_channels), 11 | nn.Sigmoid()) 12 | 13 | def forward(self, x): 14 | x_gp = torch.mean(torch.mean(x, dim=-1), dim=-1) 15 | att = self.se_block(x_gp).unsqueeze(dim=-1).unsqueeze(dim=-1) 16 | return x * att 17 | -------------------------------------------------------------------------------- /configs/config_cnn_final_cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "batch_size_val": 1000, 4 | "n_epochs": 630, 5 | "n_blocks": 4, 6 | "n_block_type": 3, 7 | "n_nodes": 5, 8 | "n_channels": 20 , 9 | "generation_size": 20, 10 | "generation_per_epoch": 2, 11 | "full_dataset": false, 12 | "population_size": 20, 13 | "keep_size": 0, 14 | "mutation_p": 0.02, 15 | "p_cross_over": 1.0, 16 | "cross_over_type": "Block", 17 | "learning_rate": 0.05, 18 | "lr_min": 0.0001, 19 | "weight_decay": 0.0001, 20 | "dropout": 0.2, 21 | "drop_path_keep_prob": 0.9, 22 | "drop_path_start_epoch": 0, 23 | "cutout": true, 24 | "n_holes": 1, 25 | "length": 16, 26 | "LRType": "MultiStepLR", 27 | "num_class": 10, 28 | "momentum": 0.9, 29 | "aux_loss": false, 30 | "aux_scale": 0.4 31 | } -------------------------------------------------------------------------------- /configs/config_cnn_final_cifar100.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 50, 3 | "batch_size_val": 1000, 4 | "n_epochs": 630, 5 | "n_blocks": 4, 6 | "n_block_type": 3, 7 | "n_nodes": 5, 8 | "n_channels": 48 , 9 | "generation_size": 20, 10 | "generation_per_epoch": 2, 11 | "full_dataset": false, 12 | "population_size": 20, 13 | "keep_size": 0, 14 | "mutation_p": 0.01, 15 | "p_cross_over": 1.0, 16 | "cross_over_type": "Block", 17 | "learning_rate": 0.05, 18 | "lr_min": 0.0001, 19 | "weight_decay": 0.0001, 20 | "dropout": 0.2, 21 | "drop_path_keep_prob": 0.9, 22 | "drop_path_start_epoch": 0, 23 | "cutout": true, 24 | "n_holes": 1, 25 | "length": 16, 26 | "LRType": "MultiStepLR", 27 | "num_class": 10, 28 | "momentum": 0.9, 29 | "aux_loss": false, 30 | "aux_scale": 0.4 31 | } -------------------------------------------------------------------------------- /configs/config_cnn_search_cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 128, 3 | "batch_size_val": 1000, 4 | "n_epochs": 310, 5 | "n_blocks": 2, 6 | "n_block_type": 3, 7 | "n_nodes": 5, 8 | "n_channels": 20 , 9 | "generation_size": 20, 10 | "generation_per_epoch": 2, 11 | "full_dataset": false, 12 | "population_size": 20, 13 | "keep_size": 0, 14 | "mutation_p": 0.02, 15 | "p_cross_over": 1.0, 16 | "cross_over_type": "Block", 17 | "learning_rate": 0.1, 18 | "lr_min": 0.0001, 19 | "weight_decay": 0.0001, 20 | "dropout": 0.2, 21 | "drop_path_keep_prob": 1.0, 22 | "drop_path_start_epoch": 50, 23 | "cutout": true, 24 | "n_holes": 1, 25 | "length": 16, 26 | "LRType": "MultiStepLR", 27 | "num_class": 10, 28 | "momentum": 0.9, 29 | "aux_loss": false, 30 | "aux_scale": 0.4 31 | } -------------------------------------------------------------------------------- /configs/config_cnn_search_cifar100.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "batch_size_val": 1000, 4 | "n_epochs": 310, 5 | "n_blocks": 2, 6 | "n_block_type": 3, 7 | "n_nodes": 5, 8 | "n_channels": 48 , 9 | "generation_size": 20, 10 | "generation_per_epoch": 2, 11 | "full_dataset": false, 12 | "population_size": 20, 13 | "keep_size": 0, 14 | "mutation_p": 0.01, 15 | "p_cross_over": 1.0, 16 | "cross_over_type": "Block", 17 | "learning_rate": 0.05, 18 | "lr_min": 0.0001, 19 | "weight_decay": 0.0001, 20 | "dropout": 0.2, 21 | "drop_path_keep_prob": 1.0, 22 | "drop_path_start_epoch": 50, 23 | "cutout": true, 24 | "n_holes": 1, 25 | "length": 16, 26 | "LRType": "MultiStepLR", 27 | "num_class": 10, 28 | "momentum": 0.9, 29 | "aux_loss": false, 30 | "aux_scale": 0.4 31 | } -------------------------------------------------------------------------------- /tests/test_genetic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import unittest 3 | import gnas 4 | 5 | 6 | class TestGenetic(unittest.TestCase): 7 | 8 | def test_search_cnn_space(self): 9 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNSingleCell) 10 | ga = gnas.genetic_algorithm_searcher(ss, population_size=20, generation_size=20, p_cross_over=0.8) 11 | for i in range(30): 12 | for i, ind in enumerate(ga.get_current_generation()): 13 | ga.sample_child() 14 | ga.update_current_individual_fitness(ind, 0 + np.random.rand(1)) 15 | ga.update_population() 16 | self.assertTrue(len(ga.max_dict) <= 200) 17 | self.assertTrue(len(ga.generation)) 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import datetime 4 | from enum import Enum 5 | 6 | 7 | class ModelType(Enum): 8 | CNN = 0 9 | RNN = 1 10 | 11 | 12 | def make_log_dir(config): 13 | log_dir = os.path.join('.', 'logs', datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) 14 | os.makedirs(log_dir, exist_ok=True) 15 | return log_dir 16 | 17 | 18 | def load_final(model, search_dir): 19 | ind_file = os.path.join(search_dir, 'best_individual.pickle') 20 | ind = pickle.load(open(ind_file, "rb")) 21 | model.set_individual(ind) 22 | return ind 23 | 24 | 25 | def get_model_type(dataset_name): 26 | if dataset_name in ['CIFAR10', 'CIFAR100']: 27 | return ModelType.CNN 28 | elif dataset_name == 'PTB': 29 | return ModelType.RNN 30 | else: 31 | raise Exception('unkown model for dataset:' + dataset_name) 32 | -------------------------------------------------------------------------------- /gnas/common/result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | class ResultAppender(object): 6 | def __init__(self): 7 | self.result_dict = dict() 8 | 9 | def add_epoch_result(self, result_name: str, result_var: float): 10 | if self.result_dict.get(result_name) is None: 11 | self.result_dict.update({result_name: [result_var]}) 12 | else: 13 | self.result_dict.get(result_name).append(result_var) 14 | 15 | def add_result(self, result_name: str, result_array): 16 | self.result_dict.update({result_name: result_array}) 17 | 18 | def save_result(self, input_path): 19 | pickle.dump(self, open(os.path.join(input_path, 'ga_result.pickle'), "wb")) 20 | 21 | @staticmethod 22 | def load_result(input_path): 23 | return pickle.load(open(os.path.join(input_path, 'ga_result.pickle'), "rb")) 24 | -------------------------------------------------------------------------------- /tests/common4testing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gnas.search_space.operation_space import RnnInputNodeConfig, RnnNodeConfig, CnnNodeConfig 3 | from gnas.search_space.search_space import SearchSpace 4 | from modules.drop_module import DropModuleControl 5 | import gnas 6 | 7 | 8 | def generate_ss(): 9 | nll = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid'] 10 | node_config_list = [RnnInputNodeConfig(2, [0, 1], nll)] 11 | for i in range(12): 12 | node_config_list.append(RnnNodeConfig(3 + i, list(np.linspace(2, 2 + i, 1 + i).astype('int')), nll)) 13 | ss = SearchSpace(node_config_list) 14 | return ss 15 | 16 | 17 | def generate_ss_cnn(): 18 | nll = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid'] 19 | op = ['Conv3x3', 'Dw3x3', 'Conv5x5', 'Dw5x5'] 20 | dp_control = DropModuleControl(1) 21 | node_config_list = [CnnNodeConfig(2, [0, 1], op, dp_control)] 22 | for i in range(3): 23 | node_config_list.append(CnnNodeConfig(3 + i, list(np.linspace(0, 2 + i, 3 + i).astype('int')), op, dp_control)) 24 | ss = SearchSpace(node_config_list) 25 | return ss 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 HVH 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /gnas/search_space/mutation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gnas.search_space.individual import Individual, MultipleBlockIndividual 3 | 4 | 5 | def flip_max_value(current_value, max_value, p): 6 | flip = np.floor(np.random.rand(current_value.shape[0]) + p).astype('int') 7 | sign = (2 * (np.round(np.random.rand(current_value.shape[0])) - 0.5)).astype('int') 8 | new_dna = current_value + flip * sign 9 | new_dna[new_dna > max_value] = 0 10 | new_dna[new_dna < 0] = max_value[new_dna < 0] 11 | return new_dna 12 | 13 | 14 | def _individual_flip_mutation(individual_a, p) -> Individual: 15 | max_values = individual_a.ss.get_max_values_vector(index=individual_a.index) 16 | new_iv = [] 17 | for m, iv in zip(max_values, individual_a.iv): 18 | new_iv.append(flip_max_value(iv, m, p)) 19 | return individual_a.update_individual(new_iv) 20 | 21 | 22 | def individual_flip_mutation(individual_a, p): 23 | if isinstance(individual_a, Individual): 24 | return _individual_flip_mutation(individual_a, p) 25 | else: 26 | return MultipleBlockIndividual([_individual_flip_mutation(inv, p) for inv in individual_a.individual_list]) 27 | -------------------------------------------------------------------------------- /gif_creator.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import glob 3 | import os 4 | from PIL import Image 5 | from PIL import ImageFont 6 | from PIL import ImageDraw 7 | import numpy as np 8 | 9 | images = dict() 10 | epoch_dict = dict() 11 | log_dir = '/data/projects/GNAS/logs/2019_02_17_20_25_42' 12 | 13 | font = ImageFont.truetype("/home/haih/Downloads/Untitled Folder/Microsoft Sans Serif.ttf", 16) 14 | for filename in glob.glob(log_dir + "/*.png"): 15 | layer_index = int(str.split(filename, '_')[-1].split('.png')[0]) 16 | epoch = int(str.split(filename, '_')[-2]) 17 | img = imageio.imread(filename) 18 | img_pil = Image.fromarray(img) 19 | draw = ImageDraw.Draw(img_pil) 20 | 21 | draw.text((0, 0), 'Epoch:' + str(epoch), (0, 0, 0), font=font) 22 | img = np.array(img_pil.getdata()).reshape(img_pil.size[1], img_pil.size[0], 4) 23 | if images.get(layer_index) is None: 24 | images.update({layer_index: [img]}) 25 | epoch_dict.update({layer_index: [epoch]}) 26 | else: 27 | images.get(layer_index).append(img) 28 | epoch_dict.get(layer_index).append(epoch) 29 | 30 | # Reorder 31 | for k, v in epoch_dict.items(): 32 | index = np.argsort(np.asarray(v)).astype('int') 33 | images.update({k: [np.asarray(images.get(k))[i] for i in index]}) 34 | 35 | for k, v in images.items(): 36 | imageio.mimsave(os.path.join(log_dir, 'cell_movie_' + str(k) + '.gif'), v, duration=0.5) 37 | -------------------------------------------------------------------------------- /modules/cut_out.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | # copy from https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 6 | class Cutout(object): 7 | """Randomly mask out one or more patches from an image. 8 | Args: 9 | n_holes (int): Number of patches to cut out of each image. 10 | length (int): The length (in pixels) of each square patch. 11 | """ 12 | 13 | def __init__(self, n_holes, length): 14 | self.n_holes = n_holes 15 | self.length = length 16 | 17 | def __call__(self, img): 18 | """ 19 | Args: 20 | img (Tensor): Tensor image of size (C, H, W). 21 | Returns: 22 | Tensor: Image with n_holes of dimension length x length cut out of it. 23 | """ 24 | h = img.size(1) 25 | w = img.size(2) 26 | 27 | mask = np.ones((h, w), np.float32) 28 | 29 | for n in range(self.n_holes): 30 | y = np.random.randint(h) 31 | x = np.random.randint(w) 32 | 33 | y1 = np.clip(y - self.length // 2, 0, h) 34 | y2 = np.clip(y + self.length // 2, 0, h) 35 | x1 = np.clip(x - self.length // 2, 0, w) 36 | x2 = np.clip(x + self.length // 2, 0, w) 37 | 38 | mask[y1: y2, x1: x2] = 0. 39 | 40 | mask = torch.from_numpy(mask) 41 | mask = mask.expand_as(img) 42 | img = img * mask 43 | 44 | return img 45 | -------------------------------------------------------------------------------- /modules/drop_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.cuda 3 | import torch.nn as nn 4 | from random import random 5 | from torch.autograd import Variable 6 | 7 | 8 | class DropModuleControl(object): 9 | def __init__(self, drop_prob=0.9): 10 | self.drop_prob = drop_prob 11 | self.status = False 12 | 13 | def enable(self): 14 | self.status = True 15 | 16 | 17 | class DropModule(nn.Module): 18 | def __init__(self, module, drop_control: DropModuleControl): 19 | super(DropModule, self).__init__() 20 | self.module = module 21 | self.shape = None 22 | self.drop_control = drop_control 23 | self.tensor_init = torch.FloatTensor 24 | 25 | def update_tensor_shape(self, *input): 26 | if self.shape is None: 27 | output_tensor = self.module(*input) 28 | self.shape = output_tensor.size() # fetch tensor shape 29 | if output_tensor.data.is_cuda: self.tensor_init = torch.cuda.FloatTensor 30 | 31 | def forward(self, *input): 32 | self.update_tensor_shape(*input) 33 | if self.training and self.drop_control.status: 34 | if random() <= self.drop_control.drop_prob: # forward module tensor 35 | return self.module(*input) / self.drop_control.drop_prob # Apply scaling 36 | else: # forward zero tensor 37 | return Variable(self.tensor_init(torch.Size([input[0].shape[0], *list(self.shape[1:])])).zero_()) 38 | else: # Inference 39 | return self.module(*input) 40 | -------------------------------------------------------------------------------- /cnn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.cuda 3 | 4 | 5 | def evaluate_single(input_individual, input_model, data_loader, device): 6 | correct = 0 7 | total = 0 8 | input_model = input_model.eval() 9 | input_model.set_individual(input_individual) 10 | with torch.no_grad(): 11 | for data in data_loader: 12 | images, labels = data 13 | images = images.to(device) 14 | labels = labels.to(device) 15 | outputs = input_model(images) 16 | _, predicted = torch.max(outputs[0].data, 1) 17 | total += labels.size(0) 18 | correct += (predicted == labels).sum().item() 19 | return 100 * correct / total 20 | 21 | 22 | def evaluate_individual_list(input_individual_list, ga, input_model, data_loader, device): 23 | correct = 0 24 | total = 0 25 | input_model = input_model.eval() 26 | i = 0 27 | with torch.no_grad(): 28 | while len(input_individual_list) > i: 29 | for data in data_loader: 30 | if len(input_individual_list) <= i: 31 | pass 32 | else: 33 | ind = input_individual_list[i] 34 | input_model.set_individual(ind) 35 | images, labels = data 36 | images = images.to(device) 37 | labels = labels.to(device) 38 | outputs = input_model(images) 39 | _, predicted = torch.max(outputs[0].data, 1) 40 | total += labels.size(0) 41 | correct += (predicted == labels).sum().item() 42 | acc = 100 * correct / total 43 | ga.update_current_individual_fitness(ind, acc) 44 | i += 1 45 | -------------------------------------------------------------------------------- /tests/test_plot_augmentation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from modules.cut_out import Cutout 6 | 7 | from matplotlib import pyplot as plt 8 | import numpy as np 9 | 10 | 11 | class MyTestCase(unittest.TestCase): 12 | def test_something(self): 13 | train_transform = transforms.Compose([]) 14 | # normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 15 | # std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 16 | train_transform.transforms.append(transforms.RandomCrop(32, padding=4)) 17 | train_transform.transforms.append(transforms.RandomHorizontalFlip()) 18 | train_transform.transforms.append(transforms.ToTensor()) 19 | # train_transform.transforms.append(normalize) 20 | train_transform.transforms.append(Cutout(n_holes=1, length=16)) 21 | 22 | trainset = torchvision.datasets.CIFAR10(root='./dataset', train=True, 23 | download=True, transform=train_transform) 24 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, 25 | shuffle=True, num_workers=4) 26 | ds_train, l = next(iter(trainloader)) 27 | 28 | 29 | fig = plt.figure(figsize=(8, 8)); 30 | columns = 4 31 | rows = 2 32 | for i in range(1, columns * rows + 1): 33 | img_xy = np.random.randint(len(ds_train)); 34 | img = ds_train[img_xy][:, :, :].numpy() 35 | img=np.transpose(img,[1,2,0]) 36 | fig.add_subplot(rows, columns, i) 37 | # plt.title(labels_map[int(ds_train[img_xy][1].numpy())]) 38 | plt.axis('off') 39 | plt.imshow(img) 40 | plt.show() 41 | 42 | 43 | if __name__ == '__main__': 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /gnas/search_space/individual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Individual(object): 5 | def __init__(self, individual_vector, max_inputs, search_space, index=0): 6 | self.iv = individual_vector 7 | self.mi = max_inputs 8 | self.ss = search_space 9 | # Generate config when generating individual 10 | self.index = index 11 | self.config_list = [oc.parse_config(iv) for iv, oc in zip(self.iv, self.ss.get_opeartion_config(self.index))] 12 | self.code = np.concatenate(self.iv, axis=0) 13 | 14 | def get_length(self): 15 | return len(self.code) 16 | 17 | def get_n_op(self): 18 | return len(self.iv) 19 | 20 | def copy(self): 21 | return Individual(self.iv, self.mi, self.ss, index=self.index) 22 | 23 | def generate_node_config(self): 24 | return self.config_list 25 | 26 | def update_individual(self, individual_vector): 27 | return Individual(individual_vector, self.mi, self.ss, index=self.index) 28 | 29 | def __eq__(self, other): 30 | return np.array_equal(self.code, other.code) 31 | 32 | def __str__(self): 33 | return "code:" + str(self.code) 34 | 35 | def __hash__(self): 36 | return hash(str(self)) 37 | 38 | 39 | class MultipleBlockIndividual(object): 40 | def __init__(self, individual_list): 41 | self.individual_list = individual_list 42 | self.code = np.concatenate([i.code for i in self.individual_list]) 43 | 44 | def get_individual(self, index): 45 | return self.individual_list[index] 46 | 47 | def generate_node_config(self, index): 48 | return self.individual_list[index].generate_node_config() 49 | 50 | def update_individual(self, individual_vector): 51 | raise NotImplemented 52 | 53 | def __eq__(self, other): 54 | return np.array_equal(self.code, other.code) 55 | 56 | def __str__(self): 57 | return "code:" + str(self.code) 58 | 59 | def __hash__(self): 60 | return hash(str(self)) 61 | -------------------------------------------------------------------------------- /modules/cosine_annealing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.optim as optim 3 | 4 | 5 | # copy from:https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py 6 | # The modification list: 7 | # 1. add T_mul to the init function 8 | # 2. change forward function to multiply the LR. 9 | 10 | class CosineAnnealingLR(optim.lr_scheduler._LRScheduler): 11 | r"""Set the learning rate of each parameter group using a cosine annealing 12 | schedule, where :math:`\eta_{max}` is set to the initial lr and 13 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 14 | 15 | .. math:: 16 | 17 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 18 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 19 | 20 | When last_epoch=-1, sets initial lr as lr. 21 | 22 | It has been proposed in 23 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 24 | implements the cosine annealing part of SGDR, and not the restarts. 25 | 26 | Args: 27 | optimizer (Optimizer): Wrapped optimizer. 28 | T_max (int): Maximum number of iterations. 29 | eta_min (float): Minimum learning rate. Default: 0. 30 | last_epoch (int): The index of last epoch. Default: -1. 31 | 32 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 33 | https://arxiv.org/abs/1608.03983 34 | """ 35 | 36 | def __init__(self, optimizer, T_max, T_mul, eta_min=0, last_epoch=-1): 37 | self.T_max = T_max 38 | self.T_mul = T_mul 39 | self.eta_min = eta_min 40 | super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) 41 | 42 | def get_lr(self): 43 | lr = [self.eta_min + (base_lr - self.eta_min) * 44 | (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 45 | for base_lr in self.base_lrs] 46 | if self.last_epoch != 0 and self.last_epoch % self.T_max == 0: 47 | self.T_max = self.T_mul * self.T_max 48 | self.last_epoch = 0 49 | return lr 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Genetic Neural Architecture Search (GeneticNAS) 2 | The genetic neural architecture search (GeneticNAS) is a neural architecture search method that is based on genetic algorithm which utilized weight sharing accross all candidate network. The project paper:https://arxiv.org/abs/1907.02871 3 | 4 | Includes code for CIFAR-10 and CIFAR-100 image classification 5 | 6 | # Installation 7 | The first is install all the flowing prerequisites using conda: 8 | * pytorch 9 | * graphviz 10 | * pygraphviz 11 | * numpy 12 | 13 | ```javascript 14 | conda install graphviz 15 | conda install pytorch torchvision cudatoolkit=9.0 -c pytorch 16 | conda install pygraphviz 17 | conda install numpy 18 | ``` 19 | 20 | # Examples Run Search 21 | In this section provide exmaple of how to run architecture search on there dataset CIFAR10 and CIFAR100, at the end of search a log folder is create under the current folder 22 | #### CIFAR 10 23 | ```javascript 24 | python main.py --dataset_name CIFAR10 --config_file ./configs/config_cnn_search_cifar10.json 25 | ``` 26 | #### CIFAR 100 27 | ```javascript 28 | python main.py --dataset_name CIFAR100 --config_file ./configs/config_cnn_search_cifar100.json 29 | ``` 30 | 31 | # Examples Run Final Training 32 | In this section provide exmaple of how to run final training search on there dataset CIFAR10 and CIFAR100, where $LOG_DIR is the log folder of the search result. 33 | #### CIFAR 10 34 | ```javascript 35 | python main.py --dataset_name CIFAR10 --final 1 --serach_dir $LOG_DIR --config_file ./configs/config_cnn_final_cifar10.json 36 | ``` 37 | #### CIFAR 100 38 | ```javascript 39 | python main.py --dataset_name CIFAR100 --final 1 --serach_dir $LOG_DIR --config_file ./configs/config_cnn_final_cifar10.json 40 | ``` 41 | 42 | # Result 43 | 44 | ## CIFAR10 Counvulation Cell 45 | ![Screenshot](images/search_result_cifar10.png) 46 | 47 | 48 | ## CIFAR100 Counvulation Cell 49 | ![Screenshot](images/search_result_cifar100.png) 50 | 51 | ## Counvulation cell final result 52 | | Dataset | Accuracy[%] | 53 | | --- | --- | 54 | | CIFAR10 | 96% | 55 | | CIFAR100 | 80.1% | 56 | -------------------------------------------------------------------------------- /gnas/modules/sub_graph_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from gnas.search_space.individual import Individual 5 | from gnas.modules.operation_factory import get_module 6 | 7 | 8 | class SubGraphModule(nn.Module): 9 | def __init__(self, search_space, config_dict, individual_index=0): 10 | super(SubGraphModule, self).__init__() 11 | self.ss = search_space 12 | self.config_dict = config_dict 13 | self.individual_index = individual_index 14 | if self.ss.single_block: 15 | self.block_modules = [get_module(oc, config_dict) for oc in self.ss.ocl] 16 | else: 17 | self.block_modules = [get_module(oc, config_dict) for oc in 18 | self.ss.ocl[individual_index]] 19 | # 20 | [self.add_module('Node' + str(i), n) for i, n in enumerate(self.block_modules)] 21 | 22 | def forward(self, *input_list): 23 | # input list at start is h_n and h_(n-1) 24 | net = list(input_list) 25 | for nm in self.block_modules: # loop over all blocks 26 | net.append(nm(net)) # call each block in the sub graph 27 | return net # output list of all block in the sub graph 28 | 29 | def set_individual(self, individual: Individual): 30 | if not self.ss.single_block: 31 | individual = individual.get_individual(self.individual_index) 32 | si_list = [] 33 | for nc, nm in zip(individual.generate_node_config(), self.block_modules): 34 | nm.set_current_node_config(nc) 35 | if nm.__dict__.get('select_index') is not None: 36 | si_list.append(nm.select_index) 37 | 38 | current_node_list = np.unique(si_list) 39 | if self.ss.single_block: 40 | self.avg_index = np.asarray([n.node_id for n in self.ss.ocl if n.node_id not in current_node_list]).astype( 41 | 'int') 42 | else: 43 | self.avg_index = np.asarray( 44 | [n.node_id for n in self.ss.ocl[self.individual_index] if n.node_id not in current_node_list]).astype( 45 | 'int') 46 | -------------------------------------------------------------------------------- /modules/weight_drop.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.nn import Parameter 4 | 5 | # This module is copy from https://github.com/salesforce/awd-lstm-lm/blob/master/weight_drop.py 6 | class WeightDrop(torch.nn.Module): 7 | def __init__(self, module, weights, dropout=0, variational=False): 8 | super(WeightDrop, self).__init__() 9 | self.module = module 10 | self.weights = weights 11 | self.dropout = dropout 12 | self.variational = variational 13 | self._setup() 14 | 15 | def widget_demagnetizer_y2k_edition(*args, **kwargs): 16 | # We need to replace flatten_parameters with a nothing function 17 | # It must be a function rather than a lambda as otherwise pickling explodes 18 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! 19 | # (╯°□°)╯︵ ┻━┻ 20 | return 21 | 22 | def _setup(self): 23 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN 24 | if issubclass(type(self.module), torch.nn.RNNBase): 25 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition 26 | 27 | for name_w in self.weights: 28 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) 29 | w = getattr(self.module, name_w) 30 | del self.module._parameters[name_w] 31 | self.module.register_parameter(name_w + '_raw', Parameter(w.data)) 32 | 33 | def _setweights(self): 34 | for name_w in self.weights: 35 | raw_w = getattr(self.module, name_w + '_raw') 36 | w = None 37 | if self.variational: 38 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) 39 | if raw_w.is_cuda: mask = mask.cuda() 40 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) 41 | w = mask.expand_as(raw_w) * raw_w 42 | else: 43 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) 44 | setattr(self.module, name_w, w) 45 | 46 | def forward(self, *args): 47 | self._setweights() 48 | return self.module.forward(*args) 49 | -------------------------------------------------------------------------------- /gnas/search_space/search_space.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gnas.search_space.individual import Individual, MultipleBlockIndividual 3 | 4 | 5 | class SearchSpace(object): 6 | def __init__(self, operation_config_list: list, single_block=True): 7 | self.single_block = single_block 8 | self.ocl = operation_config_list 9 | if single_block: 10 | self.n_elements = sum([len(self.generate_vector(o.max_values_vector(i))) for i, o in enumerate(self.ocl)]) 11 | else: 12 | self.n_elements = sum( 13 | [sum([len(self.generate_vector(o.max_values_vector(i))) for i, o in enumerate(block)]) for block in 14 | self.ocl]) 15 | 16 | def get_operation_configs(self): 17 | return self.ocl 18 | 19 | def get_n_nodes(self): 20 | if self.single_block: 21 | return len(self.ocl) 22 | else: 23 | return [len(ocl) for ocl in self.ocl] 24 | 25 | def get_max_values_vector(self, index=0): 26 | if self.single_block: 27 | return [o.max_values_vector(i) for i, o in enumerate(self.ocl)] 28 | else: 29 | return [o.max_values_vector(i) for i, o in enumerate(self.ocl[index])] 30 | 31 | def get_opeartion_config(self, index=0): 32 | if self.single_block: 33 | return self.ocl 34 | else: 35 | return self.ocl[index] 36 | 37 | def generate_vector(self, max_values): 38 | return np.asarray([np.random.randint(0, mv + 1) for mv in max_values]) 39 | 40 | def _generate_individual_single(self, ocl, index=0): 41 | operation_vector = [self.generate_vector(o.max_values_vector(i)) for i, o in enumerate(ocl)] 42 | max_inputs = [i for i, _ in enumerate(ocl)] 43 | return Individual(operation_vector, max_inputs, self, index=index) 44 | 45 | def generate_individual(self): 46 | if self.single_block: 47 | return self._generate_individual_single(self.ocl) 48 | else: 49 | return MultipleBlockIndividual( 50 | [self._generate_individual_single(ocl, index=i) for i, ocl in enumerate(self.ocl)]) 51 | 52 | def generate_population(self, size): 53 | return [self.generate_individual() for _ in range(size)] 54 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from common import ModelType 4 | 5 | 6 | def save_config(path_dir, config): 7 | with open(os.path.join(path_dir, 'config.json'), 'w') as outfile: 8 | json.dump(config, outfile) 9 | 10 | 11 | def load_config(path_dir): 12 | with open(path_dir, 'r') as json_file: 13 | data = json.load(json_file) 14 | return data 15 | 16 | 17 | def get_config(model_type): 18 | if ModelType.CNN == model_type: 19 | return default_config_cnn() 20 | elif ModelType.RNN == model_type: 21 | return default_config_rnn() 22 | else: 23 | raise Exception('unkown model type:' + str(model_type)) 24 | 25 | 26 | def default_config_rnn(): 27 | return {'batch_size': 64, 28 | 'batch_size_val': 10, 29 | 'bptt': 35, 30 | 'n_epochs': 310, 31 | 'n_blocks': 2, 32 | 'n_nodes': 12, 33 | 'n_channels': 200, 34 | 'clip': 0.25, 35 | 'generation_size': 20, 36 | 'population_size': 20, 37 | 'keep_size': 0, 38 | 'mutation_p': 0.02, 39 | 'p_cross_over': 1.0, 40 | 'cross_over_type': 'Block', 41 | 'learning_rate': 20.0, 42 | 'weight_decay': 0.0001, 43 | 'dropout': 0.2, 44 | 'LRType': 'ExponentialLR', 45 | 'gamma': 0.96} 46 | 47 | 48 | def default_config_cnn(): 49 | return {'batch_size': 128, 50 | 'batch_size_val': 1000, 51 | 'n_epochs': 310, 52 | 'n_blocks': 2, 53 | 'n_block_type': 3, 54 | 'n_nodes': 5, 55 | 'n_channels': 20, 56 | 'generation_size': 20, 57 | 'generation_per_epoch': 2, 58 | 'full_dataset': False, 59 | 'population_size': 20, 60 | 'keep_size': 0, 61 | 'mutation_p': 0.02, 62 | 'p_cross_over': 1.0, 63 | 'cross_over_type': 'Block', 64 | 'learning_rate': 0.1, 65 | 'lr_min': 0.0001, 66 | 'weight_decay': 0.0001, 67 | 'dropout': 0.2, 68 | 'drop_path_keep_prob': 1.0, 69 | 'drop_path_start_epoch': 50, 70 | 'cutout': True, 71 | 'n_holes': 1, 72 | 'length': 16, 73 | 'LRType': 'MultiStepLR', 74 | 'num_class': 10, 75 | 'momentum': 0.9, 76 | 'aux_loss': False, 77 | 'aux_scale': 0.4} 78 | -------------------------------------------------------------------------------- /gnas/search_space/factory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gnas.search_space.search_space import SearchSpace 3 | from gnas.search_space.operation_space import CnnNodeConfig, RnnNodeConfig, RnnInputNodeConfig 4 | from enum import Enum 5 | 6 | CNN_OP = ['Dw3x3', 'Identity', 'Dw5x5', 'Avg3x3', 'Max3x3'] 7 | RNN_OP = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid'] 8 | 9 | 10 | class SearchSpaceType(Enum): 11 | CNNSingleCell = 0 12 | CNNDualCell = 1 13 | CNNTripleCell = 2 14 | 15 | 16 | def _two_input_cell(n_nodes, drop_path_control): 17 | node_config_list = [CnnNodeConfig(2, [0, 1], CNN_OP, drop_path_control=drop_path_control)] 18 | for i in range(n_nodes - 1): 19 | node_config_list.append( 20 | CnnNodeConfig(3 + i, list(np.linspace(0, 2 + i, 3 + i).astype('int')), CNN_OP, 21 | drop_path_control=drop_path_control)) 22 | return node_config_list 23 | 24 | 25 | def _one_input_cell(n_nodes, drop_path_control): 26 | node_config_list = [CnnNodeConfig(1, [0], CNN_OP, drop_path_control=drop_path_control)] 27 | for i in range(n_nodes - 1): 28 | node_config_list.append( 29 | CnnNodeConfig(2 + i, list(np.linspace(0, 1 + i, 2 + i).astype('int')), CNN_OP, 30 | drop_path_control=drop_path_control)) 31 | return node_config_list 32 | 33 | 34 | def get_gnas_cnn_search_space(n_nodes, drop_path_control, n_cell_type: SearchSpaceType) -> SearchSpace: 35 | node_config_list_a = _two_input_cell(n_nodes, drop_path_control) 36 | if n_cell_type == SearchSpaceType.CNNSingleCell: 37 | return SearchSpace(node_config_list_a) 38 | elif n_cell_type == SearchSpaceType.CNNDualCell: 39 | node_config_list_b = _two_input_cell(n_nodes, drop_path_control) 40 | return SearchSpace([node_config_list_a, node_config_list_b], single_block=False) 41 | elif n_cell_type == SearchSpaceType.CNNTripleCell: 42 | node_config_list_b = _two_input_cell(n_nodes, drop_path_control) 43 | node_config_list_c = _one_input_cell(n_nodes, drop_path_control) 44 | return SearchSpace([node_config_list_a, node_config_list_b, node_config_list_c], single_block=False) 45 | 46 | 47 | def get_gnas_rnn_search_space(n_nodes) -> SearchSpace: 48 | node_config_list = [RnnInputNodeConfig(2, [0, 1], RNN_OP)] 49 | for i in range(n_nodes - 1): 50 | node_config_list.append(RnnNodeConfig(3 + i, list(np.linspace(2, 2 + i, 1 + i).astype('int')), RNN_OP)) 51 | return SearchSpace(node_config_list) 52 | -------------------------------------------------------------------------------- /gnas/search_space/cross_over.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gnas.search_space.individual import Individual, MultipleBlockIndividual 3 | 4 | def _individual_uniform_crossover(individual_a: Individual, individual_b: Individual): 5 | n = individual_a.get_length() 6 | selection = np.random.randint(0, 2, n) 7 | i = 0 8 | iv_a = [] 9 | iv_b = [] 10 | for a, b in zip(individual_a.iv, individual_b.iv): 11 | current_selection = selection[i:i + len(a)] 12 | iv_a.append(a * current_selection + b * (1 - current_selection)) 13 | iv_b.append(a * (1 - current_selection) + b * current_selection) 14 | i += len(a) 15 | return Individual(iv_a, individual_a.mi, individual_a.ss, index=individual_a.index), \ 16 | Individual(iv_b, individual_b.mi, individual_b.ss, index=individual_b.index) 17 | 18 | 19 | def _individual_block_crossover(individual_a: Individual, individual_b: Individual): 20 | n = individual_a.get_n_op() 21 | selection = np.random.randint(0, 2, n) 22 | iv_a = [] 23 | iv_b = [] 24 | for i, (a, b) in enumerate(zip(individual_a.iv, individual_b.iv)): 25 | # current_selection = selection[i] 26 | iv_a.append(a * selection[i] + b * (1 - selection[i])) 27 | iv_b.append(a * (1 - selection[i]) + b * selection[i]) 28 | 29 | return Individual(iv_a, individual_a.mi, individual_a.ss, index=individual_a.index), \ 30 | Individual(iv_b, individual_b.mi, individual_b.ss, index=individual_b.index) 31 | 32 | 33 | def individual_uniform_crossover(individual_a, individual_b, p_c): 34 | if np.random.rand(1) < p_c: 35 | if isinstance(individual_a, Individual): 36 | return _individual_uniform_crossover(individual_a, individual_b) 37 | else: 38 | pairs = [_individual_uniform_crossover(inv_a, inv_b) for inv_a, inv_b in 39 | zip(individual_a.individual_list, individual_b.individual_list)] 40 | return MultipleBlockIndividual([p[0] for p in pairs]), MultipleBlockIndividual([p[1] for p in pairs]) 41 | else: 42 | return individual_a, individual_b 43 | 44 | 45 | def individual_block_crossover(individual_a, individual_b, p_c): 46 | if np.random.rand(1) < p_c: 47 | if isinstance(individual_a, Individual): 48 | return _individual_block_crossover(individual_a, individual_b) 49 | else: 50 | pairs = [_individual_block_crossover(inv_a, inv_b) for inv_a, inv_b in 51 | zip(individual_a.individual_list, individual_b.individual_list)] 52 | return MultipleBlockIndividual([p[0] for p in pairs]), MultipleBlockIndividual([p[1] for p in pairs]) 53 | else: 54 | return individual_a, individual_b 55 | -------------------------------------------------------------------------------- /gnas/genetic_algorithm/population_dict.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from operator import itemgetter 3 | import sys 4 | import copy 5 | 6 | 7 | class PopulationDict(object): 8 | def __init__(self, values_dict: OrderedDict = None, index_dict: OrderedDict = None, 9 | current_index=0): 10 | if values_dict is None: 11 | self.values_dict = OrderedDict({}) 12 | else: 13 | self.values_dict = values_dict 14 | if index_dict is None: 15 | self.index_dict = OrderedDict({}) 16 | else: 17 | self.index_dict = index_dict 18 | self.i = current_index 19 | 20 | def copy(self): 21 | return PopulationDict(copy.deepcopy(self.values_dict), copy.deepcopy(self.index_dict), self.i) 22 | 23 | def __len__(self): 24 | return len(self.values_dict) 25 | 26 | def __str__(self): 27 | return str(self.values_dict) 28 | 29 | def items(self): 30 | return self.values_dict.items() 31 | 32 | def values(self): 33 | return self.values_dict.values() 34 | 35 | def keys(self): 36 | return self.values_dict.keys() 37 | 38 | def update(self, input_dict: dict): 39 | self.values_dict.update(input_dict) 40 | for k, v in input_dict.items(): 41 | self.index_dict.update({k: self.i}) 42 | self.i += 1 43 | 44 | def filter_top_n(self, n=sys.maxsize, min_max=True): 45 | values_dict = OrderedDict({}) 46 | index_dict = OrderedDict({}) 47 | for i, (key, value) in enumerate(sorted(self.values_dict.items(), key=itemgetter(1), reverse=min_max)): 48 | if i < n: 49 | values_dict.update({key: value}) 50 | index_dict.update({key: self.index_dict.get(key)}) 51 | return PopulationDict(values_dict, index_dict, self.i) 52 | 53 | def filter_last_n(self, n=sys.maxsize): 54 | values_dict = OrderedDict({}) 55 | index_dict = OrderedDict({}) 56 | for i, (key, index) in enumerate(sorted(self.index_dict.items(), key=itemgetter(1), reverse=True)): 57 | if i < n: 58 | values_dict.update({key: self.values_dict.get(key)}) 59 | index_dict.update({key: index}) 60 | return PopulationDict(values_dict, index_dict, self.i) 61 | 62 | def merge(self, other): 63 | values_dict = self.values_dict.copy() 64 | index_dict = self.index_dict.copy() 65 | values_dict.update(other.values_dict) 66 | index_dict.update(other.index_dict) 67 | return PopulationDict(values_dict, index_dict, self.i) 68 | 69 | def get_n_diff(self, other): 70 | n = sum([1 for k in other.keys() if k not in list(self.keys())]) 71 | return n 72 | -------------------------------------------------------------------------------- /gnas/modules/cnn_block.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | from gnas.search_space.individual import Individual 6 | from gnas.modules.sub_graph_module import SubGraphModule 7 | from torch.nn import functional as F 8 | from modules.se_block import SEBlock 9 | from modules.identity import Identity 10 | 11 | 12 | class CnnSearchModule(nn.Module): 13 | def __init__(self, n_channels, ss, individual_index=0, se_block=True): 14 | super(CnnSearchModule, self).__init__() 15 | 16 | self.ss = ss 17 | self.n_channels = n_channels 18 | self.config_dict = {'n_channels': n_channels} 19 | self.sub_graph_module = SubGraphModule(ss, self.config_dict, 20 | individual_index=individual_index) 21 | if ss.single_block: 22 | self.n_inputs = len(ss.ocl[0].inputs) 23 | else: 24 | self.n_inputs = len(ss.ocl[individual_index][0].inputs) 25 | if se_block: 26 | self.se_block = SEBlock(n_channels, 8) 27 | else: 28 | self.se_block = Identity() 29 | 30 | self.bn = nn.BatchNorm2d(n_channels) 31 | self.relu = nn.ReLU() 32 | if self.ss.single_block: 33 | self.weights = [Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in range(len(ss.ocl))] 34 | else: 35 | self.weights = [Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in 36 | range(len(ss.ocl[individual_index]))] 37 | [self.register_parameter('w_' + str(i), w) for i, w in enumerate(self.weights)] 38 | self.register_parameter('bias', None) 39 | self.reset_parameters() 40 | 41 | def reset_parameters(self): 42 | n = self.n_channels * len(self.weights) 43 | stdv = 1. / math.sqrt(n) 44 | for w in self.weights: 45 | w.data.uniform_(-stdv, stdv) 46 | 47 | def forward(self, inputs_tensor, bypass_input): 48 | if self.n_inputs == 1: 49 | net = self.sub_graph_module(inputs_tensor) 50 | elif self.n_inputs == 2: 51 | net = self.sub_graph_module(inputs_tensor, bypass_input) 52 | 53 | net = torch.cat([net[i] for i in self.sub_graph_module.avg_index if i > 1], dim=1) 54 | w = torch.cat([self.weights[i - 2] for i in self.sub_graph_module.avg_index if i > 1], dim=1) 55 | net = self.bn(F.conv2d(self.relu(net), w, self.bias, 1, 0, 1, 1)) 56 | return self.se_block(net) + inputs_tensor 57 | 58 | def set_individual(self, individual: Individual): 59 | self.sub_graph_module.set_individual(individual) 60 | 61 | def parameters(self): 62 | for name, param in self.named_parameters(): 63 | yield param 64 | -------------------------------------------------------------------------------- /gnas/search_space/operation_space.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gnas.common.bit_utils import vector_bits2int 3 | 4 | 5 | class RnnInputNodeConfig(object): 6 | def __init__(self, node_id, inputs: list, non_linear_list): 7 | self.node_id = node_id 8 | self.inputs = inputs 9 | self.non_linear_list = non_linear_list 10 | 11 | def get_n_bits(self, max_inputs): 12 | return np.log2(len(self.non_linear_list)).astype('int') 13 | 14 | def max_values_vector(self, max_inputs): 15 | return np.ones(self.get_n_bits(0)) 16 | 17 | def get_n_inputs(self): 18 | return len(self.inputs) 19 | 20 | def parse_config(self, oc): 21 | return vector_bits2int(oc) 22 | 23 | 24 | class RnnNodeConfig(object): 25 | def __init__(self, node_id, inputs: list, non_linear_list): 26 | self.node_id = node_id 27 | self.inputs = inputs 28 | self.non_linear_list = non_linear_list 29 | 30 | def get_n_bits(self, max_inputs): 31 | return np.log2(len(self.non_linear_list)).astype('int') + (max_inputs > 1) 32 | 33 | def max_values_vector(self, max_inputs): 34 | op_bits = np.ones(self.get_n_bits(0)) 35 | if max_inputs > 1: 36 | return np.concatenate([np.asarray(max_inputs - 1).reshape(1), op_bits]) 37 | return op_bits 38 | 39 | def get_n_inputs(self): 40 | return len(self.inputs) 41 | 42 | def parse_config(self, oc): 43 | if len(self.inputs) == 1: 44 | return self.inputs[0], 0, vector_bits2int(oc) 45 | else: 46 | return self.inputs[oc[0]], oc[0], vector_bits2int(oc[1:]) 47 | 48 | 49 | class CnnNodeConfig(object): 50 | def __init__(self, node_id, inputs: list, op_list, drop_path_control): 51 | self.node_id = node_id 52 | self.inputs = inputs 53 | self.op_list = op_list 54 | self.drop_path_control = drop_path_control 55 | 56 | def max_values_vector(self, max_inputs): 57 | max_inputs = len(self.inputs) 58 | if max_inputs > 1: 59 | return np.asarray([max_inputs - 1, max_inputs - 1, len(self.op_list) - 1, len(self.op_list) - 1]) 60 | return np.asarray([len(self.op_list) - 1, len(self.op_list) - 1]) 61 | 62 | def get_n_inputs(self): 63 | return len(self.inputs) 64 | 65 | def parse_config(self, oc): 66 | if len(self.inputs) == 1: 67 | op_a = oc[0] 68 | op_b = oc[1] 69 | input_index_a = 0 70 | input_index_b = 0 71 | else: 72 | input_index_a = oc[0] 73 | input_index_b = oc[1] 74 | op_a = oc[2] 75 | op_b = oc[3] 76 | input_a = self.inputs[input_index_a] 77 | input_b = self.inputs[input_index_b] 78 | return input_a, input_b, input_index_a, input_index_b, op_a, op_b 79 | -------------------------------------------------------------------------------- /gnas/modules/rnn_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from gnas.search_space.individual import Individual 5 | from gnas.modules.sub_graph_module import SubGraphModule 6 | 7 | 8 | class RnnSearchModule(nn.Module): 9 | def __init__(self, in_channels, n_channels, working_device, ss): 10 | super(RnnSearchModule, self).__init__() 11 | 12 | self.ss = ss 13 | self.in_channels = in_channels 14 | self.n_channels = n_channels 15 | self.working_device = working_device 16 | self.config_dict = {'in_channels': self.in_channels, 17 | 'n_channels': self.n_channels} 18 | self.sub_graph_module = SubGraphModule(ss, self.config_dict) 19 | 20 | self.reset_parameters() 21 | 22 | def forward(self, inputs_tensor, state_tensor): 23 | # input size [Time step,Batch,features] 24 | 25 | state = state_tensor[0, :, :] 26 | outputs = [] 27 | 28 | for i in torch.split(inputs_tensor, split_size_or_sections=1, dim=0): # Loop over time steps 29 | output, state = self.cell(i, state) 30 | # state_norm = state.norm(dim=-1) 31 | # max_norm = 25.0 32 | # if torch.any(state_norm > max_norm).item(): 33 | # clip_select = state_norm > max_norm 34 | # clip_norms = state_norm[clip_select] 35 | # 36 | # mask = torch.ones(state.size(), device=self.working_device) 37 | # normalizer = max_norm / clip_norms 38 | # mask[clip_select, :] = normalizer.unsqueeze(dim=-1) 39 | # mask = mask.detach() 40 | # state *= mask 41 | # print(np.max(state.norm(dim=-1).detach().cpu().numpy())) 42 | # print("Max Norm pass") 43 | # state = state / state.norm(dim=-1) 44 | outputs.append(output) 45 | output = torch.stack(outputs, dim=0) 46 | 47 | return output, state.unsqueeze(dim=0) 48 | 49 | def cell(self, x, state): 50 | net = self.sub_graph_module(x.squeeze(dim=0), state) 51 | output, state = torch.mean(torch.stack([net[i] for i in self.sub_graph_module.avg_index], dim=-1), dim=-1), net[ 52 | -1] 53 | return output, output 54 | 55 | def set_individual(self, individual: Individual): 56 | self.sub_graph_module.set_individual(individual) 57 | 58 | def init_state(self, batch_size=1): # model init state 59 | weight = next(self.parameters()) 60 | return weight.new_zeros(1, batch_size, self.n_channels) 61 | 62 | def parameters(self): 63 | for name, param in self.named_parameters(): 64 | yield param 65 | 66 | def reset_parameters(self): 67 | init_range = 0.025 68 | for param in self.parameters(): 69 | param.data.uniform_(-init_range, init_range) 70 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import gnas 4 | import time 5 | from tests.common4testing import generate_ss, generate_ss_cnn 6 | from gnas.modules.sub_graph_module import SubGraphModule 7 | from modules.drop_module import DropModuleControl 8 | 9 | class TestModules(unittest.TestCase): 10 | def test_sub_graph_build_rnn(self): 11 | ss = generate_ss() 12 | sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128}) 13 | 14 | sgm.set_individual(ss.generate_individual()) 15 | 16 | def test_run_sub_module(self): 17 | ss = generate_ss() 18 | sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128}) 19 | y = torch.randn(25, 128, dtype=torch.float) 20 | for i in range(100): 21 | sgm.set_individual(ss.generate_individual()) 22 | x = torch.randn(25, 32, dtype=torch.float) 23 | y = sgm(x, y) 24 | y = y[-1] 25 | 26 | def test_cnn_sub_module(self): 27 | ss = generate_ss_cnn() 28 | sgm = SubGraphModule(ss, {'n_channels': 64}) 29 | 30 | for i in range(100): 31 | sgm.set_individual(ss.generate_individual()) 32 | y = torch.randn(32, 64, 16, 16, dtype=torch.float) 33 | x = torch.randn(32, 64, 16, 16, dtype=torch.float) 34 | res = sgm(x, y) 35 | 36 | def test_cnn_module(self): 37 | batch_size = 64 38 | h, w = 16, 16 39 | channels = 64 40 | input = torch.randn(batch_size, channels, h, w, dtype=torch.float) 41 | input_b = torch.randn(batch_size, channels, h, w, dtype=torch.float) 42 | dp_control = DropModuleControl(1) 43 | ss = gnas.get_gnas_cnn_search_space(4, dp_control, gnas.SearchSpaceType.CNNSingleCell) 44 | rnn = gnas.modules.CnnSearchModule(n_channels=channels, 45 | ss=ss) 46 | rnn.set_individual(ss.generate_individual()) 47 | 48 | s = time.time() 49 | output = rnn(input, input_b) 50 | print(time.time() - s) 51 | self.assertTrue(output.shape[0] == batch_size) 52 | self.assertTrue(output.shape[1] == channels) 53 | self.assertTrue(output.shape[2] == h) 54 | self.assertTrue(output.shape[3] == w) 55 | 56 | def test_rnn_module(self): 57 | batch_size = 64 58 | in_channels = 300 59 | out_channels = 128 60 | time_steps = 35 61 | input = torch.randn(time_steps, batch_size, in_channels, dtype=torch.float) 62 | 63 | ss = gnas.get_gnas_rnn_search_space(12) 64 | rnn = gnas.modules.RnnSearchModule(in_channels=in_channels, n_channels=out_channels, working_device='cpu', 65 | ss=ss) 66 | rnn.set_individual(ss.generate_individual()) 67 | 68 | state = rnn.init_state(batch_size) 69 | s = time.time() 70 | output, state = rnn(input, state) 71 | print(time.time() - s) 72 | self.assertTrue(output.shape[1] == batch_size) 73 | self.assertTrue(output.shape[0] == time_steps) 74 | self.assertTrue(output.shape[2] == out_channels) 75 | 76 | 77 | if __name__ == '__main__': 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /gnas/modules/module_generator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from modules.identity import Identity 3 | 4 | __nl_dict__ = {'Tanh': nn.Tanh, 5 | 'ReLU': nn.ReLU6, 6 | 'ReLU6': nn.ReLU, 7 | 'SELU': nn.SELU, 8 | 'LeakyReLU': nn.LeakyReLU, 9 | 'Sigmoid': nn.Sigmoid} 10 | 11 | 12 | def generate_dw_conv(in_channels, out_channels, kernel): 13 | padding = int((kernel - 1) / 2) 14 | conv1 = nn.Sequential(nn.ReLU(), 15 | nn.Conv2d(in_channels, in_channels, kernel, padding=padding, groups=in_channels, bias=False), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(), 18 | nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False), 19 | nn.BatchNorm2d(out_channels)) 20 | conv2 = nn.Sequential(nn.ReLU(), 21 | nn.Conv2d(in_channels, in_channels, kernel, padding=padding, groups=in_channels, bias=False), 22 | nn.BatchNorm2d(out_channels), 23 | nn.ReLU(), 24 | nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False), 25 | nn.BatchNorm2d(out_channels)) 26 | return nn.Sequential(conv1, conv2) 27 | 28 | 29 | def max_pool3x3(in_channels, out_channels): 30 | return nn.Sequential(nn.ReLU(), 31 | nn.MaxPool2d(3, stride=1, padding=1)) 32 | 33 | 34 | def avg_pool3x3(in_channels, out_channels): 35 | return nn.Sequential(nn.ReLU(), 36 | nn.AvgPool2d(3, stride=1, padding=1)) 37 | 38 | 39 | def conv3x3(in_channels, out_channels): 40 | return nn.Sequential(nn.ReLU(), 41 | nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), 42 | nn.BatchNorm2d(out_channels)) 43 | 44 | 45 | def dw_conv3x3(in_channels, out_channels): 46 | return generate_dw_conv(in_channels, out_channels, 3) 47 | 48 | 49 | def dw_conv1x3(in_channels, out_channels): 50 | return nn.Sequential(nn.ReLU(), 51 | nn.Conv2d(in_channels, out_channels, (1, 3), padding=(0, 1), groups=out_channels, bias=False), 52 | nn.BatchNorm2d(out_channels)) 53 | 54 | 55 | def dw_conv3x1(in_channels, out_channels): 56 | return nn.Sequential(nn.ReLU(), 57 | nn.Conv2d(in_channels, out_channels, (3, 1), padding=(1, 0), groups=out_channels, bias=False), 58 | nn.BatchNorm2d(out_channels)) 59 | 60 | 61 | def conv5x5(in_channels, out_channels): 62 | return nn.Sequential(nn.ReLU(), 63 | nn.Conv2d(in_channels, out_channels, 5, padding=2, bias=False), 64 | nn.BatchNorm2d(out_channels)) 65 | 66 | 67 | def dw_conv5x5(in_channels, out_channels): 68 | return generate_dw_conv(in_channels, out_channels, 5) 69 | 70 | 71 | def identity(in_channels, out_channels): 72 | return Identity() 73 | 74 | 75 | __op_dict__ = {'Conv3x3': conv3x3, 76 | 'Dw3x3': dw_conv3x3, 77 | 'Dw3x1': dw_conv3x1, 78 | 'Dw1x3': dw_conv1x3, 79 | 'Conv5x5': conv5x5, 80 | 'Dw5x5': dw_conv5x5, 81 | 'Identity': identity, 82 | 'Max3x3': max_pool3x3, 83 | 'Avg3x3': avg_pool3x3, } 84 | 85 | 86 | def generate_non_linear(non_linear_list): 87 | return [__nl_dict__.get(nl)() for nl in non_linear_list] 88 | 89 | 90 | def generate_op(op_list, in_channels, out_channels): 91 | return [__op_dict__.get(nl)(in_channels, out_channels) for nl in op_list] 92 | -------------------------------------------------------------------------------- /gnas/common/graph_draw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from gnas.search_space.individual import Individual, MultipleBlockIndividual 5 | from gnas.search_space.operation_space import RnnNodeConfig, RnnInputNodeConfig, CnnNodeConfig 6 | 7 | 8 | def add_node(graph, node_id, label, shape='box', style='filled'): 9 | if label.startswith('Input') or label.startswith('x'): 10 | color = 'skyblue' 11 | elif label.startswith('Output'): 12 | color = 'pink' 13 | elif 'Tanh' in label or 'Add' in label or 'Concat' in label: 14 | color = 'yellow' 15 | elif 'ReLU' in label or 'Dw' in label or 'Conv' in label: 16 | color = 'orange' 17 | elif 'Sigmoid' in label or 'Identity' in label: 18 | color = 'greenyellow' 19 | elif label == 'avg': 20 | color = 'seagreen3' 21 | else: 22 | color = 'white' 23 | 24 | if not any(label.startswith(word) for word in ['x', 'avg', 'h']): 25 | label = f"{label}\n({node_id})" 26 | 27 | graph.add_node( 28 | node_id, label=label, color='black', fillcolor=color, 29 | shape=shape, style=style, 30 | ) 31 | 32 | 33 | def _draw_individual(ocl, individual, path=None): 34 | import pygraphviz as pgv 35 | graph = pgv.AGraph(directed=True, layout='dot') # not work? 36 | 37 | ofset = len(ocl[0].inputs) 38 | for i in range(len(ocl[0].inputs)): 39 | add_node(graph, i, 'x[' + str(i) + ']') 40 | 41 | input_list = [] 42 | 43 | for i, (oc, op) in enumerate(zip(individual.generate_node_config(), ocl)): 44 | if isinstance(op, CnnNodeConfig): 45 | input_a = oc[0] 46 | input_b = oc[1] 47 | input_list.append(input_a) 48 | input_list.append(input_b) 49 | op_a = oc[4] 50 | op_b = oc[5] 51 | add_node(graph, (i + ofset) * 10, ocl[i].op_list[op_a]) 52 | add_node(graph, (i + ofset) * 10 + 1, ocl[i].op_list[op_b]) 53 | graph.add_edge(input_a, (i + ofset) * 10) 54 | graph.add_edge(input_b, (i + ofset) * 10 + 1) 55 | add_node(graph, (i + ofset), 'Add') 56 | graph.add_edge((i + ofset) * 10, (i + ofset)) 57 | graph.add_edge((i + ofset) * 10 + 1, (i + ofset)) 58 | c = graph.add_subgraph([(i + ofset) * 10, (i + ofset) * 10 + 1, (i + ofset)], 59 | name='cluster_block:' + str(i), label='Block ' + str(i)) 60 | # c.attr(label='block:'+str(i)) 61 | 62 | elif isinstance(op, RnnInputNodeConfig): 63 | op_type = op.non_linear_list[oc] 64 | add_node(graph, (i + ofset), op_type) 65 | graph.add_edge(0, (i + ofset)) 66 | graph.add_edge(1, (i + ofset)) 67 | input_list.append(0) 68 | input_list.append(1) 69 | 70 | elif isinstance(op, RnnNodeConfig): 71 | op_type = op.non_linear_list[oc[-1]] 72 | add_node(graph, (i + ofset), op_type) 73 | graph.add_edge(oc[0], (i + ofset)) 74 | input_list.append(oc[0]) 75 | else: 76 | raise Exception('unkown node type') 77 | input_list = np.unique(input_list) 78 | op_inputs = [int(i) for i in np.linspace(ofset, ofset + individual.get_n_op() - 1, individual.get_n_op()) if 79 | i not in input_list] 80 | concat_node = i + 1 + ofset 81 | add_node(graph, concat_node, 'Concat') 82 | for i in op_inputs: 83 | graph.add_edge(i, concat_node) 84 | graph.layout(prog='dot') 85 | if path is not None: 86 | graph.draw(path + '.png') 87 | 88 | 89 | 90 | def draw_cell(ocl, individual): 91 | _draw_individual(ocl, individual, path=None) 92 | 93 | 94 | def draw_network(ss, individual, path): 95 | os.makedirs(os.path.dirname(path), exist_ok=True) 96 | if isinstance(individual, Individual): 97 | _draw_individual(ss.ocl, individual, path) 98 | elif isinstance(individual, MultipleBlockIndividual): 99 | [_draw_individual(ocl, inv, path + str(i)) for i, (inv, ocl) in 100 | enumerate(zip(individual.individual_list, ss.ocl))] 101 | -------------------------------------------------------------------------------- /models/model_cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import gnas 3 | import torch 4 | 5 | 6 | class RepeatBlock(nn.Module): 7 | def __init__(self, n_blocks, n_channels, ss, individual_index=0, first_block=None): 8 | super(RepeatBlock, self).__init__() 9 | if first_block is None: first_block = individual_index 10 | self.block_list = [gnas.modules.CnnSearchModule(n_channels, ss, 11 | individual_index=first_block if i == 0 else individual_index) 12 | for i in 13 | range(n_blocks)] 14 | [self.add_module('block_' + str(i), n) for i, n in enumerate(self.block_list)] 15 | 16 | def forward(self, x, x_prev): 17 | for b in self.block_list: 18 | x_new = b(x, x_prev) 19 | x_prev = x 20 | x = x_new 21 | return x, x_prev 22 | 23 | def set_individual(self, individual): 24 | [b.set_individual(individual) for b in self.block_list] 25 | 26 | 27 | class Net(nn.Module): 28 | def __init__(self, n_blocks, n_channels, n_classes, dropout, ss, aux=False): 29 | n_block_types = len(ss.ocl) 30 | normal_block_index = 0 31 | reduce_block_index = 0 32 | first_block_index = 0 33 | if n_block_types >= 2: 34 | normal_block_index = 1 35 | first_block_index = 1 36 | if n_block_types == 3: first_block_index = 2 37 | super(Net, self).__init__() 38 | self.conv1 = nn.Conv2d(3, n_channels, 3, stride=1, padding=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(n_channels) 40 | 41 | self.block_1 = RepeatBlock(n_blocks, n_channels, ss, 42 | individual_index=normal_block_index, first_block=first_block_index) 43 | 44 | self.avg = nn.AvgPool2d(2) 45 | self.conv2 = nn.Conv2d(n_channels, 2 * n_channels, 1, stride=1, padding=1, bias=False) 46 | self.bn2 = nn.BatchNorm2d(2 * n_channels) 47 | 48 | self.conv2_prev = nn.Conv2d(n_channels, 2 * n_channels, 1, stride=1, padding=1, bias=False) 49 | self.bn2_prev = nn.BatchNorm2d(2 * n_channels) 50 | # self 51 | 52 | self.block_2_reduce = gnas.modules.CnnSearchModule(2 * n_channels, ss, individual_index=reduce_block_index) 53 | self.block_2 = RepeatBlock(n_blocks, 2 * n_channels, ss, 54 | individual_index=normal_block_index) 55 | 56 | self.conv3 = nn.Conv2d(2 * n_channels, 4 * n_channels, 1, stride=1, padding=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(4 * n_channels) 58 | 59 | self.conv3_prev = nn.Conv2d(2 * n_channels, 4 * n_channels, 1, stride=1, padding=1, bias=False) 60 | self.bn3_prev = nn.BatchNorm2d(4 * n_channels) 61 | 62 | self.block_3_reduce = gnas.modules.CnnSearchModule(4 * n_channels, ss, individual_index=reduce_block_index) 63 | self.block_3 = RepeatBlock(n_blocks, 4 * n_channels, ss, 64 | individual_index=normal_block_index) 65 | 66 | self.relu = nn.ReLU() 67 | self.dp = nn.Dropout(p=dropout) 68 | self.fc1 = nn.Sequential(nn.ReLU(), 69 | nn.Linear(4 * n_channels, n_classes)) 70 | self.aux = aux 71 | if aux: 72 | self.fc2 = nn.Sequential(nn.ReLU(), 73 | nn.Linear(2 * n_channels, n_classes)) 74 | self.reset_param() 75 | 76 | def reset_param(self): 77 | for p in self.parameters(): 78 | if len(p.shape) == 4: 79 | nn.init.kaiming_normal_(p) 80 | 81 | def forward(self, x): 82 | x_prev = self.bn1(self.conv1(x)) 83 | 84 | x, x_prev = self.block_1(x_prev, x_prev) 85 | 86 | # reduce dim 87 | x = self.bn2(self.conv2(self.avg(x))) 88 | x_prev = self.bn2_prev(self.conv2_prev(self.avg(x_prev))) 89 | 90 | x, x_prev = self.block_2(self.block_2_reduce(x, x_prev), x) 91 | 92 | 93 | if self.aux: x2 = torch.mean(torch.mean(x, dim=-1), dim=-1) 94 | 95 | 96 | x = self.bn3(self.conv3(self.avg(x))) 97 | x_prev = self.bn3_prev(self.conv3_prev(self.avg(x_prev))) 98 | 99 | x, x_prev = self.block_3(self.block_3_reduce(x, x_prev), x) 100 | 101 | 102 | # Global pooling 103 | x = torch.mean(torch.mean(x, dim=-1), dim=-1) 104 | if self.aux: 105 | return [self.fc1(self.dp(x)), self.fc2(self.dp(x2))] 106 | else: 107 | return [self.fc1(self.dp(x))] 108 | 109 | def set_individual(self, individual): 110 | self.block_1.set_individual(individual) 111 | self.block_2.set_individual(individual) 112 | self.block_2_reduce.set_individual(individual) 113 | self.block_3_reduce.set_individual(individual) 114 | self.block_3.set_individual(individual) 115 | -------------------------------------------------------------------------------- /gnas/modules/node_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from gnas.modules.module_generator import generate_non_linear, generate_op 3 | from modules.weight_drop import WeightDrop 4 | from modules.drop_module import DropModule 5 | 6 | 7 | class RnnInputNodeModule(nn.Module): 8 | def __init__(self, node_config, config_dict): 9 | super(RnnInputNodeModule, self).__init__() 10 | if node_config.get_n_inputs() != 2: raise Exception('aaa') 11 | self.nc = node_config 12 | dropout = 0.5 # TODO:change to input config 13 | self.in_channels = config_dict.get('in_channels') 14 | self.n_channels = config_dict.get('n_channels') 15 | self.nl_module = generate_non_linear(self.nc.non_linear_list) 16 | 17 | self.x_linear_list = [nn.Linear(self.in_channels, self.n_channels) for _ in range(2)] 18 | [self.add_module('c_linear' + str(i), m) for i, m in enumerate(self.x_linear_list)] 19 | self.h_linear_list = [WeightDrop(nn.Linear(self.n_channels, self.n_channels), ['weight'], dropout, True) for _ 20 | in 21 | range(2)] 22 | [self.add_module('h_linear' + str(i), m) for i, m in enumerate(self.h_linear_list)] 23 | self.sigmoid = nn.Sigmoid() 24 | 25 | self.non_linear = None 26 | self.node_config = None 27 | 28 | def forward(self, inputs): 29 | c = self.sigmoid(self.x_linear_list[0](inputs[0]) + self.h_linear_list[0](inputs[1])) 30 | output = c * self.non_linear(self.x_linear_list[1](inputs[0]) + self.h_linear_list[1](inputs[1])) + (1 - c) * \ 31 | inputs[1] 32 | return output 33 | 34 | def set_current_node_config(self, current_config): 35 | nl_index = current_config 36 | self.cc = current_config 37 | self.non_linear = self.nl_module[nl_index] 38 | 39 | 40 | class RnnNodeModule(nn.Module): 41 | def __init__(self, node_config, config_dict): 42 | super(RnnNodeModule, self).__init__() 43 | self.nc = node_config 44 | if node_config.get_n_inputs() < 1: raise Exception('aaa') 45 | 46 | self.n_channels = config_dict.get('n_channels') 47 | self.nl_module = generate_non_linear(self.nc.non_linear_list) 48 | # self.bn = nn.BatchNorm1d(self.n_channels) 49 | 50 | self.x_linear_list = [nn.Linear(self.n_channels, self.n_channels) for _ in range(node_config.get_n_inputs())] 51 | [self.add_module('c_linear' + str(i), m) for i, m in enumerate(self.x_linear_list)] 52 | self.h_linear_list = [nn.Linear(self.n_channels, self.n_channels) for _ in range(node_config.get_n_inputs())] 53 | [self.add_module('h_linear' + str(i), m) for i, m in enumerate(self.h_linear_list)] 54 | self.sigmoid = nn.Sigmoid() 55 | # self.bn = nn.BatchNorm1d(self.n_channels) 56 | self.non_linear = None 57 | self.node_config = None 58 | 59 | def forward(self, inputs): 60 | x = inputs[self.select_index] 61 | c = self.sigmoid(self.x_linear(x)) 62 | return c * self.non_linear(self.h_linear(x)) + (1 - c) * x 63 | 64 | def set_current_node_config(self, current_config): 65 | self.select_index, op_index, nl_index = current_config 66 | self.cc = current_config 67 | self.non_linear = self.nl_module[nl_index] 68 | self.x_linear = self.x_linear_list[op_index] 69 | self.h_linear = self.h_linear_list[op_index] 70 | 71 | 72 | class ConvNodeModule(nn.Module): 73 | def __init__(self, node_config, config_dict): 74 | super(ConvNodeModule, self).__init__() 75 | self.nc = node_config 76 | 77 | self.n_channels = config_dict.get('n_channels') 78 | self.conv_module = [] 79 | for j in range(node_config.get_n_inputs()): 80 | op_list = [DropModule(op, node_config.drop_path_control) for op in 81 | generate_op(self.nc.op_list, self.n_channels, self.n_channels)] 82 | self.conv_module.append(op_list) 83 | [self.add_module('conv_op_' + str(i) + '_in_' + str(j), m) for i, m in enumerate(self.conv_module[-1])] 84 | 85 | self.non_linear_a = None 86 | self.non_linear_b = None 87 | self.input_a = None 88 | self.input_b = None 89 | self.cc = None 90 | self.op_a = None 91 | self.op_b = None 92 | 93 | def forward(self, inputs): 94 | net_a = inputs[self.input_a] 95 | net_b = inputs[self.input_b] 96 | return self.op_a(net_a) + self.op_b(net_b) 97 | 98 | def set_current_node_config(self, current_config): 99 | input_a, input_b, input_index_a, input_index_b, op_a, op_b = current_config 100 | self.select_index = [input_a, input_b] 101 | self.cc = current_config 102 | self.input_a = input_a 103 | self.input_b = input_b 104 | self.op_a = self.conv_module[input_index_a][op_a] 105 | self.op_b = self.conv_module[input_index_b][op_b] 106 | #### set grad false 107 | for p in self.parameters(): 108 | p.requires_grad = False 109 | #### set grad true 110 | for p in self.op_b.parameters(): 111 | p.requires_grad = True 112 | for p in self.op_a.parameters(): 113 | p.requires_grad = True 114 | -------------------------------------------------------------------------------- /tests/test_search_space.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import os 4 | import inspect 5 | 6 | from gnas.search_space.operation_space import RnnInputNodeConfig, RnnNodeConfig 7 | from gnas.search_space.search_space import SearchSpace 8 | from gnas.search_space.individual import Individual 9 | from gnas.search_space.cross_over import individual_uniform_crossover 10 | from gnas.common.graph_draw import draw_network 11 | from gnas.search_space.mutation import individual_flip_mutation 12 | import gnas 13 | 14 | 15 | class TestSearchSpace(unittest.TestCase): 16 | @staticmethod 17 | def generate_block(): 18 | nll = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid'] 19 | node_config_list = [RnnInputNodeConfig(0, [], nll)] 20 | for i in range(12): 21 | node_config_list.append(RnnNodeConfig(i + 1, [0], nll)) 22 | return node_config_list 23 | 24 | @staticmethod 25 | def generate_ss(): 26 | 27 | ss = SearchSpace(TestSearchSpace.generate_block()) 28 | return ss 29 | 30 | @staticmethod 31 | def generate_ss_multiple_blocks(): 32 | block_list = [TestSearchSpace.generate_block(), TestSearchSpace.generate_block()] 33 | ss = SearchSpace(block_list, single_block=False) 34 | return ss 35 | 36 | def test_basic(self): 37 | ss = self.generate_ss() 38 | individual = ss.generate_individual() 39 | self._test_individual(individual, ss.get_n_nodes()) 40 | 41 | def test_individual(self): 42 | ss = self.generate_ss() 43 | individual_a = ss.generate_individual() 44 | individual_b = ss.generate_individual() 45 | individual_a_tag = individual_a.copy() 46 | dict2test = dict() 47 | self.assertFalse(individual_a == individual_b) 48 | self.assertTrue(individual_a_tag == individual_a) 49 | dict2test.update({individual_a: 40}) 50 | dict2test.update({individual_b: 80}) 51 | dict2test.update({individual_a_tag: 90}) 52 | self.assertTrue(dict2test.get(individual_a) == 90) 53 | self.assertTrue(dict2test.get(individual_a_tag) == 90) 54 | self.assertTrue(dict2test.get(individual_b) == 80) 55 | self.assertTrue(len(dict2test) == 2) 56 | # res_dict 57 | 58 | def test_basic_multiple(self): 59 | ss = self.generate_ss_multiple_blocks() 60 | individual = ss.generate_individual() 61 | self._test_individual(individual, ss.get_n_nodes()) 62 | 63 | def test_mutation(self): 64 | ss = self.generate_ss() 65 | for i in range(100): 66 | individual_a = ss.generate_individual() 67 | individual_c = individual_flip_mutation(individual_a, 1 / 10) 68 | ce = 0 69 | te = 0 70 | for a, c in zip(individual_a.iv, individual_c.iv): 71 | for ia, ic in zip(a, c): 72 | te += 1 73 | ce += ia != ic 74 | 75 | self._test_individual(individual_c, ss.get_n_nodes()) 76 | self.assertTrue(ce != te) 77 | 78 | def test_cross_over(self): 79 | ss = self.generate_ss() 80 | ca = 0 81 | cb = 0 82 | cc = 0 83 | for i in range(100): 84 | individual_a = ss.generate_individual() 85 | individual_b = ss.generate_individual() 86 | individual_c, individual_d = individual_uniform_crossover(individual_a, individual_b, 1) 87 | 88 | for a, b, c in zip(individual_a.iv, individual_b.iv, individual_c.iv): 89 | for ia, ib, ic in zip(a, b, c): 90 | cc += 1 91 | ca += ia == ic 92 | cb += ib == ic 93 | self.assertTrue(ia == ic or ib == ic) 94 | 95 | self._test_individual(individual_c, ss.get_n_nodes()) 96 | self.assertTrue(cc != cb) 97 | self.assertTrue(cc != ca) 98 | 99 | def test_plot_individual_rnn(self): 100 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 101 | 102 | ss = gnas.get_gnas_rnn_search_space(12) 103 | ind = ss.generate_individual() 104 | draw_network(ss, ind, os.path.join(current_path, 'graph')) 105 | 106 | def test_plot_individual(self): 107 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 108 | 109 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNSingleCell) 110 | ind = ss.generate_individual() 111 | draw_network(ss, ind, os.path.join(current_path, 'graph')) 112 | 113 | def test_plot_individual_dual(self): 114 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 115 | 116 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNDualCell) 117 | ind = ss.generate_individual() 118 | draw_network(ss, ind, os.path.join(current_path, 'graph')) 119 | 120 | def test_plot_individual_triple(self): 121 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 122 | 123 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNTripleCell) 124 | ind = ss.generate_individual() 125 | draw_network(ss, ind, os.path.join(current_path, 'graph')) 126 | 127 | def _test_individual(self, individual, n_nodes): 128 | individual_flip_mutation(individual, 0.2) 129 | if isinstance(individual, Individual): 130 | self.assertTrue(len(individual.iv) == n_nodes) 131 | for c in individual.iv: 132 | if len(c) == 2: 133 | self.assertFalse(np.any(c > 1)) 134 | else: 135 | self.assertFalse(np.any(c[1:] > 1)) 136 | individual.generate_node_config() 137 | else: 138 | [self.assertTrue(len(ind.iv) == n_nodes[i]) for i, ind in enumerate(individual.individual_list)] 139 | individual.generate_node_config(0) 140 | 141 | 142 | if __name__ == '__main__': 143 | unittest.main() 144 | -------------------------------------------------------------------------------- /models/model_rnn.py: -------------------------------------------------------------------------------- 1 | import gnas 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | 7 | 8 | class EmbeddingDropout(torch.nn.Embedding): 9 | """Class for dropping out embeddings by zero'ing out parameters in the 10 | embedding matrix. 11 | 12 | This is equivalent to dropping out particular words, e.g., in the sentence 13 | 'the quick brown fox jumps over the lazy dog', dropping out 'the' would 14 | lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the 15 | embedding vector space). 16 | 17 | See 'A Theoretically Grounded Application of Dropout in Recurrent Neural 18 | Networks', (Gal and Ghahramani, 2016). 19 | """ 20 | 21 | def __init__(self, 22 | num_embeddings, 23 | embedding_dim, 24 | max_norm=None, 25 | norm_type=2, 26 | scale_grad_by_freq=False, 27 | sparse=False, 28 | dropout=0.1, 29 | scale=None): 30 | """Embedding constructor. 31 | 32 | Args: 33 | dropout: Dropout probability. 34 | scale: Used to scale parameters of embedding weight matrix that are 35 | not dropped out. Note that this is _in addition_ to the 36 | `1/(1 - dropout)` scaling. 37 | 38 | See `torch.nn.Embedding` for remaining arguments. 39 | """ 40 | torch.nn.Embedding.__init__(self, 41 | num_embeddings=num_embeddings, 42 | embedding_dim=embedding_dim, 43 | max_norm=max_norm, 44 | norm_type=norm_type, 45 | scale_grad_by_freq=scale_grad_by_freq, 46 | sparse=sparse) 47 | self.dropout = dropout 48 | assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' 49 | 'and < 1.0') 50 | self.scale = scale 51 | 52 | def forward(self, inputs): # pylint:disable=arguments-differ 53 | """Embeds `inputs` with the dropped out embedding weight matrix.""" 54 | if self.training: 55 | dropout = self.dropout 56 | else: 57 | dropout = 0 58 | 59 | if dropout: 60 | mask = self.weight.data.new(self.weight.size(0), 1) 61 | mask.bernoulli_(1 - dropout) 62 | mask = mask.expand_as(self.weight) 63 | mask = mask / (1 - dropout) 64 | masked_weight = self.weight * Variable(mask) 65 | else: 66 | masked_weight = self.weight 67 | if self.scale and self.scale != 1: 68 | masked_weight = masked_weight * self.scale 69 | 70 | return F.embedding(inputs, 71 | masked_weight, 72 | max_norm=self.max_norm, 73 | norm_type=self.norm_type, 74 | scale_grad_by_freq=self.scale_grad_by_freq, 75 | sparse=self.sparse) 76 | 77 | 78 | class LockedDropout(nn.Module): 79 | # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py 80 | def __init__(self, dropout): 81 | super().__init__() 82 | self.dropout = dropout 83 | 84 | def forward(self, x): 85 | # The input is of size [T,N,F] where T is the number of steps, N is the batch size 86 | # and F is the number of features. 87 | # The random variable generate for the locked dropout don't change over the T axis 88 | if not self.training or not self.dropout: 89 | return x 90 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout) 91 | mask = Variable(m, requires_grad=False) / (1 - self.dropout) 92 | mask = mask.expand_as(x) 93 | return mask * x 94 | 95 | 96 | class RNNModel(nn.Module): 97 | """Container module with an encoder, a recurrent module, and a decoder.""" 98 | 99 | def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False, ss=None): 100 | super(RNNModel, self).__init__() 101 | self.drop_input = LockedDropout(0.65) # TODO:change to input config 102 | self.drop_end = LockedDropout(0.4) # TODO:change to input config 103 | self.encoder = EmbeddingDropout(ntoken, ninp, dropout=0.1) # TODO:change to input config 104 | self.ss = ss 105 | self.rnn = gnas.modules.RnnSearchModule(in_channels=ninp, n_channels=nhid, 106 | working_device='cuda', 107 | ss=self.ss) 108 | self.decoder = nn.Linear(nhid, ntoken) 109 | 110 | # Optionally tie weights as in: 111 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 112 | # https://arxiv.org/abs/1608.05859 113 | # and 114 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 115 | # https://arxiv.org/abs/1611.01462 116 | if tie_weights: 117 | if nhid != ninp: 118 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 119 | self.decoder.weight = self.encoder.weight 120 | 121 | self.init_weights() 122 | 123 | self.nhid = nhid 124 | self.nlayers = nlayers 125 | 126 | def set_individual(self, individual): 127 | self.rnn.set_individual(individual) 128 | 129 | def init_weights(self): 130 | initrange = 0.1 131 | self.encoder.weight.data.uniform_(-initrange, initrange) 132 | self.decoder.bias.data.zero_() 133 | self.decoder.weight.data.uniform_(-initrange, initrange) 134 | 135 | def forward(self, input, hidden): 136 | emb = self.drop_input(self.encoder(input)) 137 | output, hidden = self.rnn(emb, hidden) 138 | output = self.drop_end(output) 139 | decoded = self.decoder(output.contiguous().view(output.size(0) * output.size(1), output.size(2))) 140 | return decoded.contiguous().view(output.size(0), output.size(1), decoded.size(1)), hidden 141 | 142 | def init_hidden(self, bsz): 143 | return self.rnn.init_state(bsz) 144 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from modules.cut_out import Cutout 6 | 7 | 8 | def get_dataset(config): 9 | dataset_name = config.get('dataset_name') 10 | data_path = config.get('data_path') 11 | if dataset_name == 'CIFAR10': 12 | return get_cifar(config, os.path.join(data_path, 'CIFAR10')) 13 | elif dataset_name == 'CIFAR100': 14 | return get_cifar(config, os.path.join(data_path, 'CIFAR100'), dataset_name='CIFAR100') 15 | elif dataset_name == 'PTB': 16 | corpus = Corpus(os.path.join(data_path, 'ptb')) 17 | batch_size_train = config.get('batch_size') 18 | batch_size_val = config.get('batch_size_val') 19 | device = config.get('working_device') 20 | # train_data, val_data, test_data = corpus.batchify(config.get('batch_size'), config.get('working_device')) 21 | return corpus.single_batchify(corpus.train, batch_size_train, device), corpus.single_batchify(corpus.valid, 22 | batch_size_val, 23 | device), len( 24 | corpus.dictionary) 25 | else: 26 | raise Exception('unkown dataset type') 27 | 28 | 29 | def get_cifar(config, data_path, dataset_name='CIFAR10'): 30 | train_transform = transforms.Compose([]) 31 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 32 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 33 | train_transform.transforms.append(transforms.RandomCrop(32, padding=4)) 34 | train_transform.transforms.append(transforms.RandomHorizontalFlip()) 35 | train_transform.transforms.append(transforms.ToTensor()) 36 | train_transform.transforms.append(normalize) 37 | if config.get('cutout'): 38 | train_transform.transforms.append(Cutout(n_holes=config.get('n_holes'), length=config.get('length'))) 39 | 40 | transform = transforms.Compose([ 41 | transforms.ToTensor(), 42 | normalize]) 43 | trainloader, testloader, n_class = None, None, None 44 | if dataset_name == 'CIFAR10': 45 | trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, 46 | download=True, transform=train_transform) 47 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.get('batch_size'), 48 | shuffle=True, num_workers=4) 49 | 50 | testset = torchvision.datasets.CIFAR10(root=data_path, train=False, 51 | download=True, transform=transform) 52 | testloader = torch.utils.data.DataLoader(testset, batch_size=config.get('batch_size_val'), 53 | shuffle=False, num_workers=4) 54 | n_class = 10 55 | elif dataset_name == 'CIFAR100': 56 | trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, 57 | download=True, transform=train_transform) 58 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.get('batch_size'), 59 | shuffle=True, num_workers=4) 60 | 61 | testset = torchvision.datasets.CIFAR100(root=data_path, train=False, 62 | download=True, transform=transform) 63 | testloader = torch.utils.data.DataLoader(testset, batch_size=config.get('batch_size_val'), 64 | shuffle=False, num_workers=4) 65 | n_class = 100 66 | else: 67 | raise Exception('unkown dataset' + dataset_name) 68 | 69 | return trainloader, testloader, n_class 70 | 71 | 72 | class BatchIterator(object): 73 | def __init__(self, data): 74 | pass 75 | 76 | 77 | class Dictionary(object): 78 | def __init__(self): 79 | self.word2idx = {} 80 | self.idx2word = [] 81 | 82 | def add_word(self, word): 83 | if word not in self.word2idx: 84 | self.idx2word.append(word) 85 | self.word2idx[word] = len(self.idx2word) - 1 86 | return self.word2idx[word] 87 | 88 | def __len__(self): 89 | return len(self.idx2word) 90 | 91 | 92 | class Corpus(object): 93 | def __init__(self, path): 94 | # Starting from sequential dataset, batchify arranges the dataset into columns. 95 | # For instance, with the alphabet as the sequence and batch size 4, we'd get 96 | # ┌ a g m s ┐ 97 | # │ b h n t │ 98 | # │ c i o u │ 99 | # │ d j p v │ 100 | # │ e k q w │ 101 | # └ f l r x ┘. 102 | # These columns are treated as independent by the model, which means that the 103 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient 104 | # batch processing. 105 | self.dictionary = Dictionary() 106 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 107 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 108 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 109 | 110 | @staticmethod 111 | def single_batchify(data, bsz, input_device): 112 | # Work out how cleanly we can divide the dataset into bsz parts. 113 | nbatch = data.size(0) // bsz 114 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 115 | data = data.narrow(0, 0, nbatch * bsz) 116 | # Evenly divide the dataset across the bsz batches. 117 | data = data.view(bsz, -1).t().contiguous() 118 | return data.to(input_device) 119 | 120 | def batchify(self, bsz, device): 121 | return self.single_batchify(self.train, bsz, device), self.single_batchify(self.valid, bsz, 122 | device), self.single_batchify( 123 | self.test, bsz, device) 124 | 125 | def tokenize(self, path): 126 | """Tokenizes a text file.""" 127 | assert os.path.exists(path) 128 | # Add words to the dictionary 129 | with open(path, 'r', encoding="utf8") as f: 130 | tokens = 0 131 | for line in f: 132 | words = line.split() + [''] 133 | tokens += len(words) 134 | for word in words: 135 | self.dictionary.add_word(word) 136 | 137 | # Tokenize file content 138 | with open(path, 'r', encoding="utf8") as f: 139 | ids = torch.LongTensor(tokens) 140 | token = 0 141 | for line in f: 142 | words = line.split() + [''] 143 | for word in words: 144 | ids[token] = self.dictionary.word2idx[word] 145 | token += 1 146 | 147 | return ids 148 | -------------------------------------------------------------------------------- /rnn_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import math 5 | 6 | 7 | def get_batch(source, i, bptt): 8 | # get_batch subdivides the source dataset into chunks of length args.bptt. 9 | # If source is equal to the example output of the batchify function, with 10 | # a bptt-limit of 2, we'd get the following two Variables for i = 0: 11 | # ┌ a g m s ┐ ┌ b h n t ┐ 12 | # └ b h n t ┘ └ c i o u ┘ 13 | # Note that despite the name of the function, the subdivison of dataset is not 14 | # done along the batch dimension (i.e. dimension 1), since that was handled 15 | # by the batchify function. The chunks are along dimension 0, corresponding 16 | # to the seq_len dimension in the LSTM. 17 | seq_len = min(bptt, len(source) - 1 - i) 18 | data = source[i:i + seq_len] 19 | target = source[i + 1:i + 1 + seq_len].view(-1) 20 | return data, target 21 | 22 | 23 | def rnn_genetic_evaluate(ga, input_model, input_criterion, data_source, ntokens, batch_size, bptt): 24 | input_model.eval() # Turn on evaluation mode which disables dropout. 25 | hidden = input_model.init_hidden(batch_size) 26 | with torch.no_grad(): 27 | for ind in ga.get_current_generation(): 28 | input_model.set_individual(ind) 29 | total_loss = 0 30 | for i in range(0, data_source.size(0) - 1, bptt): 31 | data, targets = get_batch(data_source, i, bptt) 32 | output, hidden = input_model(data, hidden) 33 | output_flat = output.view(-1, ntokens) 34 | total_loss += len(data) * input_criterion(output_flat, targets).item() 35 | hidden = repackage_hidden(hidden) 36 | ga.update_current_individual_fitness(ind, total_loss / (len(data_source) - 1)) 37 | return ga.update_population() 38 | 39 | 40 | def rnn_evaluate(input_model, input_criterion, data_source, ntokens, batch_size, bptt): 41 | input_model.eval() # Turn on evaluation mode which disables dropout. 42 | hidden = input_model.init_hidden(batch_size) 43 | with torch.no_grad(): 44 | total_loss = 0 45 | for i in range(0, data_source.size(0) - 1, bptt): 46 | data, targets = get_batch(data_source, i, bptt) 47 | output, hidden = input_model(data, hidden) 48 | output_flat = output.view(-1, ntokens) 49 | total_loss += len(data) * input_criterion(output_flat, targets).item() 50 | hidden = repackage_hidden(hidden) 51 | return total_loss / (len(data_source) - 1) 52 | 53 | 54 | def train_genetic_rnn(ga, train_data, input_model, input_optimizer, input_criterion, ntokens, batch_size, bptt, 55 | grad_clip, 56 | log_interval, final): 57 | # Turn on training mode which enables dropout. 58 | input_model.train() 59 | total_loss = 0. 60 | cur_loss = 0 61 | start_time = time.time() 62 | hidden = input_model.init_hidden(batch_size) 63 | for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): 64 | data, targets = get_batch(train_data, i, bptt) 65 | # Starting each batch, we detach the hidden state from how it was previously produced. 66 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 67 | hidden = repackage_hidden(hidden) 68 | input_optimizer.zero_grad() # zero old gradients for the next back propgation 69 | if not final: input_model.set_individual(ga.sample_child()) # updating 70 | 71 | output, hidden = input_model(data, hidden) 72 | loss = input_criterion(output.view(-1, ntokens), targets) 73 | 74 | loss.backward() 75 | 76 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 77 | torch.nn.utils.clip_grad_norm_(input_model.parameters(), grad_clip) 78 | input_optimizer.step() 79 | 80 | total_loss += loss.item() 81 | 82 | if batch % log_interval == 0 and batch > 0: 83 | cur_loss += total_loss 84 | elapsed = time.time() - start_time 85 | print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 86 | 'loss {:5.2f} | ppl {:8.2f}'.format(batch, len(train_data) // bptt, elapsed * 1000 / log_interval, 87 | total_loss / log_interval, math.exp(total_loss / log_interval))) 88 | total_loss = 0 89 | start_time = time.time() 90 | return (bptt * cur_loss) / len(train_data) 91 | 92 | 93 | def repackage_hidden(h): 94 | """Wraps hidden states in new Tensors, to detach them from their history.""" 95 | if isinstance(h, torch.Tensor): 96 | return h.detach() 97 | else: 98 | return tuple(repackage_hidden(v) for v in h) 99 | 100 | 101 | class Dictionary(object): 102 | def __init__(self): 103 | self.word2idx = {} 104 | self.idx2word = [] 105 | 106 | def add_word(self, word): 107 | if word not in self.word2idx: 108 | self.idx2word.append(word) 109 | self.word2idx[word] = len(self.idx2word) - 1 110 | return self.word2idx[word] 111 | 112 | def __len__(self): 113 | return len(self.idx2word) 114 | 115 | 116 | class Corpus(object): 117 | def __init__(self, path): 118 | self.dictionary = Dictionary() 119 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 120 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 121 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 122 | 123 | @staticmethod 124 | def single_batchify(data, bsz, input_device): 125 | # Work out how cleanly we can divide the dataset into bsz parts. 126 | nbatch = data.size(0) // bsz 127 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 128 | data = data.narrow(0, 0, nbatch * bsz) 129 | # Evenly divide the dataset across the bsz batches. 130 | data = data.view(bsz, -1).t().contiguous() 131 | return data.to(input_device) 132 | 133 | def batchify(self, bsz, device): 134 | return self.single_batchify(self.train, bsz, device), self.single_batchify(self.valid, bsz, 135 | device), self.single_batchify( 136 | self.test, bsz, device) 137 | 138 | def tokenize(self, path): 139 | """Tokenizes a text file.""" 140 | assert os.path.exists(path) 141 | # Add words to the dictionary 142 | with open(path, 'r', encoding="utf8") as f: 143 | tokens = 0 144 | for line in f: 145 | words = line.split() + [''] 146 | tokens += len(words) 147 | for word in words: 148 | self.dictionary.add_word(word) 149 | 150 | # Tokenize file content 151 | with open(path, 'r', encoding="utf8") as f: 152 | ids = torch.LongTensor(tokens) 153 | token = 0 154 | for line in f: 155 | words = line.split() + [''] 156 | for word in words: 157 | ids[token] = self.dictionary.word2idx[word] 158 | token += 1 159 | 160 | return ids 161 | -------------------------------------------------------------------------------- /gnas/genetic_algorithm/genetic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from random import choices 3 | from gnas.search_space.search_space import SearchSpace 4 | from gnas.search_space.cross_over import individual_uniform_crossover, individual_block_crossover 5 | from gnas.search_space.mutation import individual_flip_mutation 6 | from gnas.genetic_algorithm.ga_results import GenetricResult 7 | from gnas.genetic_algorithm.population_dict import PopulationDict 8 | 9 | 10 | def genetic_algorithm_searcher(search_space: SearchSpace, generation_size=20, population_size=300, keep_size=0, 11 | min_objective=True, mutation_p=None, p_cross_over=None, cross_over_type='Bit'): 12 | if mutation_p is None: mutation_p = 1 / search_space.n_elements 13 | if p_cross_over is None: p_cross_over = 1 14 | print('p mutation:' + str(mutation_p), 1 / search_space.n_elements) 15 | 16 | def population_initializer(p_size): 17 | return search_space.generate_population(p_size) 18 | 19 | def mutation_function(x): 20 | return individual_flip_mutation(x, mutation_p) 21 | 22 | if cross_over_type == 'Bit': 23 | print("Bit base cross over") 24 | 25 | def cross_over_function(x0, x1): 26 | return individual_uniform_crossover(x0, x1, p_cross_over) 27 | elif cross_over_type == 'Block': 28 | print("Block base cross over") 29 | 30 | def cross_over_function(x0, x1): 31 | return individual_block_crossover(x0, x1, p_cross_over) 32 | else: 33 | raise Exception('') 34 | 35 | def selection_function(p): 36 | couples = choices(population=list(range(len(p))), weights=p, 37 | k=generation_size) 38 | return np.reshape(np.asarray(couples), [-1, 2]) 39 | 40 | return GeneticAlgorithms(population_initializer, mutation_function, cross_over_function, selection_function, 41 | min_objective=min_objective, generation_size=generation_size, 42 | population_size=population_size, keep_size=keep_size) 43 | 44 | 45 | class GeneticAlgorithms(object): 46 | def __init__(self, population_initializer, mutation_function, cross_over_function, selection_function, 47 | population_size=300, generation_size=20, keep_size=20, min_objective=False): 48 | #################################################################### 49 | # Functions 50 | #################################################################### 51 | self.population_initializer = population_initializer 52 | self.mutation_function = mutation_function 53 | self.cross_over_function = cross_over_function 54 | self.selection_function = selection_function 55 | #################################################################### 56 | # parameters 57 | #################################################################### 58 | self.population_size = population_size 59 | self.generation_size = generation_size 60 | self.keep_size = keep_size 61 | self.min_objective = min_objective 62 | #################################################################### 63 | # status 64 | #################################################################### 65 | self.max_dict = PopulationDict() 66 | self.ga_result = GenetricResult() 67 | self.current_dict = dict() 68 | 69 | self.generation = self._create_random_generation() 70 | 71 | self.i = 0 72 | self.best_individual = None 73 | 74 | def _create_random_generation(self): 75 | return self.population_initializer(self.generation_size) 76 | 77 | def _create_new_generation(self, population, population_fitness): 78 | p = population_fitness / np.nansum(population_fitness) 79 | if self.min_objective: p = 1 - p 80 | couples = self.selection_function(p) # selection 81 | child = [cc for c in couples for cc in 82 | self.cross_over_function(population[c[0]], population[c[1]])] # cross-over 83 | new_generation = np.asarray([self.mutation_function(c) for c in child]) # mutation 84 | 85 | p_array = np.asarray([p.code for p in new_generation]) 86 | b = np.ascontiguousarray(p_array).view(np.dtype((np.void, p_array.dtype.itemsize * p_array.shape[1]))) 87 | _, idx = np.unique(b, return_index=True) 88 | if len(idx) == self.generation_size: 89 | generation = new_generation 90 | else: 91 | n = self.generation_size - len(idx) 92 | p_new = self.population_initializer(n) 93 | generation = np.asarray([*[new_generation[i] for i in idx], *p_new]) 94 | return generation 95 | 96 | def update_population(self): 97 | self.i += 1 98 | 99 | generation_fitness = np.asarray(list(self.current_dict.values())) 100 | generation = list(self.current_dict.keys()) 101 | self.ga_result.add_generation_result(generation_fitness, generation) 102 | 103 | f_mean = np.mean(generation_fitness) 104 | f_var = np.var(generation_fitness) 105 | f_max = np.max(generation_fitness) 106 | f_min = np.min(generation_fitness) 107 | total_dict = self.max_dict.copy() 108 | total_dict.update(self.current_dict) 109 | # last_dict = None 110 | # if self.keep_size > 0: 111 | # last_dict = total_dict.filter_last_n(self.keep_size) 112 | # if self.population_size - self.keep_size > 0: 113 | 114 | best_max_dict = total_dict.filter_top_n(self.population_size,min_max=not self.min_objective) 115 | n_diff = self.max_dict.get_n_diff(best_max_dict) 116 | self.max_dict = best_max_dict 117 | # 118 | # if self.keep_size > 0: 119 | # best_max_dict = best_max_dict.merge(last_dict) 120 | # else: 121 | # best_max_dict = last_dict 122 | 123 | 124 | 125 | 126 | self.current_dict = dict() 127 | population_fitness = np.asarray(list(self.max_dict.values())).flatten() 128 | population = np.asarray(list(self.max_dict.keys())).flatten() 129 | self.best_individual = population[np.argmax(population_fitness)] 130 | fp_mean = np.mean(population_fitness) 131 | fp_var = np.var(population_fitness) 132 | fp_max = np.max(population_fitness) 133 | fp_min = np.min(population_fitness) 134 | self.ga_result.add_population_result(population_fitness, population) 135 | self.generation = self._create_new_generation(population, population_fitness) 136 | 137 | 138 | print( 139 | "Update generation | mean fitness: {:5.2f} | var fitness {:5.2f} | max fitness: {:5.2f} | min fitness {:5.2f} |population size {:d}|".format( 140 | f_mean, f_var, f_max, f_min, len(population))) 141 | print( 142 | "population results | mean fitness: {:5.2f} | var fitness {:5.2f} | max fitness: {:5.2f} | min fitness {:5.2f} |".format( 143 | fp_mean, fp_var, fp_max, fp_min)) 144 | return f_mean, f_var, f_max, f_min, n_diff 145 | 146 | def get_current_generation(self): 147 | return self.generation 148 | 149 | def update_current_individual_fitness(self, individual, individual_fitness): 150 | self.current_dict.update({individual: individual_fitness}) 151 | 152 | def sample_child(self): 153 | if len(list(self.max_dict.keys())) == 0: # if not population exist generate random indivaul 154 | return self.population_initializer(1)[0] 155 | else: 156 | couples = choices(list(self.max_dict.keys()), k=2) # random select two indivuals from population 157 | child = self.cross_over_function(couples[0], couples[1]) # prefome cross over 158 | return self.mutation_function(child[0]) # select the first then mutation 159 | -------------------------------------------------------------------------------- /plot_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from common import load_final, make_log_dir, get_model_type, ModelType 6 | from config import get_config, load_config, save_config 7 | import gnas 8 | from modules.drop_module import DropModuleControl 9 | from gnas.common.graph_draw import draw_cell, draw_network 10 | import matplotlib.image as mpimg 11 | 12 | # Popultation size compare 13 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00", 14 | '/data/projects/gnas_results/population_size/2019_01_31_15_45_42', 15 | '/data/projects/gnas_results/population_size/2019_02_01_03_26_01', 16 | '/data/projects/gnas_results/population_size/2019_02_01_04_23_06', 17 | '/data/projects/gnas_results/population_size/2019_02_01_04_44_45', 18 | '/data/projects/gnas_results/population_size/2019_02_01_15_39_37', 19 | '/data/projects/gnas_results/population_size/2019_02_03_17_25_25', 20 | # '/data/projects/gnas_results/population_size/2019_02_03_17_25_27', 21 | '/data/projects/gnas_results/population_size/2019_02_03_17_25_28'] 22 | 23 | # p mutation 24 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_23_21_20_33", 25 | "/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00", 26 | "/data/projects/gnas_results/p_mutation/2019_01_25_08_46_39", 27 | "/data/projects/gnas_results/p_mutation/2019_01_26_13_18_17"] 28 | 29 | # # LR Compare 30 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00", 31 | "/data/projects/gnas_results/lr_compare/2019_02_04_19_17_59", 32 | "/data/projects/gnas_results/lr_compare/2019_02_04_19_18_00"] 33 | 34 | # # Bit Vs Block 35 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00", 36 | "/data/projects/GNAS/logs/2019_02_11_06_15_10"] 37 | # 38 | # # Plot CIFAR10 - Search Result 39 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00"] 40 | # # Plot CIFAR100 - Search Result 41 | file_list = ["/data/projects/GNAS/logs/2019_02_17_20_25_42"] 42 | 43 | # CIFAR10 Final 44 | file_list = ['/data/projects/gnas_results/new_log/2019_02_09_16_43_02'] 45 | # file_list=['/data/projects/gnas_results/new_log/2019_02_07_18_34_45', 46 | # '/data/projects/gnas_results/new_log/2019_02_09_02_23_52', 47 | # '/data/projects/gnas_results/new_log/2019_02_09_16_43_02', 48 | # '/data/projects/gnas_results/new_log/2019_02_14_18_15_48'] 49 | 50 | 51 | plot_arc = False 52 | # file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00", ] 53 | if plot_arc: 54 | ind_file = os.path.join(file_list[0], 'best_individual.pickle') 55 | config_file = os.path.join(file_list[0], 'config.json') 56 | ind = pickle.load(open(ind_file, "rb")) 57 | 58 | config = get_config(ModelType.CNN) 59 | print("Loading config file:" + config_file) 60 | config.update(load_config(config_file)) 61 | 62 | dp_control = DropModuleControl(config.get('drop_path_keep_prob')) 63 | n_cell_type = gnas.SearchSpaceType(config.get('n_block_type') - 1) 64 | ss = gnas.get_gnas_cnn_search_space(config.get('n_nodes'), dp_control, n_cell_type) 65 | draw_network(ss, ind, './') 66 | title_list = ['Reduce Cell', ' Normal Cell', ' Input Cell'] 67 | for i in range(len(ss.ocl)): 68 | plt.subplot(1, len(ss.ocl), i + 1) 69 | img = mpimg.imread(os.path.join('./', str(i) + '.png')) 70 | plt.imshow(img) 71 | plt.axis('off') 72 | plt.title(title_list[i]) 73 | plt.show() 74 | # draw_cell(ss.ocl[0], ind.individual_list[0]) 75 | # plt.show() 76 | # print("a") 77 | 78 | if len(file_list) == 1 and True: 79 | data = pickle.load(open(os.path.join(file_list[0], 'ga_result.pickle'), "rb")) 80 | config = load_config(os.path.join(file_list[0], 'config.json')) 81 | if data.result_dict.get('Fitness') is None: 82 | plt.plot(np.asarray(data.result_dict.get('Training Accuracy')), label='Training Accuracy') 83 | plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), label='Validation Accuracy') 84 | plt.xlabel('Epoch') 85 | plt.legend() 86 | plt.ylabel('Accuracy[%]') 87 | plt.grid() 88 | plt.show() 89 | else: 90 | fitness = np.stack(data.result_dict.get('Fitness')) 91 | fitness_p = np.stack(data.result_dict.get('Fitness-Population')) 92 | fitness_p = fitness_p[0:-1:2, :] 93 | 94 | epochs = np.linspace(0, fitness_p.shape[0] - 1, fitness_p.shape[0]) 95 | plt.plot(epochs, np.mean(fitness_p, axis=1), '*--', 96 | label='Population mean accuracy') 97 | plt.plot(epochs, np.max(fitness_p, axis=1), label='Max accuracy') 98 | plt.plot(np.asarray(data.result_dict.get('Best')), '--', label='Best') 99 | plt.grid() 100 | plt.legend() 101 | plt.xlabel('Epoch') 102 | plt.ylabel('Accuracy') 103 | plt.show() 104 | 105 | plt.errorbar(epochs, np.mean(fitness_p, axis=1), np.std(fitness_p, axis=1), fmt='*--', 106 | label='Population mean accuracy') 107 | plt.plot(epochs, np.min(fitness_p, axis=1), label='Min accuracy') 108 | plt.plot(epochs, np.max(fitness_p, axis=1), label='Max accuracy') 109 | plt.grid() 110 | plt.legend() 111 | plt.title('Population accuracy on the validation set') 112 | plt.xlabel('Epoch') 113 | plt.ylabel('Accuracy') 114 | plt.show() 115 | 116 | plt.plot(np.asarray(data.result_dict.get('Training Accuracy')), label='Training') 117 | plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), '--', label='Validation') 118 | plt.plot(np.asarray(data.result_dict.get('Best')), '*-', label='Best') 119 | plt.title('Training vs Validation Accuracy') 120 | plt.xlabel('Epoch') 121 | plt.legend() 122 | plt.ylabel('Accuracy[%]') 123 | plt.grid() 124 | plt.show() 125 | 126 | plt.plot(epochs, data.result_dict.get('N')) 127 | plt.title('Number of new individuals in Population') 128 | plt.xlabel('Epoch') 129 | plt.ylabel('N') 130 | plt.grid() 131 | plt.show() 132 | 133 | plt.plot(epochs, data.result_dict.get('Training Loss')) 134 | plt.title('Training Loss') 135 | plt.xlabel('Epoch') 136 | plt.ylabel('Loss') 137 | plt.grid() 138 | plt.show() 139 | 140 | else: 141 | ################ 142 | # Build legend 143 | ################ 144 | config_list = [] 145 | param_list = [] 146 | for f in file_list: 147 | config_list.append(load_config(os.path.join(f, 'config.json'))) 148 | for k in config_list[-1].keys(): 149 | param_list.append(k) 150 | param_list = np.unique(param_list) 151 | str_list = ['' for c in config_list] 152 | res_dict = dict() 153 | for p in param_list: 154 | if len(np.unique([c.get(p) for c in config_list if c.get(p) is not None])) > 1: 155 | for i, c in enumerate(config_list): 156 | str_list[i] = str_list[i] + ' ' + p + '=' + str(c.get(p)) 157 | if res_dict.get(p) is None: 158 | res_dict.update({p: [c.get(p)]}) 159 | else: 160 | res_dict.get(p).append(c.get(p)) 161 | elif len(np.unique([c.get(p) for c in config_list if c.get(p) is not None])) == 1: 162 | if len([c.get(p) for c in config_list if c.get(p) is None]) != 0: 163 | for i, c in enumerate(config_list): 164 | str_list[i] = str_list[i] + ' ' + p + '=' + str(c.get(p)) 165 | if len(res_dict.keys()) == 1: 166 | param_array = np.asarray(res_dict.get(list(res_dict.keys())[0])) 167 | res_list = [] 168 | for i, f in enumerate(file_list): 169 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb")) 170 | res_list.append(np.max(np.asarray(data.result_dict.get('Best')))) 171 | index = np.argsort(param_array) 172 | res_list = np.asarray(res_list)[index] 173 | param_array = param_array[index] 174 | plt.plot(param_array, res_list) 175 | plt.grid() 176 | plt.xlabel(list(res_dict.keys())[0].replace('_', ' ')) 177 | plt.ylabel('Accuracy[%]') 178 | plt.show() 179 | print("a") 180 | ######################### 181 | # Plot Validation 182 | ######################### 183 | plt.subplot(2, 2, 1) 184 | for i, f in enumerate(file_list): 185 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb")) 186 | plt.plot(np.asarray(data.result_dict.get('Best')), label=str_list[i]) 187 | # plt.title() 188 | plt.legend() 189 | plt.grid() 190 | plt.subplot(2, 2, 2) 191 | for i, f in enumerate(file_list): 192 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb")) 193 | config = load_config(os.path.join(f, 'config.json')) 194 | plt.plot(np.asarray(data.result_dict.get('Training Accuracy')), label=str_list[i]) 195 | # plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), '*--', label='Validation ' + str_list[i]) 196 | plt.legend() 197 | plt.grid() 198 | plt.subplot(2, 2, 3) 199 | for i, f in enumerate(file_list): 200 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb")) 201 | config = load_config(os.path.join(f, 'config.json')) 202 | plt.plot(np.asarray(data.result_dict.get('Training Loss')), label=str_list[i]) 203 | # plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), '*--', label='Validation ' + str_list[i]) 204 | plt.legend() 205 | plt.grid() 206 | plt.show() 207 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch.nn as nn 3 | 4 | import torch 5 | import torch.optim as optim 6 | import os 7 | import pickle 8 | import argparse 9 | 10 | import gnas 11 | from models import model_cnn, model_rnn 12 | from cnn_utils import evaluate_single, evaluate_individual_list 13 | from rnn_utils import train_genetic_rnn, rnn_genetic_evaluate, rnn_evaluate 14 | from data import get_dataset 15 | from common import load_final, make_log_dir, get_model_type, ModelType 16 | from config import get_config, load_config, save_config 17 | from modules.drop_module import DropModuleControl 18 | from modules.cosine_annealing import CosineAnnealingLR 19 | 20 | ####################################### 21 | # Constants 22 | ####################################### 23 | log_interval = 200 24 | ####################################### 25 | # User input 26 | ####################################### 27 | parser = argparse.ArgumentParser(description='PyTorch GNAS') 28 | parser.add_argument('--dataset_name', type=str, choices=['CIFAR10', 'CIFAR100', 'PTB'], help='the working data', 29 | default='CIFAR10') 30 | parser.add_argument('--config_file', type=str, help='location of the config file') 31 | parser.add_argument('--search_dir', type=str, help='the log dir of the search') 32 | parser.add_argument('--final', type=bool, help='location of the config file', default=False) 33 | parser.add_argument('--data_path', type=str, default='./dataset/', help='location of the dataset') 34 | args = parser.parse_args() 35 | ####################################### 36 | # Search Working Device 37 | ####################################### 38 | working_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | print(working_device) 40 | ####################################### 41 | # Set seed 42 | ####################################### 43 | model_type = get_model_type(dataset_name=args.dataset_name) 44 | print("Selected mode type:" + str(model_type)) 45 | ####################################### 46 | # Parameters 47 | ####################################### 48 | config = get_config(model_type) 49 | if args.config_file is not None: 50 | print("Loading config file:" + args.config_file) 51 | config.update(load_config(args.config_file)) 52 | config.update({'data_path': args.data_path, 'dataset_name': args.dataset_name, 'working_device': str(working_device)}) 53 | print(config) 54 | ###################################### 55 | # Read dataset and set augmentation 56 | ###################################### 57 | trainloader, testloader, n_param = get_dataset(config) 58 | ###################################### 59 | # Config model and search space 60 | ###################################### 61 | if model_type == ModelType.CNN: 62 | min_objective = False 63 | n_cell_type = gnas.SearchSpaceType(config.get('n_block_type') - 1) 64 | dp_control = DropModuleControl(config.get('drop_path_keep_prob')) 65 | ss = gnas.get_gnas_cnn_search_space(config.get('n_nodes'), dp_control, n_cell_type) 66 | 67 | net = model_cnn.Net(config.get('n_blocks'), config.get('n_channels'), n_param, 68 | config.get('dropout'), 69 | ss, aux=config.get('aux_loss')).to(working_device) 70 | ###################################### 71 | # Build Optimizer and Loss function 72 | ##################################### 73 | optimizer = optim.SGD(net.parameters(), lr=config.get('learning_rate'), momentum=config.get('momentum'), 74 | nesterov=True, 75 | weight_decay=config.get('weight_decay')) 76 | elif model_type == ModelType.RNN: 77 | min_objective = True 78 | ntokens = n_param 79 | ss = gnas.get_gnas_rnn_search_space(config.get('n_nodes')) 80 | net = model_rnn.RNNModel(ntokens, config.get('n_channels'), config.get('n_channels'), config.get('n_blocks'), 81 | config.get('dropout'), 82 | tie_weights=True, 83 | ss=ss).to( 84 | working_device) 85 | ###################################### 86 | # Build Optimizer and Loss function 87 | ##################################### 88 | optimizer = optim.SGD(net.parameters(), lr=config.get('learning_rate'), 89 | weight_decay=config.get('weight_decay')) 90 | ###################################### 91 | # Build genetic_algorithm_searcher 92 | ##################################### 93 | ga = gnas.genetic_algorithm_searcher(ss, generation_size=config.get('generation_size'), 94 | population_size=config.get('population_size'), 95 | keep_size=config.get('keep_size'), mutation_p=config.get('mutation_p'), 96 | p_cross_over=config.get('p_cross_over'), 97 | cross_over_type=config.get('cross_over_type'), 98 | min_objective=min_objective) 99 | ###################################### 100 | # Loss function 101 | ###################################### 102 | criterion = nn.CrossEntropyLoss() 103 | ###################################### 104 | # Select Learning schedule 105 | ##################################### 106 | if config.get('LRType') == 'CosineAnnealingLR': 107 | scheduler = CosineAnnealingLR(optimizer, 10, 2, config.get('lr_min')) 108 | elif config.get('LRType') == 'MultiStepLR': 109 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 110 | [int(config.get('n_epochs') / 2), int(3 * config.get('n_epochs') / 4)]) 111 | elif config.get('LRType') == 'ExponentialLR': 112 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config.get('gamma')) 113 | else: 114 | raise Exception('unkown LRType:' + config.get('LRType')) 115 | 116 | ################################################## 117 | # Generate log dir and Save Params 118 | ################################################## 119 | log_dir = make_log_dir(config) 120 | save_config(log_dir, config) 121 | ####################################### 122 | # Load Indvidual 123 | ####################################### 124 | if args.final: ind = load_final(net, args.search_dir) 125 | ################################################## 126 | # Start Epochs 127 | ################################################## 128 | ra = gnas.ResultAppender() 129 | if model_type == ModelType.CNN: 130 | best = 0 131 | print("Starting Traing with CNN Model") 132 | for epoch in range(config.get('n_epochs')): # loop over the dataset multiple times 133 | # print(epoch) 134 | running_loss = 0.0 135 | correct = 0 136 | total = 0 137 | 138 | scheduler.step() 139 | s = time.time() 140 | net = net.train() 141 | if epoch == config.get('drop_path_start_epoch'): 142 | dp_control.enable() 143 | ############################################ 144 | # Loop over batchs update weights 145 | ############################################ 146 | for i, (inputs, labels) in enumerate(trainloader, 0): # Loop over batchs 147 | # get the inputs 148 | # sample child from population 149 | if not args.final: 150 | net.set_individual(ga.sample_child()) 151 | 152 | inputs = inputs.to(working_device) 153 | labels = labels.to(working_device) 154 | 155 | optimizer.zero_grad() # zero the parameter gradients 156 | outputs = net(inputs) # forward 157 | 158 | _, predicted = torch.max(outputs[0], 1) 159 | total += labels.size(0) 160 | correct += (predicted == labels).sum().item() 161 | 162 | loss = criterion(outputs[0], labels) 163 | if config.get('aux_loss'): loss += config.get('aux_scale') * criterion(outputs[1], labels) 164 | loss.backward() # backward 165 | 166 | optimizer.step() # optimize 167 | 168 | # print statistics 169 | running_loss += loss.item() 170 | ############################################ 171 | # Update GA population 172 | ############################################ 173 | if args.final: 174 | f_max = evaluate_single(ind, net, testloader, working_device) 175 | n_diff = 0 176 | else: 177 | if config.get('full_dataset'): 178 | for ind in ga.get_current_generation(): 179 | acc = evaluate_single(ind, net, testloader, working_device) 180 | ga.update_current_individual_fitness(ind, acc) 181 | _, _, f_max, _, n_diff = ga.update_population() 182 | best_individual = ga.best_individual 183 | else: 184 | 185 | f_max = 0 186 | n_diff = 0 187 | for _ in range(config.get('generation_per_epoch')): 188 | evaluate_individual_list(ga.get_current_generation(), ga, net, testloader, 189 | working_device) # evaluate next generation on the validation set 190 | _, _, v_max, _, n_d = ga.update_population() # replacement 191 | n_diff += n_d 192 | if v_max > f_max: 193 | f_max = v_max 194 | best_individual = ga.best_individual 195 | f_max = evaluate_single(best_individual, net, testloader, working_device) # evalute best 196 | if f_max > best: 197 | print("Update Best") 198 | best = f_max 199 | torch.save(net.state_dict(), os.path.join(log_dir, 'best_model.pt')) 200 | if not args.final: 201 | gnas.draw_network(ss, ga.best_individual, os.path.join(log_dir, 'best_graph_' + str(epoch) + '_')) 202 | pickle.dump(ga.best_individual, open(os.path.join(log_dir, 'best_individual.pickle'), "wb")) 203 | print( 204 | '|Epoch: {:2d}|Time: {:2.3f}|Loss:{:2.3f}|Accuracy: {:2.3f}%|Validation Accuracy: {:2.3f}%|LR: {:2.3f}|N Change : {:2d}|'.format( 205 | epoch, ( 206 | time.time() - s) / 60, 207 | running_loss / i, 208 | 100 * correct / total, f_max, 209 | scheduler.get_lr()[ 210 | -1], 211 | n_diff)) 212 | ra.add_epoch_result('N', n_diff) 213 | ra.add_epoch_result('Best', best) 214 | ra.add_epoch_result('Validation Accuracy', f_max) 215 | ra.add_epoch_result('LR', scheduler.get_lr()[-1]) 216 | ra.add_epoch_result('Training Loss', running_loss / i) 217 | ra.add_epoch_result('Training Accuracy', 100 * correct / total) 218 | if not args.final: 219 | ra.add_result('Fitness', ga.ga_result.fitness_list) 220 | ra.add_result('Fitness-Population', ga.ga_result.fitness_full_list) 221 | ra.save_result(log_dir) 222 | elif model_type == ModelType.RNN: 223 | best = 1000 224 | for epoch in range(1, config.get('n_epochs') + 1): 225 | if epoch > 15: 226 | scheduler.step() 227 | epoch_start_time = time.time() 228 | eval_batch_size = config.get('batch_size_val') 229 | train_loss = train_genetic_rnn(ga, trainloader, net, optimizer, criterion, ntokens, config.get('batch_size'), 230 | config.get('bptt'), config.get('clip'), 231 | log_interval, args.final) 232 | if args.final: 233 | min_loss = rnn_evaluate(net, criterion, testloader, ntokens, config.get('batch_size_val'), 234 | config.get('bptt')) 235 | else: 236 | val_loss, loss_var, max_loss, min_loss, n_diff = rnn_genetic_evaluate(ga, net, criterion, testloader, 237 | ntokens, 238 | config.get('batch_size_val'), 239 | config.get('bptt')) 240 | 241 | print('-' * 89) 242 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | lr {:02.2f} | ' 243 | ''.format(epoch, (time.time() - epoch_start_time), 244 | min_loss, scheduler.get_lr()[-1])) 245 | print('-' * 89) 246 | # Save the model if the validation loss is the best we've seen so far. 247 | if min_loss < best: 248 | print("Update Best") 249 | torch.save(net.state_dict(), os.path.join(log_dir, 'best_model.pt')) 250 | if not args.final: 251 | gnas.draw_network(ss, ga.best_individual, os.path.join(log_dir, 'best_graph_' + str(epoch) + '_')) 252 | pickle.dump(ga.best_individual, open(os.path.join(log_dir, 'best_individual.pickle'), "wb")) 253 | 254 | best = min_loss 255 | 256 | ra.add_epoch_result('Loss', train_loss) 257 | ra.add_epoch_result('LR', scheduler.get_lr()[-1]) 258 | ra.add_epoch_result('Best', best) 259 | if not args.final: ra.add_result('Fitness', ga.ga_result.fitness_list) 260 | ra.save_result(log_dir) 261 | print('Finished Training') 262 | --------------------------------------------------------------------------------