├── models
├── __init__.py
├── model_cnn.py
└── model_rnn.py
├── tests
├── __init__.py
├── test_genetic.py
├── common4testing.py
├── test_plot_augmentation.py
├── test_modules.py
└── test_search_space.py
├── gnas
├── common
│ ├── __init__.py
│ ├── bit_utils.py
│ ├── result.py
│ └── graph_draw.py
├── search_space
│ ├── __init__.py
│ ├── mutation.py
│ ├── individual.py
│ ├── search_space.py
│ ├── factory.py
│ ├── cross_over.py
│ └── operation_space.py
├── genetic_algorithm
│ ├── __init__.py
│ ├── ga_results.py
│ ├── population_dict.py
│ └── genetic.py
├── modules
│ ├── drop_path.py
│ ├── __init__.py
│ ├── operation_factory.py
│ ├── sub_graph_module.py
│ ├── cnn_block.py
│ ├── rnn_layer.py
│ ├── module_generator.py
│ └── node_module.py
└── __init__.py
├── modules
├── __init__.py
├── identity.py
├── se_block.py
├── cut_out.py
├── drop_module.py
├── cosine_annealing.py
└── weight_drop.py
├── images
├── search_result_cifar10.png
└── search_result_cifar100.png
├── .idea
└── vcs.xml
├── configs
├── config_cnn_final_cifar10.json
├── config_cnn_final_cifar100.json
├── config_cnn_search_cifar10.json
└── config_cnn_search_cifar100.json
├── common.py
├── LICENSE
├── gif_creator.py
├── cnn_utils.py
├── README.md
├── config.py
├── data.py
├── rnn_utils.py
├── plot_result.py
└── main.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gnas/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gnas/search_space/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gnas/genetic_algorithm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/gnas/modules/drop_path.py:
--------------------------------------------------------------------------------
1 | from modules.drop_module import DropModule, DropModuleControl
2 |
--------------------------------------------------------------------------------
/images/search_result_cifar10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haihabi/GeneticNAS/HEAD/images/search_result_cifar10.png
--------------------------------------------------------------------------------
/images/search_result_cifar100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haihabi/GeneticNAS/HEAD/images/search_result_cifar100.png
--------------------------------------------------------------------------------
/gnas/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from gnas.modules.rnn_layer import RnnSearchModule
2 | from gnas.modules.cnn_block import CnnSearchModule
3 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/modules/identity.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class Identity(nn.Module):
5 | def __init__(self):
6 | super(Identity, self).__init__()
7 |
8 | def forward(self, x):
9 | return x
10 |
--------------------------------------------------------------------------------
/gnas/common/bit_utils.py:
--------------------------------------------------------------------------------
1 | def vector_bits2int(arr):
2 | n = arr.shape[0] # number of columns
3 | a = arr[0] << n - 1
4 |
5 | for j in range(1, n):
6 | # "overlay" with the shifted bits of the next column
7 | a |= arr[j] << n - 1 - j
8 | return a
9 |
--------------------------------------------------------------------------------
/gnas/__init__.py:
--------------------------------------------------------------------------------
1 | from gnas.search_space.factory import get_gnas_cnn_search_space, get_gnas_rnn_search_space, SearchSpaceType
2 | from gnas.genetic_algorithm.genetic import genetic_algorithm_searcher
3 | from gnas.common.result import ResultAppender
4 | from gnas import modules
5 | from gnas.common.graph_draw import draw_network
6 |
--------------------------------------------------------------------------------
/gnas/genetic_algorithm/ga_results.py:
--------------------------------------------------------------------------------
1 | class GenetricResult(object):
2 | def __init__(self):
3 | self.population_list = []
4 | self.fitness_list = []
5 | self.fitness_full_list = []
6 | self.population_full_list = []
7 |
8 | def add_generation_result(self, fitness, population):
9 | self.fitness_list.append(fitness)
10 | self.population_list.append(population)
11 |
12 | def add_population_result(self, fitness, population):
13 | self.fitness_full_list.append(fitness)
14 | self.population_full_list.append(population)
15 |
--------------------------------------------------------------------------------
/gnas/modules/operation_factory.py:
--------------------------------------------------------------------------------
1 | from gnas.search_space.operation_space import RnnNodeConfig, RnnInputNodeConfig, CnnNodeConfig
2 | from gnas.modules.node_module import RnnNodeModule, RnnInputNodeModule, ConvNodeModule
3 |
4 | __module_dict__ = {RnnNodeConfig: RnnNodeModule,
5 | RnnInputNodeConfig: RnnInputNodeModule,
6 | CnnNodeConfig: ConvNodeModule}
7 |
8 |
9 | def get_module(node_config, config_dict):
10 | m = __module_dict__.get(type(node_config))
11 | if m is None:
12 | raise Exception('Can\'t find module named:' + node_config)
13 | return m(node_config, config_dict)
14 |
--------------------------------------------------------------------------------
/modules/se_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SEBlock(nn.Module):
6 | def __init__(self, n_channels, rate):
7 | super(SEBlock, self).__init__()
8 | self.se_block = nn.Sequential(nn.Linear(n_channels, int(n_channels / rate)),
9 | nn.ReLU(),
10 | nn.Linear(int(n_channels / rate), n_channels),
11 | nn.Sigmoid())
12 |
13 | def forward(self, x):
14 | x_gp = torch.mean(torch.mean(x, dim=-1), dim=-1)
15 | att = self.se_block(x_gp).unsqueeze(dim=-1).unsqueeze(dim=-1)
16 | return x * att
17 |
--------------------------------------------------------------------------------
/configs/config_cnn_final_cifar10.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size": 64,
3 | "batch_size_val": 1000,
4 | "n_epochs": 630,
5 | "n_blocks": 4,
6 | "n_block_type": 3,
7 | "n_nodes": 5,
8 | "n_channels": 20 ,
9 | "generation_size": 20,
10 | "generation_per_epoch": 2,
11 | "full_dataset": false,
12 | "population_size": 20,
13 | "keep_size": 0,
14 | "mutation_p": 0.02,
15 | "p_cross_over": 1.0,
16 | "cross_over_type": "Block",
17 | "learning_rate": 0.05,
18 | "lr_min": 0.0001,
19 | "weight_decay": 0.0001,
20 | "dropout": 0.2,
21 | "drop_path_keep_prob": 0.9,
22 | "drop_path_start_epoch": 0,
23 | "cutout": true,
24 | "n_holes": 1,
25 | "length": 16,
26 | "LRType": "MultiStepLR",
27 | "num_class": 10,
28 | "momentum": 0.9,
29 | "aux_loss": false,
30 | "aux_scale": 0.4
31 | }
--------------------------------------------------------------------------------
/configs/config_cnn_final_cifar100.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size": 50,
3 | "batch_size_val": 1000,
4 | "n_epochs": 630,
5 | "n_blocks": 4,
6 | "n_block_type": 3,
7 | "n_nodes": 5,
8 | "n_channels": 48 ,
9 | "generation_size": 20,
10 | "generation_per_epoch": 2,
11 | "full_dataset": false,
12 | "population_size": 20,
13 | "keep_size": 0,
14 | "mutation_p": 0.01,
15 | "p_cross_over": 1.0,
16 | "cross_over_type": "Block",
17 | "learning_rate": 0.05,
18 | "lr_min": 0.0001,
19 | "weight_decay": 0.0001,
20 | "dropout": 0.2,
21 | "drop_path_keep_prob": 0.9,
22 | "drop_path_start_epoch": 0,
23 | "cutout": true,
24 | "n_holes": 1,
25 | "length": 16,
26 | "LRType": "MultiStepLR",
27 | "num_class": 10,
28 | "momentum": 0.9,
29 | "aux_loss": false,
30 | "aux_scale": 0.4
31 | }
--------------------------------------------------------------------------------
/configs/config_cnn_search_cifar10.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size": 128,
3 | "batch_size_val": 1000,
4 | "n_epochs": 310,
5 | "n_blocks": 2,
6 | "n_block_type": 3,
7 | "n_nodes": 5,
8 | "n_channels": 20 ,
9 | "generation_size": 20,
10 | "generation_per_epoch": 2,
11 | "full_dataset": false,
12 | "population_size": 20,
13 | "keep_size": 0,
14 | "mutation_p": 0.02,
15 | "p_cross_over": 1.0,
16 | "cross_over_type": "Block",
17 | "learning_rate": 0.1,
18 | "lr_min": 0.0001,
19 | "weight_decay": 0.0001,
20 | "dropout": 0.2,
21 | "drop_path_keep_prob": 1.0,
22 | "drop_path_start_epoch": 50,
23 | "cutout": true,
24 | "n_holes": 1,
25 | "length": 16,
26 | "LRType": "MultiStepLR",
27 | "num_class": 10,
28 | "momentum": 0.9,
29 | "aux_loss": false,
30 | "aux_scale": 0.4
31 | }
--------------------------------------------------------------------------------
/configs/config_cnn_search_cifar100.json:
--------------------------------------------------------------------------------
1 | {
2 | "batch_size": 64,
3 | "batch_size_val": 1000,
4 | "n_epochs": 310,
5 | "n_blocks": 2,
6 | "n_block_type": 3,
7 | "n_nodes": 5,
8 | "n_channels": 48 ,
9 | "generation_size": 20,
10 | "generation_per_epoch": 2,
11 | "full_dataset": false,
12 | "population_size": 20,
13 | "keep_size": 0,
14 | "mutation_p": 0.01,
15 | "p_cross_over": 1.0,
16 | "cross_over_type": "Block",
17 | "learning_rate": 0.05,
18 | "lr_min": 0.0001,
19 | "weight_decay": 0.0001,
20 | "dropout": 0.2,
21 | "drop_path_keep_prob": 1.0,
22 | "drop_path_start_epoch": 50,
23 | "cutout": true,
24 | "n_holes": 1,
25 | "length": 16,
26 | "LRType": "MultiStepLR",
27 | "num_class": 10,
28 | "momentum": 0.9,
29 | "aux_loss": false,
30 | "aux_scale": 0.4
31 | }
--------------------------------------------------------------------------------
/tests/test_genetic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import unittest
3 | import gnas
4 |
5 |
6 | class TestGenetic(unittest.TestCase):
7 |
8 | def test_search_cnn_space(self):
9 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNSingleCell)
10 | ga = gnas.genetic_algorithm_searcher(ss, population_size=20, generation_size=20, p_cross_over=0.8)
11 | for i in range(30):
12 | for i, ind in enumerate(ga.get_current_generation()):
13 | ga.sample_child()
14 | ga.update_current_individual_fitness(ind, 0 + np.random.rand(1))
15 | ga.update_population()
16 | self.assertTrue(len(ga.max_dict) <= 200)
17 | self.assertTrue(len(ga.generation))
18 |
19 |
20 | if __name__ == '__main__':
21 | unittest.main()
22 |
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import datetime
4 | from enum import Enum
5 |
6 |
7 | class ModelType(Enum):
8 | CNN = 0
9 | RNN = 1
10 |
11 |
12 | def make_log_dir(config):
13 | log_dir = os.path.join('.', 'logs', datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))
14 | os.makedirs(log_dir, exist_ok=True)
15 | return log_dir
16 |
17 |
18 | def load_final(model, search_dir):
19 | ind_file = os.path.join(search_dir, 'best_individual.pickle')
20 | ind = pickle.load(open(ind_file, "rb"))
21 | model.set_individual(ind)
22 | return ind
23 |
24 |
25 | def get_model_type(dataset_name):
26 | if dataset_name in ['CIFAR10', 'CIFAR100']:
27 | return ModelType.CNN
28 | elif dataset_name == 'PTB':
29 | return ModelType.RNN
30 | else:
31 | raise Exception('unkown model for dataset:' + dataset_name)
32 |
--------------------------------------------------------------------------------
/gnas/common/result.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 |
5 | class ResultAppender(object):
6 | def __init__(self):
7 | self.result_dict = dict()
8 |
9 | def add_epoch_result(self, result_name: str, result_var: float):
10 | if self.result_dict.get(result_name) is None:
11 | self.result_dict.update({result_name: [result_var]})
12 | else:
13 | self.result_dict.get(result_name).append(result_var)
14 |
15 | def add_result(self, result_name: str, result_array):
16 | self.result_dict.update({result_name: result_array})
17 |
18 | def save_result(self, input_path):
19 | pickle.dump(self, open(os.path.join(input_path, 'ga_result.pickle'), "wb"))
20 |
21 | @staticmethod
22 | def load_result(input_path):
23 | return pickle.load(open(os.path.join(input_path, 'ga_result.pickle'), "rb"))
24 |
--------------------------------------------------------------------------------
/tests/common4testing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gnas.search_space.operation_space import RnnInputNodeConfig, RnnNodeConfig, CnnNodeConfig
3 | from gnas.search_space.search_space import SearchSpace
4 | from modules.drop_module import DropModuleControl
5 | import gnas
6 |
7 |
8 | def generate_ss():
9 | nll = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid']
10 | node_config_list = [RnnInputNodeConfig(2, [0, 1], nll)]
11 | for i in range(12):
12 | node_config_list.append(RnnNodeConfig(3 + i, list(np.linspace(2, 2 + i, 1 + i).astype('int')), nll))
13 | ss = SearchSpace(node_config_list)
14 | return ss
15 |
16 |
17 | def generate_ss_cnn():
18 | nll = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid']
19 | op = ['Conv3x3', 'Dw3x3', 'Conv5x5', 'Dw5x5']
20 | dp_control = DropModuleControl(1)
21 | node_config_list = [CnnNodeConfig(2, [0, 1], op, dp_control)]
22 | for i in range(3):
23 | node_config_list.append(CnnNodeConfig(3 + i, list(np.linspace(0, 2 + i, 3 + i).astype('int')), op, dp_control))
24 | ss = SearchSpace(node_config_list)
25 | return ss
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 HVH
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/gnas/search_space/mutation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gnas.search_space.individual import Individual, MultipleBlockIndividual
3 |
4 |
5 | def flip_max_value(current_value, max_value, p):
6 | flip = np.floor(np.random.rand(current_value.shape[0]) + p).astype('int')
7 | sign = (2 * (np.round(np.random.rand(current_value.shape[0])) - 0.5)).astype('int')
8 | new_dna = current_value + flip * sign
9 | new_dna[new_dna > max_value] = 0
10 | new_dna[new_dna < 0] = max_value[new_dna < 0]
11 | return new_dna
12 |
13 |
14 | def _individual_flip_mutation(individual_a, p) -> Individual:
15 | max_values = individual_a.ss.get_max_values_vector(index=individual_a.index)
16 | new_iv = []
17 | for m, iv in zip(max_values, individual_a.iv):
18 | new_iv.append(flip_max_value(iv, m, p))
19 | return individual_a.update_individual(new_iv)
20 |
21 |
22 | def individual_flip_mutation(individual_a, p):
23 | if isinstance(individual_a, Individual):
24 | return _individual_flip_mutation(individual_a, p)
25 | else:
26 | return MultipleBlockIndividual([_individual_flip_mutation(inv, p) for inv in individual_a.individual_list])
27 |
--------------------------------------------------------------------------------
/gif_creator.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import glob
3 | import os
4 | from PIL import Image
5 | from PIL import ImageFont
6 | from PIL import ImageDraw
7 | import numpy as np
8 |
9 | images = dict()
10 | epoch_dict = dict()
11 | log_dir = '/data/projects/GNAS/logs/2019_02_17_20_25_42'
12 |
13 | font = ImageFont.truetype("/home/haih/Downloads/Untitled Folder/Microsoft Sans Serif.ttf", 16)
14 | for filename in glob.glob(log_dir + "/*.png"):
15 | layer_index = int(str.split(filename, '_')[-1].split('.png')[0])
16 | epoch = int(str.split(filename, '_')[-2])
17 | img = imageio.imread(filename)
18 | img_pil = Image.fromarray(img)
19 | draw = ImageDraw.Draw(img_pil)
20 |
21 | draw.text((0, 0), 'Epoch:' + str(epoch), (0, 0, 0), font=font)
22 | img = np.array(img_pil.getdata()).reshape(img_pil.size[1], img_pil.size[0], 4)
23 | if images.get(layer_index) is None:
24 | images.update({layer_index: [img]})
25 | epoch_dict.update({layer_index: [epoch]})
26 | else:
27 | images.get(layer_index).append(img)
28 | epoch_dict.get(layer_index).append(epoch)
29 |
30 | # Reorder
31 | for k, v in epoch_dict.items():
32 | index = np.argsort(np.asarray(v)).astype('int')
33 | images.update({k: [np.asarray(images.get(k))[i] for i in index]})
34 |
35 | for k, v in images.items():
36 | imageio.mimsave(os.path.join(log_dir, 'cell_movie_' + str(k) + '.gif'), v, duration=0.5)
37 |
--------------------------------------------------------------------------------
/modules/cut_out.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | # copy from https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
6 | class Cutout(object):
7 | """Randomly mask out one or more patches from an image.
8 | Args:
9 | n_holes (int): Number of patches to cut out of each image.
10 | length (int): The length (in pixels) of each square patch.
11 | """
12 |
13 | def __init__(self, n_holes, length):
14 | self.n_holes = n_holes
15 | self.length = length
16 |
17 | def __call__(self, img):
18 | """
19 | Args:
20 | img (Tensor): Tensor image of size (C, H, W).
21 | Returns:
22 | Tensor: Image with n_holes of dimension length x length cut out of it.
23 | """
24 | h = img.size(1)
25 | w = img.size(2)
26 |
27 | mask = np.ones((h, w), np.float32)
28 |
29 | for n in range(self.n_holes):
30 | y = np.random.randint(h)
31 | x = np.random.randint(w)
32 |
33 | y1 = np.clip(y - self.length // 2, 0, h)
34 | y2 = np.clip(y + self.length // 2, 0, h)
35 | x1 = np.clip(x - self.length // 2, 0, w)
36 | x2 = np.clip(x + self.length // 2, 0, w)
37 |
38 | mask[y1: y2, x1: x2] = 0.
39 |
40 | mask = torch.from_numpy(mask)
41 | mask = mask.expand_as(img)
42 | img = img * mask
43 |
44 | return img
45 |
--------------------------------------------------------------------------------
/modules/drop_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.cuda
3 | import torch.nn as nn
4 | from random import random
5 | from torch.autograd import Variable
6 |
7 |
8 | class DropModuleControl(object):
9 | def __init__(self, drop_prob=0.9):
10 | self.drop_prob = drop_prob
11 | self.status = False
12 |
13 | def enable(self):
14 | self.status = True
15 |
16 |
17 | class DropModule(nn.Module):
18 | def __init__(self, module, drop_control: DropModuleControl):
19 | super(DropModule, self).__init__()
20 | self.module = module
21 | self.shape = None
22 | self.drop_control = drop_control
23 | self.tensor_init = torch.FloatTensor
24 |
25 | def update_tensor_shape(self, *input):
26 | if self.shape is None:
27 | output_tensor = self.module(*input)
28 | self.shape = output_tensor.size() # fetch tensor shape
29 | if output_tensor.data.is_cuda: self.tensor_init = torch.cuda.FloatTensor
30 |
31 | def forward(self, *input):
32 | self.update_tensor_shape(*input)
33 | if self.training and self.drop_control.status:
34 | if random() <= self.drop_control.drop_prob: # forward module tensor
35 | return self.module(*input) / self.drop_control.drop_prob # Apply scaling
36 | else: # forward zero tensor
37 | return Variable(self.tensor_init(torch.Size([input[0].shape[0], *list(self.shape[1:])])).zero_())
38 | else: # Inference
39 | return self.module(*input)
40 |
--------------------------------------------------------------------------------
/cnn_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.cuda
3 |
4 |
5 | def evaluate_single(input_individual, input_model, data_loader, device):
6 | correct = 0
7 | total = 0
8 | input_model = input_model.eval()
9 | input_model.set_individual(input_individual)
10 | with torch.no_grad():
11 | for data in data_loader:
12 | images, labels = data
13 | images = images.to(device)
14 | labels = labels.to(device)
15 | outputs = input_model(images)
16 | _, predicted = torch.max(outputs[0].data, 1)
17 | total += labels.size(0)
18 | correct += (predicted == labels).sum().item()
19 | return 100 * correct / total
20 |
21 |
22 | def evaluate_individual_list(input_individual_list, ga, input_model, data_loader, device):
23 | correct = 0
24 | total = 0
25 | input_model = input_model.eval()
26 | i = 0
27 | with torch.no_grad():
28 | while len(input_individual_list) > i:
29 | for data in data_loader:
30 | if len(input_individual_list) <= i:
31 | pass
32 | else:
33 | ind = input_individual_list[i]
34 | input_model.set_individual(ind)
35 | images, labels = data
36 | images = images.to(device)
37 | labels = labels.to(device)
38 | outputs = input_model(images)
39 | _, predicted = torch.max(outputs[0].data, 1)
40 | total += labels.size(0)
41 | correct += (predicted == labels).sum().item()
42 | acc = 100 * correct / total
43 | ga.update_current_individual_fitness(ind, acc)
44 | i += 1
45 |
--------------------------------------------------------------------------------
/tests/test_plot_augmentation.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | from modules.cut_out import Cutout
6 |
7 | from matplotlib import pyplot as plt
8 | import numpy as np
9 |
10 |
11 | class MyTestCase(unittest.TestCase):
12 | def test_something(self):
13 | train_transform = transforms.Compose([])
14 | # normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
15 | # std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
16 | train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
17 | train_transform.transforms.append(transforms.RandomHorizontalFlip())
18 | train_transform.transforms.append(transforms.ToTensor())
19 | # train_transform.transforms.append(normalize)
20 | train_transform.transforms.append(Cutout(n_holes=1, length=16))
21 |
22 | trainset = torchvision.datasets.CIFAR10(root='./dataset', train=True,
23 | download=True, transform=train_transform)
24 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
25 | shuffle=True, num_workers=4)
26 | ds_train, l = next(iter(trainloader))
27 |
28 |
29 | fig = plt.figure(figsize=(8, 8));
30 | columns = 4
31 | rows = 2
32 | for i in range(1, columns * rows + 1):
33 | img_xy = np.random.randint(len(ds_train));
34 | img = ds_train[img_xy][:, :, :].numpy()
35 | img=np.transpose(img,[1,2,0])
36 | fig.add_subplot(rows, columns, i)
37 | # plt.title(labels_map[int(ds_train[img_xy][1].numpy())])
38 | plt.axis('off')
39 | plt.imshow(img)
40 | plt.show()
41 |
42 |
43 | if __name__ == '__main__':
44 | unittest.main()
45 |
--------------------------------------------------------------------------------
/gnas/search_space/individual.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Individual(object):
5 | def __init__(self, individual_vector, max_inputs, search_space, index=0):
6 | self.iv = individual_vector
7 | self.mi = max_inputs
8 | self.ss = search_space
9 | # Generate config when generating individual
10 | self.index = index
11 | self.config_list = [oc.parse_config(iv) for iv, oc in zip(self.iv, self.ss.get_opeartion_config(self.index))]
12 | self.code = np.concatenate(self.iv, axis=0)
13 |
14 | def get_length(self):
15 | return len(self.code)
16 |
17 | def get_n_op(self):
18 | return len(self.iv)
19 |
20 | def copy(self):
21 | return Individual(self.iv, self.mi, self.ss, index=self.index)
22 |
23 | def generate_node_config(self):
24 | return self.config_list
25 |
26 | def update_individual(self, individual_vector):
27 | return Individual(individual_vector, self.mi, self.ss, index=self.index)
28 |
29 | def __eq__(self, other):
30 | return np.array_equal(self.code, other.code)
31 |
32 | def __str__(self):
33 | return "code:" + str(self.code)
34 |
35 | def __hash__(self):
36 | return hash(str(self))
37 |
38 |
39 | class MultipleBlockIndividual(object):
40 | def __init__(self, individual_list):
41 | self.individual_list = individual_list
42 | self.code = np.concatenate([i.code for i in self.individual_list])
43 |
44 | def get_individual(self, index):
45 | return self.individual_list[index]
46 |
47 | def generate_node_config(self, index):
48 | return self.individual_list[index].generate_node_config()
49 |
50 | def update_individual(self, individual_vector):
51 | raise NotImplemented
52 |
53 | def __eq__(self, other):
54 | return np.array_equal(self.code, other.code)
55 |
56 | def __str__(self):
57 | return "code:" + str(self.code)
58 |
59 | def __hash__(self):
60 | return hash(str(self))
61 |
--------------------------------------------------------------------------------
/modules/cosine_annealing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.optim as optim
3 |
4 |
5 | # copy from:https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py
6 | # The modification list:
7 | # 1. add T_mul to the init function
8 | # 2. change forward function to multiply the LR.
9 |
10 | class CosineAnnealingLR(optim.lr_scheduler._LRScheduler):
11 | r"""Set the learning rate of each parameter group using a cosine annealing
12 | schedule, where :math:`\eta_{max}` is set to the initial lr and
13 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
14 |
15 | .. math::
16 |
17 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
18 | \cos(\frac{T_{cur}}{T_{max}}\pi))
19 |
20 | When last_epoch=-1, sets initial lr as lr.
21 |
22 | It has been proposed in
23 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
24 | implements the cosine annealing part of SGDR, and not the restarts.
25 |
26 | Args:
27 | optimizer (Optimizer): Wrapped optimizer.
28 | T_max (int): Maximum number of iterations.
29 | eta_min (float): Minimum learning rate. Default: 0.
30 | last_epoch (int): The index of last epoch. Default: -1.
31 |
32 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
33 | https://arxiv.org/abs/1608.03983
34 | """
35 |
36 | def __init__(self, optimizer, T_max, T_mul, eta_min=0, last_epoch=-1):
37 | self.T_max = T_max
38 | self.T_mul = T_mul
39 | self.eta_min = eta_min
40 | super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
41 |
42 | def get_lr(self):
43 | lr = [self.eta_min + (base_lr - self.eta_min) *
44 | (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
45 | for base_lr in self.base_lrs]
46 | if self.last_epoch != 0 and self.last_epoch % self.T_max == 0:
47 | self.T_max = self.T_mul * self.T_max
48 | self.last_epoch = 0
49 | return lr
50 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Genetic Neural Architecture Search (GeneticNAS)
2 | The genetic neural architecture search (GeneticNAS) is a neural architecture search method that is based on genetic algorithm which utilized weight sharing accross all candidate network. The project paper:https://arxiv.org/abs/1907.02871
3 |
4 | Includes code for CIFAR-10 and CIFAR-100 image classification
5 |
6 | # Installation
7 | The first is install all the flowing prerequisites using conda:
8 | * pytorch
9 | * graphviz
10 | * pygraphviz
11 | * numpy
12 |
13 | ```javascript
14 | conda install graphviz
15 | conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
16 | conda install pygraphviz
17 | conda install numpy
18 | ```
19 |
20 | # Examples Run Search
21 | In this section provide exmaple of how to run architecture search on there dataset CIFAR10 and CIFAR100, at the end of search a log folder is create under the current folder
22 | #### CIFAR 10
23 | ```javascript
24 | python main.py --dataset_name CIFAR10 --config_file ./configs/config_cnn_search_cifar10.json
25 | ```
26 | #### CIFAR 100
27 | ```javascript
28 | python main.py --dataset_name CIFAR100 --config_file ./configs/config_cnn_search_cifar100.json
29 | ```
30 |
31 | # Examples Run Final Training
32 | In this section provide exmaple of how to run final training search on there dataset CIFAR10 and CIFAR100, where $LOG_DIR is the log folder of the search result.
33 | #### CIFAR 10
34 | ```javascript
35 | python main.py --dataset_name CIFAR10 --final 1 --serach_dir $LOG_DIR --config_file ./configs/config_cnn_final_cifar10.json
36 | ```
37 | #### CIFAR 100
38 | ```javascript
39 | python main.py --dataset_name CIFAR100 --final 1 --serach_dir $LOG_DIR --config_file ./configs/config_cnn_final_cifar10.json
40 | ```
41 |
42 | # Result
43 |
44 | ## CIFAR10 Counvulation Cell
45 | 
46 |
47 |
48 | ## CIFAR100 Counvulation Cell
49 | 
50 |
51 | ## Counvulation cell final result
52 | | Dataset | Accuracy[%] |
53 | | --- | --- |
54 | | CIFAR10 | 96% |
55 | | CIFAR100 | 80.1% |
56 |
--------------------------------------------------------------------------------
/gnas/modules/sub_graph_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from gnas.search_space.individual import Individual
5 | from gnas.modules.operation_factory import get_module
6 |
7 |
8 | class SubGraphModule(nn.Module):
9 | def __init__(self, search_space, config_dict, individual_index=0):
10 | super(SubGraphModule, self).__init__()
11 | self.ss = search_space
12 | self.config_dict = config_dict
13 | self.individual_index = individual_index
14 | if self.ss.single_block:
15 | self.block_modules = [get_module(oc, config_dict) for oc in self.ss.ocl]
16 | else:
17 | self.block_modules = [get_module(oc, config_dict) for oc in
18 | self.ss.ocl[individual_index]]
19 | #
20 | [self.add_module('Node' + str(i), n) for i, n in enumerate(self.block_modules)]
21 |
22 | def forward(self, *input_list):
23 | # input list at start is h_n and h_(n-1)
24 | net = list(input_list)
25 | for nm in self.block_modules: # loop over all blocks
26 | net.append(nm(net)) # call each block in the sub graph
27 | return net # output list of all block in the sub graph
28 |
29 | def set_individual(self, individual: Individual):
30 | if not self.ss.single_block:
31 | individual = individual.get_individual(self.individual_index)
32 | si_list = []
33 | for nc, nm in zip(individual.generate_node_config(), self.block_modules):
34 | nm.set_current_node_config(nc)
35 | if nm.__dict__.get('select_index') is not None:
36 | si_list.append(nm.select_index)
37 |
38 | current_node_list = np.unique(si_list)
39 | if self.ss.single_block:
40 | self.avg_index = np.asarray([n.node_id for n in self.ss.ocl if n.node_id not in current_node_list]).astype(
41 | 'int')
42 | else:
43 | self.avg_index = np.asarray(
44 | [n.node_id for n in self.ss.ocl[self.individual_index] if n.node_id not in current_node_list]).astype(
45 | 'int')
46 |
--------------------------------------------------------------------------------
/modules/weight_drop.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.nn import Parameter
4 |
5 | # This module is copy from https://github.com/salesforce/awd-lstm-lm/blob/master/weight_drop.py
6 | class WeightDrop(torch.nn.Module):
7 | def __init__(self, module, weights, dropout=0, variational=False):
8 | super(WeightDrop, self).__init__()
9 | self.module = module
10 | self.weights = weights
11 | self.dropout = dropout
12 | self.variational = variational
13 | self._setup()
14 |
15 | def widget_demagnetizer_y2k_edition(*args, **kwargs):
16 | # We need to replace flatten_parameters with a nothing function
17 | # It must be a function rather than a lambda as otherwise pickling explodes
18 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
19 | # (╯°□°)╯︵ ┻━┻
20 | return
21 |
22 | def _setup(self):
23 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
24 | if issubclass(type(self.module), torch.nn.RNNBase):
25 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition
26 |
27 | for name_w in self.weights:
28 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
29 | w = getattr(self.module, name_w)
30 | del self.module._parameters[name_w]
31 | self.module.register_parameter(name_w + '_raw', Parameter(w.data))
32 |
33 | def _setweights(self):
34 | for name_w in self.weights:
35 | raw_w = getattr(self.module, name_w + '_raw')
36 | w = None
37 | if self.variational:
38 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
39 | if raw_w.is_cuda: mask = mask.cuda()
40 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
41 | w = mask.expand_as(raw_w) * raw_w
42 | else:
43 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
44 | setattr(self.module, name_w, w)
45 |
46 | def forward(self, *args):
47 | self._setweights()
48 | return self.module.forward(*args)
49 |
--------------------------------------------------------------------------------
/gnas/search_space/search_space.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gnas.search_space.individual import Individual, MultipleBlockIndividual
3 |
4 |
5 | class SearchSpace(object):
6 | def __init__(self, operation_config_list: list, single_block=True):
7 | self.single_block = single_block
8 | self.ocl = operation_config_list
9 | if single_block:
10 | self.n_elements = sum([len(self.generate_vector(o.max_values_vector(i))) for i, o in enumerate(self.ocl)])
11 | else:
12 | self.n_elements = sum(
13 | [sum([len(self.generate_vector(o.max_values_vector(i))) for i, o in enumerate(block)]) for block in
14 | self.ocl])
15 |
16 | def get_operation_configs(self):
17 | return self.ocl
18 |
19 | def get_n_nodes(self):
20 | if self.single_block:
21 | return len(self.ocl)
22 | else:
23 | return [len(ocl) for ocl in self.ocl]
24 |
25 | def get_max_values_vector(self, index=0):
26 | if self.single_block:
27 | return [o.max_values_vector(i) for i, o in enumerate(self.ocl)]
28 | else:
29 | return [o.max_values_vector(i) for i, o in enumerate(self.ocl[index])]
30 |
31 | def get_opeartion_config(self, index=0):
32 | if self.single_block:
33 | return self.ocl
34 | else:
35 | return self.ocl[index]
36 |
37 | def generate_vector(self, max_values):
38 | return np.asarray([np.random.randint(0, mv + 1) for mv in max_values])
39 |
40 | def _generate_individual_single(self, ocl, index=0):
41 | operation_vector = [self.generate_vector(o.max_values_vector(i)) for i, o in enumerate(ocl)]
42 | max_inputs = [i for i, _ in enumerate(ocl)]
43 | return Individual(operation_vector, max_inputs, self, index=index)
44 |
45 | def generate_individual(self):
46 | if self.single_block:
47 | return self._generate_individual_single(self.ocl)
48 | else:
49 | return MultipleBlockIndividual(
50 | [self._generate_individual_single(ocl, index=i) for i, ocl in enumerate(self.ocl)])
51 |
52 | def generate_population(self, size):
53 | return [self.generate_individual() for _ in range(size)]
54 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from common import ModelType
4 |
5 |
6 | def save_config(path_dir, config):
7 | with open(os.path.join(path_dir, 'config.json'), 'w') as outfile:
8 | json.dump(config, outfile)
9 |
10 |
11 | def load_config(path_dir):
12 | with open(path_dir, 'r') as json_file:
13 | data = json.load(json_file)
14 | return data
15 |
16 |
17 | def get_config(model_type):
18 | if ModelType.CNN == model_type:
19 | return default_config_cnn()
20 | elif ModelType.RNN == model_type:
21 | return default_config_rnn()
22 | else:
23 | raise Exception('unkown model type:' + str(model_type))
24 |
25 |
26 | def default_config_rnn():
27 | return {'batch_size': 64,
28 | 'batch_size_val': 10,
29 | 'bptt': 35,
30 | 'n_epochs': 310,
31 | 'n_blocks': 2,
32 | 'n_nodes': 12,
33 | 'n_channels': 200,
34 | 'clip': 0.25,
35 | 'generation_size': 20,
36 | 'population_size': 20,
37 | 'keep_size': 0,
38 | 'mutation_p': 0.02,
39 | 'p_cross_over': 1.0,
40 | 'cross_over_type': 'Block',
41 | 'learning_rate': 20.0,
42 | 'weight_decay': 0.0001,
43 | 'dropout': 0.2,
44 | 'LRType': 'ExponentialLR',
45 | 'gamma': 0.96}
46 |
47 |
48 | def default_config_cnn():
49 | return {'batch_size': 128,
50 | 'batch_size_val': 1000,
51 | 'n_epochs': 310,
52 | 'n_blocks': 2,
53 | 'n_block_type': 3,
54 | 'n_nodes': 5,
55 | 'n_channels': 20,
56 | 'generation_size': 20,
57 | 'generation_per_epoch': 2,
58 | 'full_dataset': False,
59 | 'population_size': 20,
60 | 'keep_size': 0,
61 | 'mutation_p': 0.02,
62 | 'p_cross_over': 1.0,
63 | 'cross_over_type': 'Block',
64 | 'learning_rate': 0.1,
65 | 'lr_min': 0.0001,
66 | 'weight_decay': 0.0001,
67 | 'dropout': 0.2,
68 | 'drop_path_keep_prob': 1.0,
69 | 'drop_path_start_epoch': 50,
70 | 'cutout': True,
71 | 'n_holes': 1,
72 | 'length': 16,
73 | 'LRType': 'MultiStepLR',
74 | 'num_class': 10,
75 | 'momentum': 0.9,
76 | 'aux_loss': False,
77 | 'aux_scale': 0.4}
78 |
--------------------------------------------------------------------------------
/gnas/search_space/factory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gnas.search_space.search_space import SearchSpace
3 | from gnas.search_space.operation_space import CnnNodeConfig, RnnNodeConfig, RnnInputNodeConfig
4 | from enum import Enum
5 |
6 | CNN_OP = ['Dw3x3', 'Identity', 'Dw5x5', 'Avg3x3', 'Max3x3']
7 | RNN_OP = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid']
8 |
9 |
10 | class SearchSpaceType(Enum):
11 | CNNSingleCell = 0
12 | CNNDualCell = 1
13 | CNNTripleCell = 2
14 |
15 |
16 | def _two_input_cell(n_nodes, drop_path_control):
17 | node_config_list = [CnnNodeConfig(2, [0, 1], CNN_OP, drop_path_control=drop_path_control)]
18 | for i in range(n_nodes - 1):
19 | node_config_list.append(
20 | CnnNodeConfig(3 + i, list(np.linspace(0, 2 + i, 3 + i).astype('int')), CNN_OP,
21 | drop_path_control=drop_path_control))
22 | return node_config_list
23 |
24 |
25 | def _one_input_cell(n_nodes, drop_path_control):
26 | node_config_list = [CnnNodeConfig(1, [0], CNN_OP, drop_path_control=drop_path_control)]
27 | for i in range(n_nodes - 1):
28 | node_config_list.append(
29 | CnnNodeConfig(2 + i, list(np.linspace(0, 1 + i, 2 + i).astype('int')), CNN_OP,
30 | drop_path_control=drop_path_control))
31 | return node_config_list
32 |
33 |
34 | def get_gnas_cnn_search_space(n_nodes, drop_path_control, n_cell_type: SearchSpaceType) -> SearchSpace:
35 | node_config_list_a = _two_input_cell(n_nodes, drop_path_control)
36 | if n_cell_type == SearchSpaceType.CNNSingleCell:
37 | return SearchSpace(node_config_list_a)
38 | elif n_cell_type == SearchSpaceType.CNNDualCell:
39 | node_config_list_b = _two_input_cell(n_nodes, drop_path_control)
40 | return SearchSpace([node_config_list_a, node_config_list_b], single_block=False)
41 | elif n_cell_type == SearchSpaceType.CNNTripleCell:
42 | node_config_list_b = _two_input_cell(n_nodes, drop_path_control)
43 | node_config_list_c = _one_input_cell(n_nodes, drop_path_control)
44 | return SearchSpace([node_config_list_a, node_config_list_b, node_config_list_c], single_block=False)
45 |
46 |
47 | def get_gnas_rnn_search_space(n_nodes) -> SearchSpace:
48 | node_config_list = [RnnInputNodeConfig(2, [0, 1], RNN_OP)]
49 | for i in range(n_nodes - 1):
50 | node_config_list.append(RnnNodeConfig(3 + i, list(np.linspace(2, 2 + i, 1 + i).astype('int')), RNN_OP))
51 | return SearchSpace(node_config_list)
52 |
--------------------------------------------------------------------------------
/gnas/search_space/cross_over.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gnas.search_space.individual import Individual, MultipleBlockIndividual
3 |
4 | def _individual_uniform_crossover(individual_a: Individual, individual_b: Individual):
5 | n = individual_a.get_length()
6 | selection = np.random.randint(0, 2, n)
7 | i = 0
8 | iv_a = []
9 | iv_b = []
10 | for a, b in zip(individual_a.iv, individual_b.iv):
11 | current_selection = selection[i:i + len(a)]
12 | iv_a.append(a * current_selection + b * (1 - current_selection))
13 | iv_b.append(a * (1 - current_selection) + b * current_selection)
14 | i += len(a)
15 | return Individual(iv_a, individual_a.mi, individual_a.ss, index=individual_a.index), \
16 | Individual(iv_b, individual_b.mi, individual_b.ss, index=individual_b.index)
17 |
18 |
19 | def _individual_block_crossover(individual_a: Individual, individual_b: Individual):
20 | n = individual_a.get_n_op()
21 | selection = np.random.randint(0, 2, n)
22 | iv_a = []
23 | iv_b = []
24 | for i, (a, b) in enumerate(zip(individual_a.iv, individual_b.iv)):
25 | # current_selection = selection[i]
26 | iv_a.append(a * selection[i] + b * (1 - selection[i]))
27 | iv_b.append(a * (1 - selection[i]) + b * selection[i])
28 |
29 | return Individual(iv_a, individual_a.mi, individual_a.ss, index=individual_a.index), \
30 | Individual(iv_b, individual_b.mi, individual_b.ss, index=individual_b.index)
31 |
32 |
33 | def individual_uniform_crossover(individual_a, individual_b, p_c):
34 | if np.random.rand(1) < p_c:
35 | if isinstance(individual_a, Individual):
36 | return _individual_uniform_crossover(individual_a, individual_b)
37 | else:
38 | pairs = [_individual_uniform_crossover(inv_a, inv_b) for inv_a, inv_b in
39 | zip(individual_a.individual_list, individual_b.individual_list)]
40 | return MultipleBlockIndividual([p[0] for p in pairs]), MultipleBlockIndividual([p[1] for p in pairs])
41 | else:
42 | return individual_a, individual_b
43 |
44 |
45 | def individual_block_crossover(individual_a, individual_b, p_c):
46 | if np.random.rand(1) < p_c:
47 | if isinstance(individual_a, Individual):
48 | return _individual_block_crossover(individual_a, individual_b)
49 | else:
50 | pairs = [_individual_block_crossover(inv_a, inv_b) for inv_a, inv_b in
51 | zip(individual_a.individual_list, individual_b.individual_list)]
52 | return MultipleBlockIndividual([p[0] for p in pairs]), MultipleBlockIndividual([p[1] for p in pairs])
53 | else:
54 | return individual_a, individual_b
55 |
--------------------------------------------------------------------------------
/gnas/genetic_algorithm/population_dict.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from operator import itemgetter
3 | import sys
4 | import copy
5 |
6 |
7 | class PopulationDict(object):
8 | def __init__(self, values_dict: OrderedDict = None, index_dict: OrderedDict = None,
9 | current_index=0):
10 | if values_dict is None:
11 | self.values_dict = OrderedDict({})
12 | else:
13 | self.values_dict = values_dict
14 | if index_dict is None:
15 | self.index_dict = OrderedDict({})
16 | else:
17 | self.index_dict = index_dict
18 | self.i = current_index
19 |
20 | def copy(self):
21 | return PopulationDict(copy.deepcopy(self.values_dict), copy.deepcopy(self.index_dict), self.i)
22 |
23 | def __len__(self):
24 | return len(self.values_dict)
25 |
26 | def __str__(self):
27 | return str(self.values_dict)
28 |
29 | def items(self):
30 | return self.values_dict.items()
31 |
32 | def values(self):
33 | return self.values_dict.values()
34 |
35 | def keys(self):
36 | return self.values_dict.keys()
37 |
38 | def update(self, input_dict: dict):
39 | self.values_dict.update(input_dict)
40 | for k, v in input_dict.items():
41 | self.index_dict.update({k: self.i})
42 | self.i += 1
43 |
44 | def filter_top_n(self, n=sys.maxsize, min_max=True):
45 | values_dict = OrderedDict({})
46 | index_dict = OrderedDict({})
47 | for i, (key, value) in enumerate(sorted(self.values_dict.items(), key=itemgetter(1), reverse=min_max)):
48 | if i < n:
49 | values_dict.update({key: value})
50 | index_dict.update({key: self.index_dict.get(key)})
51 | return PopulationDict(values_dict, index_dict, self.i)
52 |
53 | def filter_last_n(self, n=sys.maxsize):
54 | values_dict = OrderedDict({})
55 | index_dict = OrderedDict({})
56 | for i, (key, index) in enumerate(sorted(self.index_dict.items(), key=itemgetter(1), reverse=True)):
57 | if i < n:
58 | values_dict.update({key: self.values_dict.get(key)})
59 | index_dict.update({key: index})
60 | return PopulationDict(values_dict, index_dict, self.i)
61 |
62 | def merge(self, other):
63 | values_dict = self.values_dict.copy()
64 | index_dict = self.index_dict.copy()
65 | values_dict.update(other.values_dict)
66 | index_dict.update(other.index_dict)
67 | return PopulationDict(values_dict, index_dict, self.i)
68 |
69 | def get_n_diff(self, other):
70 | n = sum([1 for k in other.keys() if k not in list(self.keys())])
71 | return n
72 |
--------------------------------------------------------------------------------
/gnas/modules/cnn_block.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn.parameter import Parameter
5 | from gnas.search_space.individual import Individual
6 | from gnas.modules.sub_graph_module import SubGraphModule
7 | from torch.nn import functional as F
8 | from modules.se_block import SEBlock
9 | from modules.identity import Identity
10 |
11 |
12 | class CnnSearchModule(nn.Module):
13 | def __init__(self, n_channels, ss, individual_index=0, se_block=True):
14 | super(CnnSearchModule, self).__init__()
15 |
16 | self.ss = ss
17 | self.n_channels = n_channels
18 | self.config_dict = {'n_channels': n_channels}
19 | self.sub_graph_module = SubGraphModule(ss, self.config_dict,
20 | individual_index=individual_index)
21 | if ss.single_block:
22 | self.n_inputs = len(ss.ocl[0].inputs)
23 | else:
24 | self.n_inputs = len(ss.ocl[individual_index][0].inputs)
25 | if se_block:
26 | self.se_block = SEBlock(n_channels, 8)
27 | else:
28 | self.se_block = Identity()
29 |
30 | self.bn = nn.BatchNorm2d(n_channels)
31 | self.relu = nn.ReLU()
32 | if self.ss.single_block:
33 | self.weights = [Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in range(len(ss.ocl))]
34 | else:
35 | self.weights = [Parameter(torch.Tensor(n_channels, n_channels, 1, 1)) for _ in
36 | range(len(ss.ocl[individual_index]))]
37 | [self.register_parameter('w_' + str(i), w) for i, w in enumerate(self.weights)]
38 | self.register_parameter('bias', None)
39 | self.reset_parameters()
40 |
41 | def reset_parameters(self):
42 | n = self.n_channels * len(self.weights)
43 | stdv = 1. / math.sqrt(n)
44 | for w in self.weights:
45 | w.data.uniform_(-stdv, stdv)
46 |
47 | def forward(self, inputs_tensor, bypass_input):
48 | if self.n_inputs == 1:
49 | net = self.sub_graph_module(inputs_tensor)
50 | elif self.n_inputs == 2:
51 | net = self.sub_graph_module(inputs_tensor, bypass_input)
52 |
53 | net = torch.cat([net[i] for i in self.sub_graph_module.avg_index if i > 1], dim=1)
54 | w = torch.cat([self.weights[i - 2] for i in self.sub_graph_module.avg_index if i > 1], dim=1)
55 | net = self.bn(F.conv2d(self.relu(net), w, self.bias, 1, 0, 1, 1))
56 | return self.se_block(net) + inputs_tensor
57 |
58 | def set_individual(self, individual: Individual):
59 | self.sub_graph_module.set_individual(individual)
60 |
61 | def parameters(self):
62 | for name, param in self.named_parameters():
63 | yield param
64 |
--------------------------------------------------------------------------------
/gnas/search_space/operation_space.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gnas.common.bit_utils import vector_bits2int
3 |
4 |
5 | class RnnInputNodeConfig(object):
6 | def __init__(self, node_id, inputs: list, non_linear_list):
7 | self.node_id = node_id
8 | self.inputs = inputs
9 | self.non_linear_list = non_linear_list
10 |
11 | def get_n_bits(self, max_inputs):
12 | return np.log2(len(self.non_linear_list)).astype('int')
13 |
14 | def max_values_vector(self, max_inputs):
15 | return np.ones(self.get_n_bits(0))
16 |
17 | def get_n_inputs(self):
18 | return len(self.inputs)
19 |
20 | def parse_config(self, oc):
21 | return vector_bits2int(oc)
22 |
23 |
24 | class RnnNodeConfig(object):
25 | def __init__(self, node_id, inputs: list, non_linear_list):
26 | self.node_id = node_id
27 | self.inputs = inputs
28 | self.non_linear_list = non_linear_list
29 |
30 | def get_n_bits(self, max_inputs):
31 | return np.log2(len(self.non_linear_list)).astype('int') + (max_inputs > 1)
32 |
33 | def max_values_vector(self, max_inputs):
34 | op_bits = np.ones(self.get_n_bits(0))
35 | if max_inputs > 1:
36 | return np.concatenate([np.asarray(max_inputs - 1).reshape(1), op_bits])
37 | return op_bits
38 |
39 | def get_n_inputs(self):
40 | return len(self.inputs)
41 |
42 | def parse_config(self, oc):
43 | if len(self.inputs) == 1:
44 | return self.inputs[0], 0, vector_bits2int(oc)
45 | else:
46 | return self.inputs[oc[0]], oc[0], vector_bits2int(oc[1:])
47 |
48 |
49 | class CnnNodeConfig(object):
50 | def __init__(self, node_id, inputs: list, op_list, drop_path_control):
51 | self.node_id = node_id
52 | self.inputs = inputs
53 | self.op_list = op_list
54 | self.drop_path_control = drop_path_control
55 |
56 | def max_values_vector(self, max_inputs):
57 | max_inputs = len(self.inputs)
58 | if max_inputs > 1:
59 | return np.asarray([max_inputs - 1, max_inputs - 1, len(self.op_list) - 1, len(self.op_list) - 1])
60 | return np.asarray([len(self.op_list) - 1, len(self.op_list) - 1])
61 |
62 | def get_n_inputs(self):
63 | return len(self.inputs)
64 |
65 | def parse_config(self, oc):
66 | if len(self.inputs) == 1:
67 | op_a = oc[0]
68 | op_b = oc[1]
69 | input_index_a = 0
70 | input_index_b = 0
71 | else:
72 | input_index_a = oc[0]
73 | input_index_b = oc[1]
74 | op_a = oc[2]
75 | op_b = oc[3]
76 | input_a = self.inputs[input_index_a]
77 | input_b = self.inputs[input_index_b]
78 | return input_a, input_b, input_index_a, input_index_b, op_a, op_b
79 |
--------------------------------------------------------------------------------
/gnas/modules/rnn_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from gnas.search_space.individual import Individual
5 | from gnas.modules.sub_graph_module import SubGraphModule
6 |
7 |
8 | class RnnSearchModule(nn.Module):
9 | def __init__(self, in_channels, n_channels, working_device, ss):
10 | super(RnnSearchModule, self).__init__()
11 |
12 | self.ss = ss
13 | self.in_channels = in_channels
14 | self.n_channels = n_channels
15 | self.working_device = working_device
16 | self.config_dict = {'in_channels': self.in_channels,
17 | 'n_channels': self.n_channels}
18 | self.sub_graph_module = SubGraphModule(ss, self.config_dict)
19 |
20 | self.reset_parameters()
21 |
22 | def forward(self, inputs_tensor, state_tensor):
23 | # input size [Time step,Batch,features]
24 |
25 | state = state_tensor[0, :, :]
26 | outputs = []
27 |
28 | for i in torch.split(inputs_tensor, split_size_or_sections=1, dim=0): # Loop over time steps
29 | output, state = self.cell(i, state)
30 | # state_norm = state.norm(dim=-1)
31 | # max_norm = 25.0
32 | # if torch.any(state_norm > max_norm).item():
33 | # clip_select = state_norm > max_norm
34 | # clip_norms = state_norm[clip_select]
35 | #
36 | # mask = torch.ones(state.size(), device=self.working_device)
37 | # normalizer = max_norm / clip_norms
38 | # mask[clip_select, :] = normalizer.unsqueeze(dim=-1)
39 | # mask = mask.detach()
40 | # state *= mask
41 | # print(np.max(state.norm(dim=-1).detach().cpu().numpy()))
42 | # print("Max Norm pass")
43 | # state = state / state.norm(dim=-1)
44 | outputs.append(output)
45 | output = torch.stack(outputs, dim=0)
46 |
47 | return output, state.unsqueeze(dim=0)
48 |
49 | def cell(self, x, state):
50 | net = self.sub_graph_module(x.squeeze(dim=0), state)
51 | output, state = torch.mean(torch.stack([net[i] for i in self.sub_graph_module.avg_index], dim=-1), dim=-1), net[
52 | -1]
53 | return output, output
54 |
55 | def set_individual(self, individual: Individual):
56 | self.sub_graph_module.set_individual(individual)
57 |
58 | def init_state(self, batch_size=1): # model init state
59 | weight = next(self.parameters())
60 | return weight.new_zeros(1, batch_size, self.n_channels)
61 |
62 | def parameters(self):
63 | for name, param in self.named_parameters():
64 | yield param
65 |
66 | def reset_parameters(self):
67 | init_range = 0.025
68 | for param in self.parameters():
69 | param.data.uniform_(-init_range, init_range)
70 |
--------------------------------------------------------------------------------
/tests/test_modules.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import gnas
4 | import time
5 | from tests.common4testing import generate_ss, generate_ss_cnn
6 | from gnas.modules.sub_graph_module import SubGraphModule
7 | from modules.drop_module import DropModuleControl
8 |
9 | class TestModules(unittest.TestCase):
10 | def test_sub_graph_build_rnn(self):
11 | ss = generate_ss()
12 | sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128})
13 |
14 | sgm.set_individual(ss.generate_individual())
15 |
16 | def test_run_sub_module(self):
17 | ss = generate_ss()
18 | sgm = SubGraphModule(ss, {'in_channels': 32, 'n_channels': 128})
19 | y = torch.randn(25, 128, dtype=torch.float)
20 | for i in range(100):
21 | sgm.set_individual(ss.generate_individual())
22 | x = torch.randn(25, 32, dtype=torch.float)
23 | y = sgm(x, y)
24 | y = y[-1]
25 |
26 | def test_cnn_sub_module(self):
27 | ss = generate_ss_cnn()
28 | sgm = SubGraphModule(ss, {'n_channels': 64})
29 |
30 | for i in range(100):
31 | sgm.set_individual(ss.generate_individual())
32 | y = torch.randn(32, 64, 16, 16, dtype=torch.float)
33 | x = torch.randn(32, 64, 16, 16, dtype=torch.float)
34 | res = sgm(x, y)
35 |
36 | def test_cnn_module(self):
37 | batch_size = 64
38 | h, w = 16, 16
39 | channels = 64
40 | input = torch.randn(batch_size, channels, h, w, dtype=torch.float)
41 | input_b = torch.randn(batch_size, channels, h, w, dtype=torch.float)
42 | dp_control = DropModuleControl(1)
43 | ss = gnas.get_gnas_cnn_search_space(4, dp_control, gnas.SearchSpaceType.CNNSingleCell)
44 | rnn = gnas.modules.CnnSearchModule(n_channels=channels,
45 | ss=ss)
46 | rnn.set_individual(ss.generate_individual())
47 |
48 | s = time.time()
49 | output = rnn(input, input_b)
50 | print(time.time() - s)
51 | self.assertTrue(output.shape[0] == batch_size)
52 | self.assertTrue(output.shape[1] == channels)
53 | self.assertTrue(output.shape[2] == h)
54 | self.assertTrue(output.shape[3] == w)
55 |
56 | def test_rnn_module(self):
57 | batch_size = 64
58 | in_channels = 300
59 | out_channels = 128
60 | time_steps = 35
61 | input = torch.randn(time_steps, batch_size, in_channels, dtype=torch.float)
62 |
63 | ss = gnas.get_gnas_rnn_search_space(12)
64 | rnn = gnas.modules.RnnSearchModule(in_channels=in_channels, n_channels=out_channels, working_device='cpu',
65 | ss=ss)
66 | rnn.set_individual(ss.generate_individual())
67 |
68 | state = rnn.init_state(batch_size)
69 | s = time.time()
70 | output, state = rnn(input, state)
71 | print(time.time() - s)
72 | self.assertTrue(output.shape[1] == batch_size)
73 | self.assertTrue(output.shape[0] == time_steps)
74 | self.assertTrue(output.shape[2] == out_channels)
75 |
76 |
77 | if __name__ == '__main__':
78 | unittest.main()
79 |
--------------------------------------------------------------------------------
/gnas/modules/module_generator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from modules.identity import Identity
3 |
4 | __nl_dict__ = {'Tanh': nn.Tanh,
5 | 'ReLU': nn.ReLU6,
6 | 'ReLU6': nn.ReLU,
7 | 'SELU': nn.SELU,
8 | 'LeakyReLU': nn.LeakyReLU,
9 | 'Sigmoid': nn.Sigmoid}
10 |
11 |
12 | def generate_dw_conv(in_channels, out_channels, kernel):
13 | padding = int((kernel - 1) / 2)
14 | conv1 = nn.Sequential(nn.ReLU(),
15 | nn.Conv2d(in_channels, in_channels, kernel, padding=padding, groups=in_channels, bias=False),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(),
18 | nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False),
19 | nn.BatchNorm2d(out_channels))
20 | conv2 = nn.Sequential(nn.ReLU(),
21 | nn.Conv2d(in_channels, in_channels, kernel, padding=padding, groups=in_channels, bias=False),
22 | nn.BatchNorm2d(out_channels),
23 | nn.ReLU(),
24 | nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False),
25 | nn.BatchNorm2d(out_channels))
26 | return nn.Sequential(conv1, conv2)
27 |
28 |
29 | def max_pool3x3(in_channels, out_channels):
30 | return nn.Sequential(nn.ReLU(),
31 | nn.MaxPool2d(3, stride=1, padding=1))
32 |
33 |
34 | def avg_pool3x3(in_channels, out_channels):
35 | return nn.Sequential(nn.ReLU(),
36 | nn.AvgPool2d(3, stride=1, padding=1))
37 |
38 |
39 | def conv3x3(in_channels, out_channels):
40 | return nn.Sequential(nn.ReLU(),
41 | nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
42 | nn.BatchNorm2d(out_channels))
43 |
44 |
45 | def dw_conv3x3(in_channels, out_channels):
46 | return generate_dw_conv(in_channels, out_channels, 3)
47 |
48 |
49 | def dw_conv1x3(in_channels, out_channels):
50 | return nn.Sequential(nn.ReLU(),
51 | nn.Conv2d(in_channels, out_channels, (1, 3), padding=(0, 1), groups=out_channels, bias=False),
52 | nn.BatchNorm2d(out_channels))
53 |
54 |
55 | def dw_conv3x1(in_channels, out_channels):
56 | return nn.Sequential(nn.ReLU(),
57 | nn.Conv2d(in_channels, out_channels, (3, 1), padding=(1, 0), groups=out_channels, bias=False),
58 | nn.BatchNorm2d(out_channels))
59 |
60 |
61 | def conv5x5(in_channels, out_channels):
62 | return nn.Sequential(nn.ReLU(),
63 | nn.Conv2d(in_channels, out_channels, 5, padding=2, bias=False),
64 | nn.BatchNorm2d(out_channels))
65 |
66 |
67 | def dw_conv5x5(in_channels, out_channels):
68 | return generate_dw_conv(in_channels, out_channels, 5)
69 |
70 |
71 | def identity(in_channels, out_channels):
72 | return Identity()
73 |
74 |
75 | __op_dict__ = {'Conv3x3': conv3x3,
76 | 'Dw3x3': dw_conv3x3,
77 | 'Dw3x1': dw_conv3x1,
78 | 'Dw1x3': dw_conv1x3,
79 | 'Conv5x5': conv5x5,
80 | 'Dw5x5': dw_conv5x5,
81 | 'Identity': identity,
82 | 'Max3x3': max_pool3x3,
83 | 'Avg3x3': avg_pool3x3, }
84 |
85 |
86 | def generate_non_linear(non_linear_list):
87 | return [__nl_dict__.get(nl)() for nl in non_linear_list]
88 |
89 |
90 | def generate_op(op_list, in_channels, out_channels):
91 | return [__op_dict__.get(nl)(in_channels, out_channels) for nl in op_list]
92 |
--------------------------------------------------------------------------------
/gnas/common/graph_draw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from gnas.search_space.individual import Individual, MultipleBlockIndividual
5 | from gnas.search_space.operation_space import RnnNodeConfig, RnnInputNodeConfig, CnnNodeConfig
6 |
7 |
8 | def add_node(graph, node_id, label, shape='box', style='filled'):
9 | if label.startswith('Input') or label.startswith('x'):
10 | color = 'skyblue'
11 | elif label.startswith('Output'):
12 | color = 'pink'
13 | elif 'Tanh' in label or 'Add' in label or 'Concat' in label:
14 | color = 'yellow'
15 | elif 'ReLU' in label or 'Dw' in label or 'Conv' in label:
16 | color = 'orange'
17 | elif 'Sigmoid' in label or 'Identity' in label:
18 | color = 'greenyellow'
19 | elif label == 'avg':
20 | color = 'seagreen3'
21 | else:
22 | color = 'white'
23 |
24 | if not any(label.startswith(word) for word in ['x', 'avg', 'h']):
25 | label = f"{label}\n({node_id})"
26 |
27 | graph.add_node(
28 | node_id, label=label, color='black', fillcolor=color,
29 | shape=shape, style=style,
30 | )
31 |
32 |
33 | def _draw_individual(ocl, individual, path=None):
34 | import pygraphviz as pgv
35 | graph = pgv.AGraph(directed=True, layout='dot') # not work?
36 |
37 | ofset = len(ocl[0].inputs)
38 | for i in range(len(ocl[0].inputs)):
39 | add_node(graph, i, 'x[' + str(i) + ']')
40 |
41 | input_list = []
42 |
43 | for i, (oc, op) in enumerate(zip(individual.generate_node_config(), ocl)):
44 | if isinstance(op, CnnNodeConfig):
45 | input_a = oc[0]
46 | input_b = oc[1]
47 | input_list.append(input_a)
48 | input_list.append(input_b)
49 | op_a = oc[4]
50 | op_b = oc[5]
51 | add_node(graph, (i + ofset) * 10, ocl[i].op_list[op_a])
52 | add_node(graph, (i + ofset) * 10 + 1, ocl[i].op_list[op_b])
53 | graph.add_edge(input_a, (i + ofset) * 10)
54 | graph.add_edge(input_b, (i + ofset) * 10 + 1)
55 | add_node(graph, (i + ofset), 'Add')
56 | graph.add_edge((i + ofset) * 10, (i + ofset))
57 | graph.add_edge((i + ofset) * 10 + 1, (i + ofset))
58 | c = graph.add_subgraph([(i + ofset) * 10, (i + ofset) * 10 + 1, (i + ofset)],
59 | name='cluster_block:' + str(i), label='Block ' + str(i))
60 | # c.attr(label='block:'+str(i))
61 |
62 | elif isinstance(op, RnnInputNodeConfig):
63 | op_type = op.non_linear_list[oc]
64 | add_node(graph, (i + ofset), op_type)
65 | graph.add_edge(0, (i + ofset))
66 | graph.add_edge(1, (i + ofset))
67 | input_list.append(0)
68 | input_list.append(1)
69 |
70 | elif isinstance(op, RnnNodeConfig):
71 | op_type = op.non_linear_list[oc[-1]]
72 | add_node(graph, (i + ofset), op_type)
73 | graph.add_edge(oc[0], (i + ofset))
74 | input_list.append(oc[0])
75 | else:
76 | raise Exception('unkown node type')
77 | input_list = np.unique(input_list)
78 | op_inputs = [int(i) for i in np.linspace(ofset, ofset + individual.get_n_op() - 1, individual.get_n_op()) if
79 | i not in input_list]
80 | concat_node = i + 1 + ofset
81 | add_node(graph, concat_node, 'Concat')
82 | for i in op_inputs:
83 | graph.add_edge(i, concat_node)
84 | graph.layout(prog='dot')
85 | if path is not None:
86 | graph.draw(path + '.png')
87 |
88 |
89 |
90 | def draw_cell(ocl, individual):
91 | _draw_individual(ocl, individual, path=None)
92 |
93 |
94 | def draw_network(ss, individual, path):
95 | os.makedirs(os.path.dirname(path), exist_ok=True)
96 | if isinstance(individual, Individual):
97 | _draw_individual(ss.ocl, individual, path)
98 | elif isinstance(individual, MultipleBlockIndividual):
99 | [_draw_individual(ocl, inv, path + str(i)) for i, (inv, ocl) in
100 | enumerate(zip(individual.individual_list, ss.ocl))]
101 |
--------------------------------------------------------------------------------
/models/model_cnn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import gnas
3 | import torch
4 |
5 |
6 | class RepeatBlock(nn.Module):
7 | def __init__(self, n_blocks, n_channels, ss, individual_index=0, first_block=None):
8 | super(RepeatBlock, self).__init__()
9 | if first_block is None: first_block = individual_index
10 | self.block_list = [gnas.modules.CnnSearchModule(n_channels, ss,
11 | individual_index=first_block if i == 0 else individual_index)
12 | for i in
13 | range(n_blocks)]
14 | [self.add_module('block_' + str(i), n) for i, n in enumerate(self.block_list)]
15 |
16 | def forward(self, x, x_prev):
17 | for b in self.block_list:
18 | x_new = b(x, x_prev)
19 | x_prev = x
20 | x = x_new
21 | return x, x_prev
22 |
23 | def set_individual(self, individual):
24 | [b.set_individual(individual) for b in self.block_list]
25 |
26 |
27 | class Net(nn.Module):
28 | def __init__(self, n_blocks, n_channels, n_classes, dropout, ss, aux=False):
29 | n_block_types = len(ss.ocl)
30 | normal_block_index = 0
31 | reduce_block_index = 0
32 | first_block_index = 0
33 | if n_block_types >= 2:
34 | normal_block_index = 1
35 | first_block_index = 1
36 | if n_block_types == 3: first_block_index = 2
37 | super(Net, self).__init__()
38 | self.conv1 = nn.Conv2d(3, n_channels, 3, stride=1, padding=1, bias=False)
39 | self.bn1 = nn.BatchNorm2d(n_channels)
40 |
41 | self.block_1 = RepeatBlock(n_blocks, n_channels, ss,
42 | individual_index=normal_block_index, first_block=first_block_index)
43 |
44 | self.avg = nn.AvgPool2d(2)
45 | self.conv2 = nn.Conv2d(n_channels, 2 * n_channels, 1, stride=1, padding=1, bias=False)
46 | self.bn2 = nn.BatchNorm2d(2 * n_channels)
47 |
48 | self.conv2_prev = nn.Conv2d(n_channels, 2 * n_channels, 1, stride=1, padding=1, bias=False)
49 | self.bn2_prev = nn.BatchNorm2d(2 * n_channels)
50 | # self
51 |
52 | self.block_2_reduce = gnas.modules.CnnSearchModule(2 * n_channels, ss, individual_index=reduce_block_index)
53 | self.block_2 = RepeatBlock(n_blocks, 2 * n_channels, ss,
54 | individual_index=normal_block_index)
55 |
56 | self.conv3 = nn.Conv2d(2 * n_channels, 4 * n_channels, 1, stride=1, padding=1, bias=False)
57 | self.bn3 = nn.BatchNorm2d(4 * n_channels)
58 |
59 | self.conv3_prev = nn.Conv2d(2 * n_channels, 4 * n_channels, 1, stride=1, padding=1, bias=False)
60 | self.bn3_prev = nn.BatchNorm2d(4 * n_channels)
61 |
62 | self.block_3_reduce = gnas.modules.CnnSearchModule(4 * n_channels, ss, individual_index=reduce_block_index)
63 | self.block_3 = RepeatBlock(n_blocks, 4 * n_channels, ss,
64 | individual_index=normal_block_index)
65 |
66 | self.relu = nn.ReLU()
67 | self.dp = nn.Dropout(p=dropout)
68 | self.fc1 = nn.Sequential(nn.ReLU(),
69 | nn.Linear(4 * n_channels, n_classes))
70 | self.aux = aux
71 | if aux:
72 | self.fc2 = nn.Sequential(nn.ReLU(),
73 | nn.Linear(2 * n_channels, n_classes))
74 | self.reset_param()
75 |
76 | def reset_param(self):
77 | for p in self.parameters():
78 | if len(p.shape) == 4:
79 | nn.init.kaiming_normal_(p)
80 |
81 | def forward(self, x):
82 | x_prev = self.bn1(self.conv1(x))
83 |
84 | x, x_prev = self.block_1(x_prev, x_prev)
85 |
86 | # reduce dim
87 | x = self.bn2(self.conv2(self.avg(x)))
88 | x_prev = self.bn2_prev(self.conv2_prev(self.avg(x_prev)))
89 |
90 | x, x_prev = self.block_2(self.block_2_reduce(x, x_prev), x)
91 |
92 |
93 | if self.aux: x2 = torch.mean(torch.mean(x, dim=-1), dim=-1)
94 |
95 |
96 | x = self.bn3(self.conv3(self.avg(x)))
97 | x_prev = self.bn3_prev(self.conv3_prev(self.avg(x_prev)))
98 |
99 | x, x_prev = self.block_3(self.block_3_reduce(x, x_prev), x)
100 |
101 |
102 | # Global pooling
103 | x = torch.mean(torch.mean(x, dim=-1), dim=-1)
104 | if self.aux:
105 | return [self.fc1(self.dp(x)), self.fc2(self.dp(x2))]
106 | else:
107 | return [self.fc1(self.dp(x))]
108 |
109 | def set_individual(self, individual):
110 | self.block_1.set_individual(individual)
111 | self.block_2.set_individual(individual)
112 | self.block_2_reduce.set_individual(individual)
113 | self.block_3_reduce.set_individual(individual)
114 | self.block_3.set_individual(individual)
115 |
--------------------------------------------------------------------------------
/gnas/modules/node_module.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from gnas.modules.module_generator import generate_non_linear, generate_op
3 | from modules.weight_drop import WeightDrop
4 | from modules.drop_module import DropModule
5 |
6 |
7 | class RnnInputNodeModule(nn.Module):
8 | def __init__(self, node_config, config_dict):
9 | super(RnnInputNodeModule, self).__init__()
10 | if node_config.get_n_inputs() != 2: raise Exception('aaa')
11 | self.nc = node_config
12 | dropout = 0.5 # TODO:change to input config
13 | self.in_channels = config_dict.get('in_channels')
14 | self.n_channels = config_dict.get('n_channels')
15 | self.nl_module = generate_non_linear(self.nc.non_linear_list)
16 |
17 | self.x_linear_list = [nn.Linear(self.in_channels, self.n_channels) for _ in range(2)]
18 | [self.add_module('c_linear' + str(i), m) for i, m in enumerate(self.x_linear_list)]
19 | self.h_linear_list = [WeightDrop(nn.Linear(self.n_channels, self.n_channels), ['weight'], dropout, True) for _
20 | in
21 | range(2)]
22 | [self.add_module('h_linear' + str(i), m) for i, m in enumerate(self.h_linear_list)]
23 | self.sigmoid = nn.Sigmoid()
24 |
25 | self.non_linear = None
26 | self.node_config = None
27 |
28 | def forward(self, inputs):
29 | c = self.sigmoid(self.x_linear_list[0](inputs[0]) + self.h_linear_list[0](inputs[1]))
30 | output = c * self.non_linear(self.x_linear_list[1](inputs[0]) + self.h_linear_list[1](inputs[1])) + (1 - c) * \
31 | inputs[1]
32 | return output
33 |
34 | def set_current_node_config(self, current_config):
35 | nl_index = current_config
36 | self.cc = current_config
37 | self.non_linear = self.nl_module[nl_index]
38 |
39 |
40 | class RnnNodeModule(nn.Module):
41 | def __init__(self, node_config, config_dict):
42 | super(RnnNodeModule, self).__init__()
43 | self.nc = node_config
44 | if node_config.get_n_inputs() < 1: raise Exception('aaa')
45 |
46 | self.n_channels = config_dict.get('n_channels')
47 | self.nl_module = generate_non_linear(self.nc.non_linear_list)
48 | # self.bn = nn.BatchNorm1d(self.n_channels)
49 |
50 | self.x_linear_list = [nn.Linear(self.n_channels, self.n_channels) for _ in range(node_config.get_n_inputs())]
51 | [self.add_module('c_linear' + str(i), m) for i, m in enumerate(self.x_linear_list)]
52 | self.h_linear_list = [nn.Linear(self.n_channels, self.n_channels) for _ in range(node_config.get_n_inputs())]
53 | [self.add_module('h_linear' + str(i), m) for i, m in enumerate(self.h_linear_list)]
54 | self.sigmoid = nn.Sigmoid()
55 | # self.bn = nn.BatchNorm1d(self.n_channels)
56 | self.non_linear = None
57 | self.node_config = None
58 |
59 | def forward(self, inputs):
60 | x = inputs[self.select_index]
61 | c = self.sigmoid(self.x_linear(x))
62 | return c * self.non_linear(self.h_linear(x)) + (1 - c) * x
63 |
64 | def set_current_node_config(self, current_config):
65 | self.select_index, op_index, nl_index = current_config
66 | self.cc = current_config
67 | self.non_linear = self.nl_module[nl_index]
68 | self.x_linear = self.x_linear_list[op_index]
69 | self.h_linear = self.h_linear_list[op_index]
70 |
71 |
72 | class ConvNodeModule(nn.Module):
73 | def __init__(self, node_config, config_dict):
74 | super(ConvNodeModule, self).__init__()
75 | self.nc = node_config
76 |
77 | self.n_channels = config_dict.get('n_channels')
78 | self.conv_module = []
79 | for j in range(node_config.get_n_inputs()):
80 | op_list = [DropModule(op, node_config.drop_path_control) for op in
81 | generate_op(self.nc.op_list, self.n_channels, self.n_channels)]
82 | self.conv_module.append(op_list)
83 | [self.add_module('conv_op_' + str(i) + '_in_' + str(j), m) for i, m in enumerate(self.conv_module[-1])]
84 |
85 | self.non_linear_a = None
86 | self.non_linear_b = None
87 | self.input_a = None
88 | self.input_b = None
89 | self.cc = None
90 | self.op_a = None
91 | self.op_b = None
92 |
93 | def forward(self, inputs):
94 | net_a = inputs[self.input_a]
95 | net_b = inputs[self.input_b]
96 | return self.op_a(net_a) + self.op_b(net_b)
97 |
98 | def set_current_node_config(self, current_config):
99 | input_a, input_b, input_index_a, input_index_b, op_a, op_b = current_config
100 | self.select_index = [input_a, input_b]
101 | self.cc = current_config
102 | self.input_a = input_a
103 | self.input_b = input_b
104 | self.op_a = self.conv_module[input_index_a][op_a]
105 | self.op_b = self.conv_module[input_index_b][op_b]
106 | #### set grad false
107 | for p in self.parameters():
108 | p.requires_grad = False
109 | #### set grad true
110 | for p in self.op_b.parameters():
111 | p.requires_grad = True
112 | for p in self.op_a.parameters():
113 | p.requires_grad = True
114 |
--------------------------------------------------------------------------------
/tests/test_search_space.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import os
4 | import inspect
5 |
6 | from gnas.search_space.operation_space import RnnInputNodeConfig, RnnNodeConfig
7 | from gnas.search_space.search_space import SearchSpace
8 | from gnas.search_space.individual import Individual
9 | from gnas.search_space.cross_over import individual_uniform_crossover
10 | from gnas.common.graph_draw import draw_network
11 | from gnas.search_space.mutation import individual_flip_mutation
12 | import gnas
13 |
14 |
15 | class TestSearchSpace(unittest.TestCase):
16 | @staticmethod
17 | def generate_block():
18 | nll = ['Tanh', 'ReLU', 'ReLU6', 'Sigmoid']
19 | node_config_list = [RnnInputNodeConfig(0, [], nll)]
20 | for i in range(12):
21 | node_config_list.append(RnnNodeConfig(i + 1, [0], nll))
22 | return node_config_list
23 |
24 | @staticmethod
25 | def generate_ss():
26 |
27 | ss = SearchSpace(TestSearchSpace.generate_block())
28 | return ss
29 |
30 | @staticmethod
31 | def generate_ss_multiple_blocks():
32 | block_list = [TestSearchSpace.generate_block(), TestSearchSpace.generate_block()]
33 | ss = SearchSpace(block_list, single_block=False)
34 | return ss
35 |
36 | def test_basic(self):
37 | ss = self.generate_ss()
38 | individual = ss.generate_individual()
39 | self._test_individual(individual, ss.get_n_nodes())
40 |
41 | def test_individual(self):
42 | ss = self.generate_ss()
43 | individual_a = ss.generate_individual()
44 | individual_b = ss.generate_individual()
45 | individual_a_tag = individual_a.copy()
46 | dict2test = dict()
47 | self.assertFalse(individual_a == individual_b)
48 | self.assertTrue(individual_a_tag == individual_a)
49 | dict2test.update({individual_a: 40})
50 | dict2test.update({individual_b: 80})
51 | dict2test.update({individual_a_tag: 90})
52 | self.assertTrue(dict2test.get(individual_a) == 90)
53 | self.assertTrue(dict2test.get(individual_a_tag) == 90)
54 | self.assertTrue(dict2test.get(individual_b) == 80)
55 | self.assertTrue(len(dict2test) == 2)
56 | # res_dict
57 |
58 | def test_basic_multiple(self):
59 | ss = self.generate_ss_multiple_blocks()
60 | individual = ss.generate_individual()
61 | self._test_individual(individual, ss.get_n_nodes())
62 |
63 | def test_mutation(self):
64 | ss = self.generate_ss()
65 | for i in range(100):
66 | individual_a = ss.generate_individual()
67 | individual_c = individual_flip_mutation(individual_a, 1 / 10)
68 | ce = 0
69 | te = 0
70 | for a, c in zip(individual_a.iv, individual_c.iv):
71 | for ia, ic in zip(a, c):
72 | te += 1
73 | ce += ia != ic
74 |
75 | self._test_individual(individual_c, ss.get_n_nodes())
76 | self.assertTrue(ce != te)
77 |
78 | def test_cross_over(self):
79 | ss = self.generate_ss()
80 | ca = 0
81 | cb = 0
82 | cc = 0
83 | for i in range(100):
84 | individual_a = ss.generate_individual()
85 | individual_b = ss.generate_individual()
86 | individual_c, individual_d = individual_uniform_crossover(individual_a, individual_b, 1)
87 |
88 | for a, b, c in zip(individual_a.iv, individual_b.iv, individual_c.iv):
89 | for ia, ib, ic in zip(a, b, c):
90 | cc += 1
91 | ca += ia == ic
92 | cb += ib == ic
93 | self.assertTrue(ia == ic or ib == ic)
94 |
95 | self._test_individual(individual_c, ss.get_n_nodes())
96 | self.assertTrue(cc != cb)
97 | self.assertTrue(cc != ca)
98 |
99 | def test_plot_individual_rnn(self):
100 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
101 |
102 | ss = gnas.get_gnas_rnn_search_space(12)
103 | ind = ss.generate_individual()
104 | draw_network(ss, ind, os.path.join(current_path, 'graph'))
105 |
106 | def test_plot_individual(self):
107 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
108 |
109 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNSingleCell)
110 | ind = ss.generate_individual()
111 | draw_network(ss, ind, os.path.join(current_path, 'graph'))
112 |
113 | def test_plot_individual_dual(self):
114 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
115 |
116 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNDualCell)
117 | ind = ss.generate_individual()
118 | draw_network(ss, ind, os.path.join(current_path, 'graph'))
119 |
120 | def test_plot_individual_triple(self):
121 | current_path = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
122 |
123 | ss = gnas.get_gnas_cnn_search_space(5, 1, gnas.SearchSpaceType.CNNTripleCell)
124 | ind = ss.generate_individual()
125 | draw_network(ss, ind, os.path.join(current_path, 'graph'))
126 |
127 | def _test_individual(self, individual, n_nodes):
128 | individual_flip_mutation(individual, 0.2)
129 | if isinstance(individual, Individual):
130 | self.assertTrue(len(individual.iv) == n_nodes)
131 | for c in individual.iv:
132 | if len(c) == 2:
133 | self.assertFalse(np.any(c > 1))
134 | else:
135 | self.assertFalse(np.any(c[1:] > 1))
136 | individual.generate_node_config()
137 | else:
138 | [self.assertTrue(len(ind.iv) == n_nodes[i]) for i, ind in enumerate(individual.individual_list)]
139 | individual.generate_node_config(0)
140 |
141 |
142 | if __name__ == '__main__':
143 | unittest.main()
144 |
--------------------------------------------------------------------------------
/models/model_rnn.py:
--------------------------------------------------------------------------------
1 | import gnas
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 |
7 |
8 | class EmbeddingDropout(torch.nn.Embedding):
9 | """Class for dropping out embeddings by zero'ing out parameters in the
10 | embedding matrix.
11 |
12 | This is equivalent to dropping out particular words, e.g., in the sentence
13 | 'the quick brown fox jumps over the lazy dog', dropping out 'the' would
14 | lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the
15 | embedding vector space).
16 |
17 | See 'A Theoretically Grounded Application of Dropout in Recurrent Neural
18 | Networks', (Gal and Ghahramani, 2016).
19 | """
20 |
21 | def __init__(self,
22 | num_embeddings,
23 | embedding_dim,
24 | max_norm=None,
25 | norm_type=2,
26 | scale_grad_by_freq=False,
27 | sparse=False,
28 | dropout=0.1,
29 | scale=None):
30 | """Embedding constructor.
31 |
32 | Args:
33 | dropout: Dropout probability.
34 | scale: Used to scale parameters of embedding weight matrix that are
35 | not dropped out. Note that this is _in addition_ to the
36 | `1/(1 - dropout)` scaling.
37 |
38 | See `torch.nn.Embedding` for remaining arguments.
39 | """
40 | torch.nn.Embedding.__init__(self,
41 | num_embeddings=num_embeddings,
42 | embedding_dim=embedding_dim,
43 | max_norm=max_norm,
44 | norm_type=norm_type,
45 | scale_grad_by_freq=scale_grad_by_freq,
46 | sparse=sparse)
47 | self.dropout = dropout
48 | assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 '
49 | 'and < 1.0')
50 | self.scale = scale
51 |
52 | def forward(self, inputs): # pylint:disable=arguments-differ
53 | """Embeds `inputs` with the dropped out embedding weight matrix."""
54 | if self.training:
55 | dropout = self.dropout
56 | else:
57 | dropout = 0
58 |
59 | if dropout:
60 | mask = self.weight.data.new(self.weight.size(0), 1)
61 | mask.bernoulli_(1 - dropout)
62 | mask = mask.expand_as(self.weight)
63 | mask = mask / (1 - dropout)
64 | masked_weight = self.weight * Variable(mask)
65 | else:
66 | masked_weight = self.weight
67 | if self.scale and self.scale != 1:
68 | masked_weight = masked_weight * self.scale
69 |
70 | return F.embedding(inputs,
71 | masked_weight,
72 | max_norm=self.max_norm,
73 | norm_type=self.norm_type,
74 | scale_grad_by_freq=self.scale_grad_by_freq,
75 | sparse=self.sparse)
76 |
77 |
78 | class LockedDropout(nn.Module):
79 | # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py
80 | def __init__(self, dropout):
81 | super().__init__()
82 | self.dropout = dropout
83 |
84 | def forward(self, x):
85 | # The input is of size [T,N,F] where T is the number of steps, N is the batch size
86 | # and F is the number of features.
87 | # The random variable generate for the locked dropout don't change over the T axis
88 | if not self.training or not self.dropout:
89 | return x
90 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - self.dropout)
91 | mask = Variable(m, requires_grad=False) / (1 - self.dropout)
92 | mask = mask.expand_as(x)
93 | return mask * x
94 |
95 |
96 | class RNNModel(nn.Module):
97 | """Container module with an encoder, a recurrent module, and a decoder."""
98 |
99 | def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False, ss=None):
100 | super(RNNModel, self).__init__()
101 | self.drop_input = LockedDropout(0.65) # TODO:change to input config
102 | self.drop_end = LockedDropout(0.4) # TODO:change to input config
103 | self.encoder = EmbeddingDropout(ntoken, ninp, dropout=0.1) # TODO:change to input config
104 | self.ss = ss
105 | self.rnn = gnas.modules.RnnSearchModule(in_channels=ninp, n_channels=nhid,
106 | working_device='cuda',
107 | ss=self.ss)
108 | self.decoder = nn.Linear(nhid, ntoken)
109 |
110 | # Optionally tie weights as in:
111 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
112 | # https://arxiv.org/abs/1608.05859
113 | # and
114 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
115 | # https://arxiv.org/abs/1611.01462
116 | if tie_weights:
117 | if nhid != ninp:
118 | raise ValueError('When using the tied flag, nhid must be equal to emsize')
119 | self.decoder.weight = self.encoder.weight
120 |
121 | self.init_weights()
122 |
123 | self.nhid = nhid
124 | self.nlayers = nlayers
125 |
126 | def set_individual(self, individual):
127 | self.rnn.set_individual(individual)
128 |
129 | def init_weights(self):
130 | initrange = 0.1
131 | self.encoder.weight.data.uniform_(-initrange, initrange)
132 | self.decoder.bias.data.zero_()
133 | self.decoder.weight.data.uniform_(-initrange, initrange)
134 |
135 | def forward(self, input, hidden):
136 | emb = self.drop_input(self.encoder(input))
137 | output, hidden = self.rnn(emb, hidden)
138 | output = self.drop_end(output)
139 | decoded = self.decoder(output.contiguous().view(output.size(0) * output.size(1), output.size(2)))
140 | return decoded.contiguous().view(output.size(0), output.size(1), decoded.size(1)), hidden
141 |
142 | def init_hidden(self, bsz):
143 | return self.rnn.init_state(bsz)
144 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import torchvision.transforms as transforms
5 | from modules.cut_out import Cutout
6 |
7 |
8 | def get_dataset(config):
9 | dataset_name = config.get('dataset_name')
10 | data_path = config.get('data_path')
11 | if dataset_name == 'CIFAR10':
12 | return get_cifar(config, os.path.join(data_path, 'CIFAR10'))
13 | elif dataset_name == 'CIFAR100':
14 | return get_cifar(config, os.path.join(data_path, 'CIFAR100'), dataset_name='CIFAR100')
15 | elif dataset_name == 'PTB':
16 | corpus = Corpus(os.path.join(data_path, 'ptb'))
17 | batch_size_train = config.get('batch_size')
18 | batch_size_val = config.get('batch_size_val')
19 | device = config.get('working_device')
20 | # train_data, val_data, test_data = corpus.batchify(config.get('batch_size'), config.get('working_device'))
21 | return corpus.single_batchify(corpus.train, batch_size_train, device), corpus.single_batchify(corpus.valid,
22 | batch_size_val,
23 | device), len(
24 | corpus.dictionary)
25 | else:
26 | raise Exception('unkown dataset type')
27 |
28 |
29 | def get_cifar(config, data_path, dataset_name='CIFAR10'):
30 | train_transform = transforms.Compose([])
31 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
32 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
33 | train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
34 | train_transform.transforms.append(transforms.RandomHorizontalFlip())
35 | train_transform.transforms.append(transforms.ToTensor())
36 | train_transform.transforms.append(normalize)
37 | if config.get('cutout'):
38 | train_transform.transforms.append(Cutout(n_holes=config.get('n_holes'), length=config.get('length')))
39 |
40 | transform = transforms.Compose([
41 | transforms.ToTensor(),
42 | normalize])
43 | trainloader, testloader, n_class = None, None, None
44 | if dataset_name == 'CIFAR10':
45 | trainset = torchvision.datasets.CIFAR10(root=data_path, train=True,
46 | download=True, transform=train_transform)
47 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.get('batch_size'),
48 | shuffle=True, num_workers=4)
49 |
50 | testset = torchvision.datasets.CIFAR10(root=data_path, train=False,
51 | download=True, transform=transform)
52 | testloader = torch.utils.data.DataLoader(testset, batch_size=config.get('batch_size_val'),
53 | shuffle=False, num_workers=4)
54 | n_class = 10
55 | elif dataset_name == 'CIFAR100':
56 | trainset = torchvision.datasets.CIFAR100(root=data_path, train=True,
57 | download=True, transform=train_transform)
58 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.get('batch_size'),
59 | shuffle=True, num_workers=4)
60 |
61 | testset = torchvision.datasets.CIFAR100(root=data_path, train=False,
62 | download=True, transform=transform)
63 | testloader = torch.utils.data.DataLoader(testset, batch_size=config.get('batch_size_val'),
64 | shuffle=False, num_workers=4)
65 | n_class = 100
66 | else:
67 | raise Exception('unkown dataset' + dataset_name)
68 |
69 | return trainloader, testloader, n_class
70 |
71 |
72 | class BatchIterator(object):
73 | def __init__(self, data):
74 | pass
75 |
76 |
77 | class Dictionary(object):
78 | def __init__(self):
79 | self.word2idx = {}
80 | self.idx2word = []
81 |
82 | def add_word(self, word):
83 | if word not in self.word2idx:
84 | self.idx2word.append(word)
85 | self.word2idx[word] = len(self.idx2word) - 1
86 | return self.word2idx[word]
87 |
88 | def __len__(self):
89 | return len(self.idx2word)
90 |
91 |
92 | class Corpus(object):
93 | def __init__(self, path):
94 | # Starting from sequential dataset, batchify arranges the dataset into columns.
95 | # For instance, with the alphabet as the sequence and batch size 4, we'd get
96 | # ┌ a g m s ┐
97 | # │ b h n t │
98 | # │ c i o u │
99 | # │ d j p v │
100 | # │ e k q w │
101 | # └ f l r x ┘.
102 | # These columns are treated as independent by the model, which means that the
103 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
104 | # batch processing.
105 | self.dictionary = Dictionary()
106 | self.train = self.tokenize(os.path.join(path, 'train.txt'))
107 | self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
108 | self.test = self.tokenize(os.path.join(path, 'test.txt'))
109 |
110 | @staticmethod
111 | def single_batchify(data, bsz, input_device):
112 | # Work out how cleanly we can divide the dataset into bsz parts.
113 | nbatch = data.size(0) // bsz
114 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
115 | data = data.narrow(0, 0, nbatch * bsz)
116 | # Evenly divide the dataset across the bsz batches.
117 | data = data.view(bsz, -1).t().contiguous()
118 | return data.to(input_device)
119 |
120 | def batchify(self, bsz, device):
121 | return self.single_batchify(self.train, bsz, device), self.single_batchify(self.valid, bsz,
122 | device), self.single_batchify(
123 | self.test, bsz, device)
124 |
125 | def tokenize(self, path):
126 | """Tokenizes a text file."""
127 | assert os.path.exists(path)
128 | # Add words to the dictionary
129 | with open(path, 'r', encoding="utf8") as f:
130 | tokens = 0
131 | for line in f:
132 | words = line.split() + ['']
133 | tokens += len(words)
134 | for word in words:
135 | self.dictionary.add_word(word)
136 |
137 | # Tokenize file content
138 | with open(path, 'r', encoding="utf8") as f:
139 | ids = torch.LongTensor(tokens)
140 | token = 0
141 | for line in f:
142 | words = line.split() + ['']
143 | for word in words:
144 | ids[token] = self.dictionary.word2idx[word]
145 | token += 1
146 |
147 | return ids
148 |
--------------------------------------------------------------------------------
/rnn_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import time
4 | import math
5 |
6 |
7 | def get_batch(source, i, bptt):
8 | # get_batch subdivides the source dataset into chunks of length args.bptt.
9 | # If source is equal to the example output of the batchify function, with
10 | # a bptt-limit of 2, we'd get the following two Variables for i = 0:
11 | # ┌ a g m s ┐ ┌ b h n t ┐
12 | # └ b h n t ┘ └ c i o u ┘
13 | # Note that despite the name of the function, the subdivison of dataset is not
14 | # done along the batch dimension (i.e. dimension 1), since that was handled
15 | # by the batchify function. The chunks are along dimension 0, corresponding
16 | # to the seq_len dimension in the LSTM.
17 | seq_len = min(bptt, len(source) - 1 - i)
18 | data = source[i:i + seq_len]
19 | target = source[i + 1:i + 1 + seq_len].view(-1)
20 | return data, target
21 |
22 |
23 | def rnn_genetic_evaluate(ga, input_model, input_criterion, data_source, ntokens, batch_size, bptt):
24 | input_model.eval() # Turn on evaluation mode which disables dropout.
25 | hidden = input_model.init_hidden(batch_size)
26 | with torch.no_grad():
27 | for ind in ga.get_current_generation():
28 | input_model.set_individual(ind)
29 | total_loss = 0
30 | for i in range(0, data_source.size(0) - 1, bptt):
31 | data, targets = get_batch(data_source, i, bptt)
32 | output, hidden = input_model(data, hidden)
33 | output_flat = output.view(-1, ntokens)
34 | total_loss += len(data) * input_criterion(output_flat, targets).item()
35 | hidden = repackage_hidden(hidden)
36 | ga.update_current_individual_fitness(ind, total_loss / (len(data_source) - 1))
37 | return ga.update_population()
38 |
39 |
40 | def rnn_evaluate(input_model, input_criterion, data_source, ntokens, batch_size, bptt):
41 | input_model.eval() # Turn on evaluation mode which disables dropout.
42 | hidden = input_model.init_hidden(batch_size)
43 | with torch.no_grad():
44 | total_loss = 0
45 | for i in range(0, data_source.size(0) - 1, bptt):
46 | data, targets = get_batch(data_source, i, bptt)
47 | output, hidden = input_model(data, hidden)
48 | output_flat = output.view(-1, ntokens)
49 | total_loss += len(data) * input_criterion(output_flat, targets).item()
50 | hidden = repackage_hidden(hidden)
51 | return total_loss / (len(data_source) - 1)
52 |
53 |
54 | def train_genetic_rnn(ga, train_data, input_model, input_optimizer, input_criterion, ntokens, batch_size, bptt,
55 | grad_clip,
56 | log_interval, final):
57 | # Turn on training mode which enables dropout.
58 | input_model.train()
59 | total_loss = 0.
60 | cur_loss = 0
61 | start_time = time.time()
62 | hidden = input_model.init_hidden(batch_size)
63 | for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
64 | data, targets = get_batch(train_data, i, bptt)
65 | # Starting each batch, we detach the hidden state from how it was previously produced.
66 | # If we didn't, the model would try backpropagating all the way to start of the dataset.
67 | hidden = repackage_hidden(hidden)
68 | input_optimizer.zero_grad() # zero old gradients for the next back propgation
69 | if not final: input_model.set_individual(ga.sample_child()) # updating
70 |
71 | output, hidden = input_model(data, hidden)
72 | loss = input_criterion(output.view(-1, ntokens), targets)
73 |
74 | loss.backward()
75 |
76 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
77 | torch.nn.utils.clip_grad_norm_(input_model.parameters(), grad_clip)
78 | input_optimizer.step()
79 |
80 | total_loss += loss.item()
81 |
82 | if batch % log_interval == 0 and batch > 0:
83 | cur_loss += total_loss
84 | elapsed = time.time() - start_time
85 | print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | '
86 | 'loss {:5.2f} | ppl {:8.2f}'.format(batch, len(train_data) // bptt, elapsed * 1000 / log_interval,
87 | total_loss / log_interval, math.exp(total_loss / log_interval)))
88 | total_loss = 0
89 | start_time = time.time()
90 | return (bptt * cur_loss) / len(train_data)
91 |
92 |
93 | def repackage_hidden(h):
94 | """Wraps hidden states in new Tensors, to detach them from their history."""
95 | if isinstance(h, torch.Tensor):
96 | return h.detach()
97 | else:
98 | return tuple(repackage_hidden(v) for v in h)
99 |
100 |
101 | class Dictionary(object):
102 | def __init__(self):
103 | self.word2idx = {}
104 | self.idx2word = []
105 |
106 | def add_word(self, word):
107 | if word not in self.word2idx:
108 | self.idx2word.append(word)
109 | self.word2idx[word] = len(self.idx2word) - 1
110 | return self.word2idx[word]
111 |
112 | def __len__(self):
113 | return len(self.idx2word)
114 |
115 |
116 | class Corpus(object):
117 | def __init__(self, path):
118 | self.dictionary = Dictionary()
119 | self.train = self.tokenize(os.path.join(path, 'train.txt'))
120 | self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
121 | self.test = self.tokenize(os.path.join(path, 'test.txt'))
122 |
123 | @staticmethod
124 | def single_batchify(data, bsz, input_device):
125 | # Work out how cleanly we can divide the dataset into bsz parts.
126 | nbatch = data.size(0) // bsz
127 | # Trim off any extra elements that wouldn't cleanly fit (remainders).
128 | data = data.narrow(0, 0, nbatch * bsz)
129 | # Evenly divide the dataset across the bsz batches.
130 | data = data.view(bsz, -1).t().contiguous()
131 | return data.to(input_device)
132 |
133 | def batchify(self, bsz, device):
134 | return self.single_batchify(self.train, bsz, device), self.single_batchify(self.valid, bsz,
135 | device), self.single_batchify(
136 | self.test, bsz, device)
137 |
138 | def tokenize(self, path):
139 | """Tokenizes a text file."""
140 | assert os.path.exists(path)
141 | # Add words to the dictionary
142 | with open(path, 'r', encoding="utf8") as f:
143 | tokens = 0
144 | for line in f:
145 | words = line.split() + ['']
146 | tokens += len(words)
147 | for word in words:
148 | self.dictionary.add_word(word)
149 |
150 | # Tokenize file content
151 | with open(path, 'r', encoding="utf8") as f:
152 | ids = torch.LongTensor(tokens)
153 | token = 0
154 | for line in f:
155 | words = line.split() + ['']
156 | for word in words:
157 | ids[token] = self.dictionary.word2idx[word]
158 | token += 1
159 |
160 | return ids
161 |
--------------------------------------------------------------------------------
/gnas/genetic_algorithm/genetic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from random import choices
3 | from gnas.search_space.search_space import SearchSpace
4 | from gnas.search_space.cross_over import individual_uniform_crossover, individual_block_crossover
5 | from gnas.search_space.mutation import individual_flip_mutation
6 | from gnas.genetic_algorithm.ga_results import GenetricResult
7 | from gnas.genetic_algorithm.population_dict import PopulationDict
8 |
9 |
10 | def genetic_algorithm_searcher(search_space: SearchSpace, generation_size=20, population_size=300, keep_size=0,
11 | min_objective=True, mutation_p=None, p_cross_over=None, cross_over_type='Bit'):
12 | if mutation_p is None: mutation_p = 1 / search_space.n_elements
13 | if p_cross_over is None: p_cross_over = 1
14 | print('p mutation:' + str(mutation_p), 1 / search_space.n_elements)
15 |
16 | def population_initializer(p_size):
17 | return search_space.generate_population(p_size)
18 |
19 | def mutation_function(x):
20 | return individual_flip_mutation(x, mutation_p)
21 |
22 | if cross_over_type == 'Bit':
23 | print("Bit base cross over")
24 |
25 | def cross_over_function(x0, x1):
26 | return individual_uniform_crossover(x0, x1, p_cross_over)
27 | elif cross_over_type == 'Block':
28 | print("Block base cross over")
29 |
30 | def cross_over_function(x0, x1):
31 | return individual_block_crossover(x0, x1, p_cross_over)
32 | else:
33 | raise Exception('')
34 |
35 | def selection_function(p):
36 | couples = choices(population=list(range(len(p))), weights=p,
37 | k=generation_size)
38 | return np.reshape(np.asarray(couples), [-1, 2])
39 |
40 | return GeneticAlgorithms(population_initializer, mutation_function, cross_over_function, selection_function,
41 | min_objective=min_objective, generation_size=generation_size,
42 | population_size=population_size, keep_size=keep_size)
43 |
44 |
45 | class GeneticAlgorithms(object):
46 | def __init__(self, population_initializer, mutation_function, cross_over_function, selection_function,
47 | population_size=300, generation_size=20, keep_size=20, min_objective=False):
48 | ####################################################################
49 | # Functions
50 | ####################################################################
51 | self.population_initializer = population_initializer
52 | self.mutation_function = mutation_function
53 | self.cross_over_function = cross_over_function
54 | self.selection_function = selection_function
55 | ####################################################################
56 | # parameters
57 | ####################################################################
58 | self.population_size = population_size
59 | self.generation_size = generation_size
60 | self.keep_size = keep_size
61 | self.min_objective = min_objective
62 | ####################################################################
63 | # status
64 | ####################################################################
65 | self.max_dict = PopulationDict()
66 | self.ga_result = GenetricResult()
67 | self.current_dict = dict()
68 |
69 | self.generation = self._create_random_generation()
70 |
71 | self.i = 0
72 | self.best_individual = None
73 |
74 | def _create_random_generation(self):
75 | return self.population_initializer(self.generation_size)
76 |
77 | def _create_new_generation(self, population, population_fitness):
78 | p = population_fitness / np.nansum(population_fitness)
79 | if self.min_objective: p = 1 - p
80 | couples = self.selection_function(p) # selection
81 | child = [cc for c in couples for cc in
82 | self.cross_over_function(population[c[0]], population[c[1]])] # cross-over
83 | new_generation = np.asarray([self.mutation_function(c) for c in child]) # mutation
84 |
85 | p_array = np.asarray([p.code for p in new_generation])
86 | b = np.ascontiguousarray(p_array).view(np.dtype((np.void, p_array.dtype.itemsize * p_array.shape[1])))
87 | _, idx = np.unique(b, return_index=True)
88 | if len(idx) == self.generation_size:
89 | generation = new_generation
90 | else:
91 | n = self.generation_size - len(idx)
92 | p_new = self.population_initializer(n)
93 | generation = np.asarray([*[new_generation[i] for i in idx], *p_new])
94 | return generation
95 |
96 | def update_population(self):
97 | self.i += 1
98 |
99 | generation_fitness = np.asarray(list(self.current_dict.values()))
100 | generation = list(self.current_dict.keys())
101 | self.ga_result.add_generation_result(generation_fitness, generation)
102 |
103 | f_mean = np.mean(generation_fitness)
104 | f_var = np.var(generation_fitness)
105 | f_max = np.max(generation_fitness)
106 | f_min = np.min(generation_fitness)
107 | total_dict = self.max_dict.copy()
108 | total_dict.update(self.current_dict)
109 | # last_dict = None
110 | # if self.keep_size > 0:
111 | # last_dict = total_dict.filter_last_n(self.keep_size)
112 | # if self.population_size - self.keep_size > 0:
113 |
114 | best_max_dict = total_dict.filter_top_n(self.population_size,min_max=not self.min_objective)
115 | n_diff = self.max_dict.get_n_diff(best_max_dict)
116 | self.max_dict = best_max_dict
117 | #
118 | # if self.keep_size > 0:
119 | # best_max_dict = best_max_dict.merge(last_dict)
120 | # else:
121 | # best_max_dict = last_dict
122 |
123 |
124 |
125 |
126 | self.current_dict = dict()
127 | population_fitness = np.asarray(list(self.max_dict.values())).flatten()
128 | population = np.asarray(list(self.max_dict.keys())).flatten()
129 | self.best_individual = population[np.argmax(population_fitness)]
130 | fp_mean = np.mean(population_fitness)
131 | fp_var = np.var(population_fitness)
132 | fp_max = np.max(population_fitness)
133 | fp_min = np.min(population_fitness)
134 | self.ga_result.add_population_result(population_fitness, population)
135 | self.generation = self._create_new_generation(population, population_fitness)
136 |
137 |
138 | print(
139 | "Update generation | mean fitness: {:5.2f} | var fitness {:5.2f} | max fitness: {:5.2f} | min fitness {:5.2f} |population size {:d}|".format(
140 | f_mean, f_var, f_max, f_min, len(population)))
141 | print(
142 | "population results | mean fitness: {:5.2f} | var fitness {:5.2f} | max fitness: {:5.2f} | min fitness {:5.2f} |".format(
143 | fp_mean, fp_var, fp_max, fp_min))
144 | return f_mean, f_var, f_max, f_min, n_diff
145 |
146 | def get_current_generation(self):
147 | return self.generation
148 |
149 | def update_current_individual_fitness(self, individual, individual_fitness):
150 | self.current_dict.update({individual: individual_fitness})
151 |
152 | def sample_child(self):
153 | if len(list(self.max_dict.keys())) == 0: # if not population exist generate random indivaul
154 | return self.population_initializer(1)[0]
155 | else:
156 | couples = choices(list(self.max_dict.keys()), k=2) # random select two indivuals from population
157 | child = self.cross_over_function(couples[0], couples[1]) # prefome cross over
158 | return self.mutation_function(child[0]) # select the first then mutation
159 |
--------------------------------------------------------------------------------
/plot_result.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | from matplotlib import pyplot as plt
5 | from common import load_final, make_log_dir, get_model_type, ModelType
6 | from config import get_config, load_config, save_config
7 | import gnas
8 | from modules.drop_module import DropModuleControl
9 | from gnas.common.graph_draw import draw_cell, draw_network
10 | import matplotlib.image as mpimg
11 |
12 | # Popultation size compare
13 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00",
14 | '/data/projects/gnas_results/population_size/2019_01_31_15_45_42',
15 | '/data/projects/gnas_results/population_size/2019_02_01_03_26_01',
16 | '/data/projects/gnas_results/population_size/2019_02_01_04_23_06',
17 | '/data/projects/gnas_results/population_size/2019_02_01_04_44_45',
18 | '/data/projects/gnas_results/population_size/2019_02_01_15_39_37',
19 | '/data/projects/gnas_results/population_size/2019_02_03_17_25_25',
20 | # '/data/projects/gnas_results/population_size/2019_02_03_17_25_27',
21 | '/data/projects/gnas_results/population_size/2019_02_03_17_25_28']
22 |
23 | # p mutation
24 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_23_21_20_33",
25 | "/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00",
26 | "/data/projects/gnas_results/p_mutation/2019_01_25_08_46_39",
27 | "/data/projects/gnas_results/p_mutation/2019_01_26_13_18_17"]
28 |
29 | # # LR Compare
30 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00",
31 | "/data/projects/gnas_results/lr_compare/2019_02_04_19_17_59",
32 | "/data/projects/gnas_results/lr_compare/2019_02_04_19_18_00"]
33 |
34 | # # Bit Vs Block
35 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00",
36 | "/data/projects/GNAS/logs/2019_02_11_06_15_10"]
37 | #
38 | # # Plot CIFAR10 - Search Result
39 | file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00"]
40 | # # Plot CIFAR100 - Search Result
41 | file_list = ["/data/projects/GNAS/logs/2019_02_17_20_25_42"]
42 |
43 | # CIFAR10 Final
44 | file_list = ['/data/projects/gnas_results/new_log/2019_02_09_16_43_02']
45 | # file_list=['/data/projects/gnas_results/new_log/2019_02_07_18_34_45',
46 | # '/data/projects/gnas_results/new_log/2019_02_09_02_23_52',
47 | # '/data/projects/gnas_results/new_log/2019_02_09_16_43_02',
48 | # '/data/projects/gnas_results/new_log/2019_02_14_18_15_48']
49 |
50 |
51 | plot_arc = False
52 | # file_list = ["/data/projects/gnas_results/p_mutation/2019_01_24_19_06_00", ]
53 | if plot_arc:
54 | ind_file = os.path.join(file_list[0], 'best_individual.pickle')
55 | config_file = os.path.join(file_list[0], 'config.json')
56 | ind = pickle.load(open(ind_file, "rb"))
57 |
58 | config = get_config(ModelType.CNN)
59 | print("Loading config file:" + config_file)
60 | config.update(load_config(config_file))
61 |
62 | dp_control = DropModuleControl(config.get('drop_path_keep_prob'))
63 | n_cell_type = gnas.SearchSpaceType(config.get('n_block_type') - 1)
64 | ss = gnas.get_gnas_cnn_search_space(config.get('n_nodes'), dp_control, n_cell_type)
65 | draw_network(ss, ind, './')
66 | title_list = ['Reduce Cell', ' Normal Cell', ' Input Cell']
67 | for i in range(len(ss.ocl)):
68 | plt.subplot(1, len(ss.ocl), i + 1)
69 | img = mpimg.imread(os.path.join('./', str(i) + '.png'))
70 | plt.imshow(img)
71 | plt.axis('off')
72 | plt.title(title_list[i])
73 | plt.show()
74 | # draw_cell(ss.ocl[0], ind.individual_list[0])
75 | # plt.show()
76 | # print("a")
77 |
78 | if len(file_list) == 1 and True:
79 | data = pickle.load(open(os.path.join(file_list[0], 'ga_result.pickle'), "rb"))
80 | config = load_config(os.path.join(file_list[0], 'config.json'))
81 | if data.result_dict.get('Fitness') is None:
82 | plt.plot(np.asarray(data.result_dict.get('Training Accuracy')), label='Training Accuracy')
83 | plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), label='Validation Accuracy')
84 | plt.xlabel('Epoch')
85 | plt.legend()
86 | plt.ylabel('Accuracy[%]')
87 | plt.grid()
88 | plt.show()
89 | else:
90 | fitness = np.stack(data.result_dict.get('Fitness'))
91 | fitness_p = np.stack(data.result_dict.get('Fitness-Population'))
92 | fitness_p = fitness_p[0:-1:2, :]
93 |
94 | epochs = np.linspace(0, fitness_p.shape[0] - 1, fitness_p.shape[0])
95 | plt.plot(epochs, np.mean(fitness_p, axis=1), '*--',
96 | label='Population mean accuracy')
97 | plt.plot(epochs, np.max(fitness_p, axis=1), label='Max accuracy')
98 | plt.plot(np.asarray(data.result_dict.get('Best')), '--', label='Best')
99 | plt.grid()
100 | plt.legend()
101 | plt.xlabel('Epoch')
102 | plt.ylabel('Accuracy')
103 | plt.show()
104 |
105 | plt.errorbar(epochs, np.mean(fitness_p, axis=1), np.std(fitness_p, axis=1), fmt='*--',
106 | label='Population mean accuracy')
107 | plt.plot(epochs, np.min(fitness_p, axis=1), label='Min accuracy')
108 | plt.plot(epochs, np.max(fitness_p, axis=1), label='Max accuracy')
109 | plt.grid()
110 | plt.legend()
111 | plt.title('Population accuracy on the validation set')
112 | plt.xlabel('Epoch')
113 | plt.ylabel('Accuracy')
114 | plt.show()
115 |
116 | plt.plot(np.asarray(data.result_dict.get('Training Accuracy')), label='Training')
117 | plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), '--', label='Validation')
118 | plt.plot(np.asarray(data.result_dict.get('Best')), '*-', label='Best')
119 | plt.title('Training vs Validation Accuracy')
120 | plt.xlabel('Epoch')
121 | plt.legend()
122 | plt.ylabel('Accuracy[%]')
123 | plt.grid()
124 | plt.show()
125 |
126 | plt.plot(epochs, data.result_dict.get('N'))
127 | plt.title('Number of new individuals in Population')
128 | plt.xlabel('Epoch')
129 | plt.ylabel('N')
130 | plt.grid()
131 | plt.show()
132 |
133 | plt.plot(epochs, data.result_dict.get('Training Loss'))
134 | plt.title('Training Loss')
135 | plt.xlabel('Epoch')
136 | plt.ylabel('Loss')
137 | plt.grid()
138 | plt.show()
139 |
140 | else:
141 | ################
142 | # Build legend
143 | ################
144 | config_list = []
145 | param_list = []
146 | for f in file_list:
147 | config_list.append(load_config(os.path.join(f, 'config.json')))
148 | for k in config_list[-1].keys():
149 | param_list.append(k)
150 | param_list = np.unique(param_list)
151 | str_list = ['' for c in config_list]
152 | res_dict = dict()
153 | for p in param_list:
154 | if len(np.unique([c.get(p) for c in config_list if c.get(p) is not None])) > 1:
155 | for i, c in enumerate(config_list):
156 | str_list[i] = str_list[i] + ' ' + p + '=' + str(c.get(p))
157 | if res_dict.get(p) is None:
158 | res_dict.update({p: [c.get(p)]})
159 | else:
160 | res_dict.get(p).append(c.get(p))
161 | elif len(np.unique([c.get(p) for c in config_list if c.get(p) is not None])) == 1:
162 | if len([c.get(p) for c in config_list if c.get(p) is None]) != 0:
163 | for i, c in enumerate(config_list):
164 | str_list[i] = str_list[i] + ' ' + p + '=' + str(c.get(p))
165 | if len(res_dict.keys()) == 1:
166 | param_array = np.asarray(res_dict.get(list(res_dict.keys())[0]))
167 | res_list = []
168 | for i, f in enumerate(file_list):
169 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb"))
170 | res_list.append(np.max(np.asarray(data.result_dict.get('Best'))))
171 | index = np.argsort(param_array)
172 | res_list = np.asarray(res_list)[index]
173 | param_array = param_array[index]
174 | plt.plot(param_array, res_list)
175 | plt.grid()
176 | plt.xlabel(list(res_dict.keys())[0].replace('_', ' '))
177 | plt.ylabel('Accuracy[%]')
178 | plt.show()
179 | print("a")
180 | #########################
181 | # Plot Validation
182 | #########################
183 | plt.subplot(2, 2, 1)
184 | for i, f in enumerate(file_list):
185 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb"))
186 | plt.plot(np.asarray(data.result_dict.get('Best')), label=str_list[i])
187 | # plt.title()
188 | plt.legend()
189 | plt.grid()
190 | plt.subplot(2, 2, 2)
191 | for i, f in enumerate(file_list):
192 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb"))
193 | config = load_config(os.path.join(f, 'config.json'))
194 | plt.plot(np.asarray(data.result_dict.get('Training Accuracy')), label=str_list[i])
195 | # plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), '*--', label='Validation ' + str_list[i])
196 | plt.legend()
197 | plt.grid()
198 | plt.subplot(2, 2, 3)
199 | for i, f in enumerate(file_list):
200 | data = pickle.load(open(os.path.join(f, 'ga_result.pickle'), "rb"))
201 | config = load_config(os.path.join(f, 'config.json'))
202 | plt.plot(np.asarray(data.result_dict.get('Training Loss')), label=str_list[i])
203 | # plt.plot(np.asarray(data.result_dict.get('Validation Accuracy')), '*--', label='Validation ' + str_list[i])
204 | plt.legend()
205 | plt.grid()
206 | plt.show()
207 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch.nn as nn
3 |
4 | import torch
5 | import torch.optim as optim
6 | import os
7 | import pickle
8 | import argparse
9 |
10 | import gnas
11 | from models import model_cnn, model_rnn
12 | from cnn_utils import evaluate_single, evaluate_individual_list
13 | from rnn_utils import train_genetic_rnn, rnn_genetic_evaluate, rnn_evaluate
14 | from data import get_dataset
15 | from common import load_final, make_log_dir, get_model_type, ModelType
16 | from config import get_config, load_config, save_config
17 | from modules.drop_module import DropModuleControl
18 | from modules.cosine_annealing import CosineAnnealingLR
19 |
20 | #######################################
21 | # Constants
22 | #######################################
23 | log_interval = 200
24 | #######################################
25 | # User input
26 | #######################################
27 | parser = argparse.ArgumentParser(description='PyTorch GNAS')
28 | parser.add_argument('--dataset_name', type=str, choices=['CIFAR10', 'CIFAR100', 'PTB'], help='the working data',
29 | default='CIFAR10')
30 | parser.add_argument('--config_file', type=str, help='location of the config file')
31 | parser.add_argument('--search_dir', type=str, help='the log dir of the search')
32 | parser.add_argument('--final', type=bool, help='location of the config file', default=False)
33 | parser.add_argument('--data_path', type=str, default='./dataset/', help='location of the dataset')
34 | args = parser.parse_args()
35 | #######################################
36 | # Search Working Device
37 | #######################################
38 | working_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39 | print(working_device)
40 | #######################################
41 | # Set seed
42 | #######################################
43 | model_type = get_model_type(dataset_name=args.dataset_name)
44 | print("Selected mode type:" + str(model_type))
45 | #######################################
46 | # Parameters
47 | #######################################
48 | config = get_config(model_type)
49 | if args.config_file is not None:
50 | print("Loading config file:" + args.config_file)
51 | config.update(load_config(args.config_file))
52 | config.update({'data_path': args.data_path, 'dataset_name': args.dataset_name, 'working_device': str(working_device)})
53 | print(config)
54 | ######################################
55 | # Read dataset and set augmentation
56 | ######################################
57 | trainloader, testloader, n_param = get_dataset(config)
58 | ######################################
59 | # Config model and search space
60 | ######################################
61 | if model_type == ModelType.CNN:
62 | min_objective = False
63 | n_cell_type = gnas.SearchSpaceType(config.get('n_block_type') - 1)
64 | dp_control = DropModuleControl(config.get('drop_path_keep_prob'))
65 | ss = gnas.get_gnas_cnn_search_space(config.get('n_nodes'), dp_control, n_cell_type)
66 |
67 | net = model_cnn.Net(config.get('n_blocks'), config.get('n_channels'), n_param,
68 | config.get('dropout'),
69 | ss, aux=config.get('aux_loss')).to(working_device)
70 | ######################################
71 | # Build Optimizer and Loss function
72 | #####################################
73 | optimizer = optim.SGD(net.parameters(), lr=config.get('learning_rate'), momentum=config.get('momentum'),
74 | nesterov=True,
75 | weight_decay=config.get('weight_decay'))
76 | elif model_type == ModelType.RNN:
77 | min_objective = True
78 | ntokens = n_param
79 | ss = gnas.get_gnas_rnn_search_space(config.get('n_nodes'))
80 | net = model_rnn.RNNModel(ntokens, config.get('n_channels'), config.get('n_channels'), config.get('n_blocks'),
81 | config.get('dropout'),
82 | tie_weights=True,
83 | ss=ss).to(
84 | working_device)
85 | ######################################
86 | # Build Optimizer and Loss function
87 | #####################################
88 | optimizer = optim.SGD(net.parameters(), lr=config.get('learning_rate'),
89 | weight_decay=config.get('weight_decay'))
90 | ######################################
91 | # Build genetic_algorithm_searcher
92 | #####################################
93 | ga = gnas.genetic_algorithm_searcher(ss, generation_size=config.get('generation_size'),
94 | population_size=config.get('population_size'),
95 | keep_size=config.get('keep_size'), mutation_p=config.get('mutation_p'),
96 | p_cross_over=config.get('p_cross_over'),
97 | cross_over_type=config.get('cross_over_type'),
98 | min_objective=min_objective)
99 | ######################################
100 | # Loss function
101 | ######################################
102 | criterion = nn.CrossEntropyLoss()
103 | ######################################
104 | # Select Learning schedule
105 | #####################################
106 | if config.get('LRType') == 'CosineAnnealingLR':
107 | scheduler = CosineAnnealingLR(optimizer, 10, 2, config.get('lr_min'))
108 | elif config.get('LRType') == 'MultiStepLR':
109 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
110 | [int(config.get('n_epochs') / 2), int(3 * config.get('n_epochs') / 4)])
111 | elif config.get('LRType') == 'ExponentialLR':
112 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=config.get('gamma'))
113 | else:
114 | raise Exception('unkown LRType:' + config.get('LRType'))
115 |
116 | ##################################################
117 | # Generate log dir and Save Params
118 | ##################################################
119 | log_dir = make_log_dir(config)
120 | save_config(log_dir, config)
121 | #######################################
122 | # Load Indvidual
123 | #######################################
124 | if args.final: ind = load_final(net, args.search_dir)
125 | ##################################################
126 | # Start Epochs
127 | ##################################################
128 | ra = gnas.ResultAppender()
129 | if model_type == ModelType.CNN:
130 | best = 0
131 | print("Starting Traing with CNN Model")
132 | for epoch in range(config.get('n_epochs')): # loop over the dataset multiple times
133 | # print(epoch)
134 | running_loss = 0.0
135 | correct = 0
136 | total = 0
137 |
138 | scheduler.step()
139 | s = time.time()
140 | net = net.train()
141 | if epoch == config.get('drop_path_start_epoch'):
142 | dp_control.enable()
143 | ############################################
144 | # Loop over batchs update weights
145 | ############################################
146 | for i, (inputs, labels) in enumerate(trainloader, 0): # Loop over batchs
147 | # get the inputs
148 | # sample child from population
149 | if not args.final:
150 | net.set_individual(ga.sample_child())
151 |
152 | inputs = inputs.to(working_device)
153 | labels = labels.to(working_device)
154 |
155 | optimizer.zero_grad() # zero the parameter gradients
156 | outputs = net(inputs) # forward
157 |
158 | _, predicted = torch.max(outputs[0], 1)
159 | total += labels.size(0)
160 | correct += (predicted == labels).sum().item()
161 |
162 | loss = criterion(outputs[0], labels)
163 | if config.get('aux_loss'): loss += config.get('aux_scale') * criterion(outputs[1], labels)
164 | loss.backward() # backward
165 |
166 | optimizer.step() # optimize
167 |
168 | # print statistics
169 | running_loss += loss.item()
170 | ############################################
171 | # Update GA population
172 | ############################################
173 | if args.final:
174 | f_max = evaluate_single(ind, net, testloader, working_device)
175 | n_diff = 0
176 | else:
177 | if config.get('full_dataset'):
178 | for ind in ga.get_current_generation():
179 | acc = evaluate_single(ind, net, testloader, working_device)
180 | ga.update_current_individual_fitness(ind, acc)
181 | _, _, f_max, _, n_diff = ga.update_population()
182 | best_individual = ga.best_individual
183 | else:
184 |
185 | f_max = 0
186 | n_diff = 0
187 | for _ in range(config.get('generation_per_epoch')):
188 | evaluate_individual_list(ga.get_current_generation(), ga, net, testloader,
189 | working_device) # evaluate next generation on the validation set
190 | _, _, v_max, _, n_d = ga.update_population() # replacement
191 | n_diff += n_d
192 | if v_max > f_max:
193 | f_max = v_max
194 | best_individual = ga.best_individual
195 | f_max = evaluate_single(best_individual, net, testloader, working_device) # evalute best
196 | if f_max > best:
197 | print("Update Best")
198 | best = f_max
199 | torch.save(net.state_dict(), os.path.join(log_dir, 'best_model.pt'))
200 | if not args.final:
201 | gnas.draw_network(ss, ga.best_individual, os.path.join(log_dir, 'best_graph_' + str(epoch) + '_'))
202 | pickle.dump(ga.best_individual, open(os.path.join(log_dir, 'best_individual.pickle'), "wb"))
203 | print(
204 | '|Epoch: {:2d}|Time: {:2.3f}|Loss:{:2.3f}|Accuracy: {:2.3f}%|Validation Accuracy: {:2.3f}%|LR: {:2.3f}|N Change : {:2d}|'.format(
205 | epoch, (
206 | time.time() - s) / 60,
207 | running_loss / i,
208 | 100 * correct / total, f_max,
209 | scheduler.get_lr()[
210 | -1],
211 | n_diff))
212 | ra.add_epoch_result('N', n_diff)
213 | ra.add_epoch_result('Best', best)
214 | ra.add_epoch_result('Validation Accuracy', f_max)
215 | ra.add_epoch_result('LR', scheduler.get_lr()[-1])
216 | ra.add_epoch_result('Training Loss', running_loss / i)
217 | ra.add_epoch_result('Training Accuracy', 100 * correct / total)
218 | if not args.final:
219 | ra.add_result('Fitness', ga.ga_result.fitness_list)
220 | ra.add_result('Fitness-Population', ga.ga_result.fitness_full_list)
221 | ra.save_result(log_dir)
222 | elif model_type == ModelType.RNN:
223 | best = 1000
224 | for epoch in range(1, config.get('n_epochs') + 1):
225 | if epoch > 15:
226 | scheduler.step()
227 | epoch_start_time = time.time()
228 | eval_batch_size = config.get('batch_size_val')
229 | train_loss = train_genetic_rnn(ga, trainloader, net, optimizer, criterion, ntokens, config.get('batch_size'),
230 | config.get('bptt'), config.get('clip'),
231 | log_interval, args.final)
232 | if args.final:
233 | min_loss = rnn_evaluate(net, criterion, testloader, ntokens, config.get('batch_size_val'),
234 | config.get('bptt'))
235 | else:
236 | val_loss, loss_var, max_loss, min_loss, n_diff = rnn_genetic_evaluate(ga, net, criterion, testloader,
237 | ntokens,
238 | config.get('batch_size_val'),
239 | config.get('bptt'))
240 |
241 | print('-' * 89)
242 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | lr {:02.2f} | '
243 | ''.format(epoch, (time.time() - epoch_start_time),
244 | min_loss, scheduler.get_lr()[-1]))
245 | print('-' * 89)
246 | # Save the model if the validation loss is the best we've seen so far.
247 | if min_loss < best:
248 | print("Update Best")
249 | torch.save(net.state_dict(), os.path.join(log_dir, 'best_model.pt'))
250 | if not args.final:
251 | gnas.draw_network(ss, ga.best_individual, os.path.join(log_dir, 'best_graph_' + str(epoch) + '_'))
252 | pickle.dump(ga.best_individual, open(os.path.join(log_dir, 'best_individual.pickle'), "wb"))
253 |
254 | best = min_loss
255 |
256 | ra.add_epoch_result('Loss', train_loss)
257 | ra.add_epoch_result('LR', scheduler.get_lr()[-1])
258 | ra.add_epoch_result('Best', best)
259 | if not args.final: ra.add_result('Fitness', ga.ga_result.fitness_list)
260 | ra.save_result(log_dir)
261 | print('Finished Training')
262 |
--------------------------------------------------------------------------------