├── assets ├── fix11.png ├── teaser.png ├── 50dot6percent.png ├── 56dot8percent.png └── train-overall.png ├── requirements.txt ├── model ├── __init__.py ├── graphs │ ├── generated │ │ ├── .gitignore │ │ ├── ws-4-075-3.txt │ │ ├── ws-4-075-4.txt │ │ └── ws-4-075-5.txt │ ├── ba.py │ ├── er.py │ └── ws.py ├── sep_conv.py ├── node.py ├── model.py └── dag_layer.py ├── utils ├── __init__.py ├── graph_reader.py ├── writer.py ├── evaluation.py ├── hparams.py └── train.py ├── dataset ├── __init__.py └── dataloader.py ├── config ├── test.yaml └── default.yaml ├── test.py ├── validation.py ├── .gitignore ├── trainer.py └── README.md /assets/fix11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungwonpark/RandWireNN/HEAD/assets/fix11.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungwonpark/RandWireNN/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/50dot6percent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungwonpark/RandWireNN/HEAD/assets/50dot6percent.png -------------------------------------------------------------------------------- /assets/56dot8percent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungwonpark/RandWireNN/HEAD/assets/56dot8percent.png -------------------------------------------------------------------------------- /assets/train-overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungwonpark/RandWireNN/HEAD/assets/train-overall.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | adabound 2 | numpy 3 | pyyaml 4 | tensorboardX 5 | tensorflow 6 | torch 7 | torchvision 8 | tqdm 9 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 8 08:22:37 2019 4 | 5 | @author: Michael 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 8 08:22:54 2019 4 | 5 | @author: Michael 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 8 08:22:05 2019 4 | 5 | @author: Michael 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /model/graphs/generated/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except these files 4 | !.gitignore 5 | !ws-4-075-3.txt 6 | !ws-4-075-4.txt 7 | !ws-4-075-5.txt 8 | -------------------------------------------------------------------------------- /config/test.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: '' # MNIST, CIFAR10, ImageNet 3 | train: 4 | optimizer: 'adam' 5 | adam: 0.001 6 | adabound: 7 | initial: 0.001 8 | final: 0.05 9 | --- 10 | model: 11 | channel: 78 # 'small regime' 12 | classes: 1000 # use 10 for MNIST, CIFAR10 13 | graph0: 'ws-4-075-3.txt' 14 | graph1: 'ws-4-075-4.txt' 15 | graph2: 'ws-4-075-5.txt' 16 | --- 17 | log: 18 | chkpt_dir: 'chkpt' 19 | log_dir: 'logs' 20 | -------------------------------------------------------------------------------- /utils/graph_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def read_graph(txtfile): 5 | txtpath = os.path.join('model', 'graphs', 'generated', txtfile) 6 | with open(txtpath, 'r') as f: 7 | num_nodes = int(f.readline().strip()) 8 | num_edges = int(f.readline().strip()) 9 | edges = list() 10 | for _ in range(num_edges): 11 | s, e = map(int, f.readline().strip().split()) 12 | edges.append((s, e)) 13 | 14 | temp = dict() 15 | temp['num_nodes'] = num_nodes 16 | temp['edges'] = edges 17 | return temp 18 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | 3 | 4 | class MyWriter(SummaryWriter): 5 | def __init__(self, logdir): 6 | super(MyWriter, self).__init__(logdir) 7 | 8 | def log_training(self, train_loss, step): 9 | self.add_scalar('loss/train_loss', train_loss, step) 10 | 11 | def log_evaluation(self, test_loss, accuracy, step): 12 | self.add_scalar('loss/test_avg_loss', test_loss, step) 13 | self.add_scalar('eval/Top1_accuracy', accuracy, step) 14 | 15 | def write_graph(self, model, dummy_input): 16 | self.add_graph(model, dummy_input) 17 | 18 | -------------------------------------------------------------------------------- /model/sep_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SeparableConv2d(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, 7 | dilation=1, bias=False): 8 | super(SeparableConv2d,self).__init__() 9 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, 10 | dilation, groups=in_channels, bias=bias) 11 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 12 | 13 | def forward(self,x): 14 | x = self.conv1(x) 15 | x = self.pointwise(x) 16 | return x -------------------------------------------------------------------------------- /model/graphs/generated/ws-4-075-3.txt: -------------------------------------------------------------------------------- 1 | 32 2 | 64 3 | 0 1 4 | 0 2 5 | 0 30 6 | 0 31 7 | 1 2 8 | 1 3 9 | 1 31 10 | 2 3 11 | 2 4 12 | 3 4 13 | 3 5 14 | 4 5 15 | 4 6 16 | 5 6 17 | 5 7 18 | 6 7 19 | 6 8 20 | 7 8 21 | 7 9 22 | 8 9 23 | 8 10 24 | 9 10 25 | 9 11 26 | 10 11 27 | 10 12 28 | 11 12 29 | 11 13 30 | 12 13 31 | 12 14 32 | 13 14 33 | 13 15 34 | 14 15 35 | 14 16 36 | 15 16 37 | 15 17 38 | 16 17 39 | 16 18 40 | 17 18 41 | 17 19 42 | 18 19 43 | 18 20 44 | 19 20 45 | 19 21 46 | 20 21 47 | 20 22 48 | 21 22 49 | 21 23 50 | 22 23 51 | 22 24 52 | 23 24 53 | 23 25 54 | 24 25 55 | 24 26 56 | 25 26 57 | 25 27 58 | 26 27 59 | 26 28 60 | 27 28 61 | 27 29 62 | 28 29 63 | 28 30 64 | 29 30 65 | 29 31 66 | 30 31 67 | -------------------------------------------------------------------------------- /model/graphs/generated/ws-4-075-4.txt: -------------------------------------------------------------------------------- 1 | 32 2 | 64 3 | 0 1 4 | 0 2 5 | 0 30 6 | 0 31 7 | 1 2 8 | 1 3 9 | 1 31 10 | 2 3 11 | 2 4 12 | 3 4 13 | 3 5 14 | 4 5 15 | 4 6 16 | 5 6 17 | 5 7 18 | 6 7 19 | 6 8 20 | 7 8 21 | 7 9 22 | 8 9 23 | 8 10 24 | 9 10 25 | 9 11 26 | 10 11 27 | 10 12 28 | 11 12 29 | 11 13 30 | 12 13 31 | 12 14 32 | 13 14 33 | 13 15 34 | 14 15 35 | 14 16 36 | 15 16 37 | 15 17 38 | 16 17 39 | 16 18 40 | 17 18 41 | 17 19 42 | 18 19 43 | 18 20 44 | 19 20 45 | 19 21 46 | 20 21 47 | 20 22 48 | 21 22 49 | 21 23 50 | 22 23 51 | 22 24 52 | 23 24 53 | 23 25 54 | 24 25 55 | 24 26 56 | 25 26 57 | 25 27 58 | 26 27 59 | 26 28 60 | 27 28 61 | 27 29 62 | 28 29 63 | 28 30 64 | 29 30 65 | 29 31 66 | 30 31 67 | -------------------------------------------------------------------------------- /model/graphs/generated/ws-4-075-5.txt: -------------------------------------------------------------------------------- 1 | 32 2 | 64 3 | 0 1 4 | 0 2 5 | 0 30 6 | 0 31 7 | 1 2 8 | 1 3 9 | 1 31 10 | 2 3 11 | 2 4 12 | 3 4 13 | 3 5 14 | 4 5 15 | 4 6 16 | 5 6 17 | 5 7 18 | 6 7 19 | 6 8 20 | 7 8 21 | 7 9 22 | 8 9 23 | 8 10 24 | 9 10 25 | 9 11 26 | 10 11 27 | 10 12 28 | 11 12 29 | 11 13 30 | 12 13 31 | 12 14 32 | 13 14 33 | 13 15 34 | 14 15 35 | 14 16 36 | 15 16 37 | 15 17 38 | 16 17 39 | 16 18 40 | 17 18 41 | 17 19 42 | 18 19 43 | 18 20 44 | 19 20 45 | 19 21 46 | 20 21 47 | 20 22 48 | 21 22 49 | 21 23 50 | 22 23 51 | 22 24 52 | 23 24 53 | 23 25 54 | 24 25 55 | 24 26 56 | 25 26 57 | 25 27 58 | 26 27 59 | 26 28 60 | 27 28 61 | 27 29 62 | 28 29 63 | 28 30 64 | 29 30 65 | 29 31 66 | 30 31 67 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def validate(model, valset, writer=None, epoch=None): 8 | model.eval() 9 | test_loss = 0 10 | correct = 0 11 | with torch.no_grad(): 12 | for idx, (data, target) in tqdm.tqdm(enumerate(valset)): 13 | data, target = data.cuda(), target.cuda() 14 | output = model(data) 15 | test_loss += F.nll_loss(output, target).item() 16 | pred = output.argmax(dim=1, keepdim=True) 17 | correct += pred.eq(target.view_as(pred)).sum().item() 18 | 19 | test_loss /= len(valset.dataset) 20 | accuracy = correct / len(valset.dataset) 21 | 22 | if writer is not None: 23 | writer.log_evaluation(test_loss, accuracy, epoch) 24 | 25 | model.train() 26 | return test_loss, accuracy 27 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train: '' # if using MNIST, can download directly from cloud with train: 'download', test: 'download' 3 | val: '' 4 | batch_size: 256 # learning rate should be adjusted along with batch_size. see 1706.02677 5 | num_workers: 16 6 | type: '' # MNIST, CIFAR10, ImageNet 7 | --- 8 | train: 9 | optimizer: 'adam' 10 | epoch: 250 # 'small regime'. see table 2 of the paper. 11 | adam: 0.001 12 | adabound: 13 | initial: 0.001 14 | final: 0.05 15 | sgd: 16 | lr: 0.1 17 | momentum: 0.9 18 | weight_decay: 0.00005 19 | decay: 20 | step: 150000 21 | gamma: 0.1 22 | --- 23 | model: 24 | channel: 78 # 'small regime' 25 | classes: 1000 # use 10 for MNIST, CIFAR10 26 | input_maps: 3 # 3 for ImageNet/CIFAR10, 1 for MNIST 27 | graph0: '' # example: 'ws-4-075-3.txt' 28 | graph1: '' 29 | graph2: '' 30 | --- 31 | log: 32 | chkpt_dir: 'chkpt' 33 | log_dir: 'logs' 34 | -------------------------------------------------------------------------------- /model/node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .sep_conv import SeparableConv2d 6 | 7 | 8 | class NodeOp(nn.Module): 9 | def __init__(self, in_degree, in_channel, out_channel, stride): 10 | super(NodeOp, self).__init__() 11 | self.single = (in_degree == 1) 12 | if not self.single: 13 | self.agg_weight = nn.Parameter(torch.zeros(in_degree, requires_grad=True)) 14 | self.conv = SeparableConv2d(in_channel, out_channel, kernel_size=3, padding=1, stride=stride) 15 | self.bn = nn.BatchNorm2d(out_channel) 16 | 17 | def forward(self, y): 18 | # y: [B, C, N, M, in_degree] 19 | if self.single: 20 | y = y.squeeze(-1) 21 | else: 22 | y = torch.matmul(y, torch.sigmoid(self.agg_weight)) # [B, C, N, M] 23 | y = F.relu(y) # [B, C, N, M] 24 | y = self.conv(y) # [B, C_out, N, M] 25 | y = self.bn(y) # [B, C_out, N, M] 26 | return y 27 | 28 | 29 | # if __name__ == '__main__': 30 | # x = torch.randn(7, 3, 224, 224, 5) 31 | # node = NodeOp(5, 3, 4) 32 | # y = node(x) 33 | # print(y.shape) # [7, 4, 224, 224] 34 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from utils.hparams import HParam 6 | from model.model import RandWire 7 | 8 | 9 | def read_graph(txtfile): 10 | txtpath = os.path.join('model', 'graphs', 'generated', txtfile) 11 | with open(txtpath, 'r') as f: 12 | num_nodes = int(f.readline().strip()) 13 | num_edges = int(f.readline().strip()) 14 | edges = list() 15 | for _ in range(num_edges): 16 | s, e = map(int, f.readline().strip().split()) 17 | edges.append((s, e)) 18 | 19 | temp = dict() 20 | temp['num_nodes'] = num_nodes 21 | temp['edges'] = edges 22 | return temp 23 | 24 | if __name__ == '__main__': 25 | hp = HParam('config/test.yaml') 26 | graphs = [ 27 | read_graph(hp.model.graph0), 28 | read_graph(hp.model.graph1), 29 | read_graph(hp.model.graph2), 30 | ] 31 | 32 | print('Building Network...') 33 | model = RandWire(hp, graphs) 34 | 35 | x = torch.randn(16, 3, 224, 224) # RGB-channel 224x224 image with batch_size=16 36 | print('Input shape:') 37 | print(x.shape) 38 | y = model(x) 39 | print('Output shape:') 40 | print(y.shape) # [3, 1000] 41 | -------------------------------------------------------------------------------- /model/graphs/ba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import numpy as np 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser('Barabasi-Albert graph generator') 9 | parser.add_argument('-n', '--n_nodes', type=int, default=32, 10 | help="number of nodes for random graph") 11 | parser.add_argument('-m', '--m_nodes', type=int, required=True, 12 | help="initial number of nodes for random graph") 13 | parser.add_argument('-o', '--out_txt', type=str, required=True, 14 | help="name of output txt file") 15 | args = parser.parse_args() 16 | n, m = args.n_nodes, args.m_nodes 17 | 18 | assert 1 <= m < n, "m must be smaller than n." 19 | 20 | edges = list() 21 | deg = np.zeros(n) 22 | 23 | for i in range(m, n): 24 | if i == m: 25 | for j in range(i): 26 | edges.append((j, i)) 27 | deg[j] += 1 28 | deg[i] += 1 29 | continue 30 | 31 | connection = np.random.choice(range(n), size=m, replace=False, 32 | p=deg/np.sum(deg)) 33 | for cnt in connection: 34 | edges.append((cnt, i)) 35 | deg[cnt] += 1 36 | deg[i] += 1 37 | 38 | edges.sort() 39 | 40 | os.makedirs('generated', exist_ok=True) 41 | with open(os.path.join('generated', args.out_txt), 'w') as f: 42 | f.write(str(n) + '\n') 43 | f.write(str(len(edges)) + '\n') 44 | for edge in edges: 45 | f.write('%d %d\n' % (edge[0], edge[1])) 46 | -------------------------------------------------------------------------------- /model/graphs/er.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import numpy as np 5 | 6 | 7 | def shuffle(n, edges): 8 | mapping = np.random.permutation(range(n)) 9 | shuffled = list() 10 | for edge in edges: 11 | s, e = edge 12 | shuffled.append(sorted((mapping[s], mapping[e]))) 13 | return sorted(shuffled) 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser('Erdos-Renyi graph generator') 17 | parser.add_argument('-n', '--n_nodes', type=int, default=32, 18 | help="number of nodes for random graph") 19 | parser.add_argument('-p', '--prob', type=float, required=True, 20 | help="probablity of node connection for ER") 21 | parser.add_argument('-o', '--out_txt', type=str, required=True, 22 | help="name of output txt file") 23 | args = parser.parse_args() 24 | n, p = args.n_nodes, args.prob 25 | 26 | if p < math.log(n) / n: 27 | print("Warning: p is to small for given n.") 28 | print("This may make generated graph being disconnected.") 29 | 30 | edges = list() 31 | rand = np.random.uniform(0.0, 1.0, size=(n, n)) 32 | 33 | for i in range(n): 34 | for j in range(i+1, n): 35 | if rand[i][j] < p: 36 | edges.append((i, j)) 37 | 38 | edges = shuffle(n, edges) 39 | 40 | os.makedirs('generated', exist_ok=True) 41 | with open(os.path.join('generated', args.out_txt), 'w') as f: 42 | f.write(str(n) + '\n') 43 | f.write(str(len(edges)) + '\n') 44 | for edge in edges: 45 | f.write('%d %d\n' % (edge[0], edge[1])) 46 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import argparse 7 | import torchvision.transforms as transforms 8 | from torchvision import datasets 9 | 10 | from model.model import RandWire 11 | from utils.hparams import HParam 12 | from utils.graph_reader import read_graph 13 | from utils.evaluation import validate 14 | from dataset.dataloader import create_dataloader, MNIST_dataloader, CIFAR10_dataloader 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-c', '--config', type=str, required=True, 20 | help="yaml file for configuration") 21 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, required=False, 22 | help="path of checkpoint pt file") 23 | args = parser.parse_args() 24 | 25 | hp = HParam(args.config) 26 | graphs = [ 27 | read_graph(hp.model.graph0), 28 | read_graph(hp.model.graph1), 29 | read_graph(hp.model.graph2), 30 | ] 31 | print('Loading model from checkpoint...') 32 | model = RandWire(hp, graphs).cuda() 33 | checkpoint = torch.load(args.checkpoint_path) 34 | model.load_state_dict(checkpoint['model']) 35 | step = checkpoint['step'] 36 | 37 | dataset = hp.data.type 38 | switcher = { 39 | 'MNIST': MNIST_dataloader, 40 | 'CIFAR10':CIFAR10_dataloader, 41 | 'ImageNet':create_dataloader, 42 | } 43 | assert dataset in switcher.keys(), 'Dataset type currently not supported' 44 | dl_func = switcher[dataset] 45 | valset = dl_func(hp, args, False) 46 | 47 | print('Validating...') 48 | test_avg_loss, accuracy = validate(model, valset) 49 | 50 | print('Result on step %d:' % step) 51 | print('Average test loss: %.4f' % test_avg_loss) 52 | print('Accuracy: %.3f' % accuracy) 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | chkpt/ 3 | 4 | # config 5 | config/* 6 | !config/default.yaml 7 | !config/test.yaml 8 | 9 | # data 10 | cifar-100-python/ 11 | 12 | # generated graph text files 13 | # model/graphs/generated # specified in sub gitignore 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | -------------------------------------------------------------------------------- /utils/hparams.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/HarryVolek/PyTorch_Speaker_Verification 2 | 3 | import os 4 | import yaml 5 | 6 | 7 | def load_hparam_str(hp_str): 8 | os.makedirs('config', exist_ok=True) 9 | path = os.path.join('config', 'temp-restore.yaml') 10 | with open(path, 'w') as f: 11 | f.write(hp_str) 12 | return HParam(path) 13 | 14 | 15 | def load_hparam(filename): 16 | stream = open(filename, 'r') 17 | docs = yaml.load_all(stream, Loader=yaml.Loader) 18 | hparam_dict = dict() 19 | for doc in docs: 20 | for k, v in doc.items(): 21 | hparam_dict[k] = v 22 | return hparam_dict 23 | 24 | 25 | def merge_dict(user, default): 26 | if isinstance(user, dict) and isinstance(default, dict): 27 | for k, v in default.items(): 28 | if k not in user: 29 | user[k] = v 30 | else: 31 | user[k] = merge_dict(user[k], v) 32 | return user 33 | 34 | 35 | class Dotdict(dict): 36 | """ 37 | a dictionary that supports dot notation 38 | as well as dictionary access notation 39 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 40 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 41 | get attributes: d.val2 or d['val2'] 42 | """ 43 | __getattr__ = dict.__getitem__ 44 | __setattr__ = dict.__setitem__ 45 | __delattr__ = dict.__delitem__ 46 | 47 | def __init__(self, dct=None): 48 | dct = dict() if not dct else dct 49 | for key, value in dct.items(): 50 | if hasattr(value, 'keys'): 51 | value = Dotdict(value) 52 | self[key] = value 53 | 54 | 55 | class HParam(Dotdict): 56 | 57 | def __init__(self, file): 58 | super(Dotdict, self).__init__() 59 | hp_dict = load_hparam(file) 60 | hp_dotdict = Dotdict(hp_dict) 61 | for k, v in hp_dotdict.items(): 62 | setattr(self, k, v) 63 | 64 | __getattr__ = Dotdict.__getitem__ 65 | __setattr__ = Dotdict.__setitem__ 66 | __delattr__ = Dotdict.__delitem__ 67 | -------------------------------------------------------------------------------- /model/graphs/ws.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import numpy as np 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser('Erdos-Renyi graph generator') 9 | parser.add_argument('-n', '--n_nodes', type=int, default=32, 10 | help="number of nodes for random graph") 11 | parser.add_argument('-k', '--k_neighbors', type=int, required=True, 12 | help="connecting neighboring nodes for WS") 13 | parser.add_argument('-p', '--prob', type=float, required=True, 14 | help="probablity of rewiring for WS") 15 | parser.add_argument('-o', '--out_txt', type=str, required=True, 16 | help="name of output txt file") 17 | args = parser.parse_args() 18 | n, k, p = args.n_nodes, args.k_neighbors, args.prob 19 | 20 | assert k % 2 == 0, "k must be even." 21 | assert 0 < k < n, "k must be larger than 0 and smaller than n." 22 | 23 | adj = [[False]*n for _ in range(n)] # adjacency matrix 24 | for i in range(n): 25 | adj[i][i] = True 26 | 27 | # initial connection 28 | for i in range(n): 29 | for j in range(i-k//2, i+k//2+1): 30 | real_j = j % n 31 | if real_j == i: 32 | continue 33 | adj[real_j][i] = adj[i][real_j] = True 34 | 35 | rand = np.random.uniform(0.0, 1.0, size=(n, k//2)) 36 | for i in range(n): 37 | for j in range(1, k//2+1): # 'j' here is 'i' of paper's notation 38 | current = (i+j) % n 39 | if rand[i][j-1] < p: # rewire 40 | unoccupied = [x for x in range(n) if not adj[i][x]] 41 | rewired = np.random.choice(unoccupied) 42 | adj[i][current] = adj[current][i] = False 43 | adj[i][rewired] = adj[rewired][i] = True 44 | 45 | edges = list() 46 | for i in range(n): 47 | for j in range(i+1, n): 48 | if adj[i][j]: 49 | edges.append((i, j)) 50 | 51 | edges.sort() 52 | 53 | os.makedirs('generated', exist_ok=True) 54 | with open(os.path.join('generated', args.out_txt), 'w') as f: 55 | f.write(str(n) + '\n') 56 | f.write(str(len(edges)) + '\n') 57 | for edge in edges: 58 | f.write('%d %d\n' % (edge[0], edge[1])) 59 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .node import NodeOp 6 | from .dag_layer import DAGLayer 7 | from .sep_conv import SeparableConv2d 8 | 9 | 10 | class RandWire(nn.Module): 11 | def __init__(self, hp, graphs): 12 | super(RandWire, self).__init__() 13 | self.chn = hp.model.channel 14 | self.cls = hp.model.classes 15 | self.im = hp.model.input_maps 16 | # didn't used nn.Sequential for debugging purpose 17 | # self.conv1 = SeparableConv2d(1, self.chn//2, kernel_size=3, padding=1, stride=2) 18 | self.conv1 = nn.Conv2d(self.im, self.chn//2, kernel_size=3, padding=1, stride=2) 19 | self.bn1 = nn.BatchNorm2d(self.chn//2) 20 | # self.conv2 = SeparableConv2d(self.chn//2, self.chn, kernel_size=3, padding=1, stride=2) 21 | self.conv2 = nn.Conv2d(self.chn//2, self.chn, kernel_size=3, padding=1, stride=2) 22 | self.bn2 = nn.BatchNorm2d(self.chn) 23 | self.dagly3 = DAGLayer(self.chn, self.chn, graphs[0]['num_nodes'], graphs[0]['edges']) 24 | self.dagly4 = DAGLayer(self.chn, 2*self.chn, graphs[1]['num_nodes'], graphs[1]['edges']) 25 | self.dagly5 = DAGLayer(2*self.chn, 4*self.chn, graphs[2]['num_nodes'], graphs[2]['edges']) 26 | # self.convlast = SeparableConv2d(4*self.chn, 1280, kernel_size=1) 27 | self.convlast = nn.Conv2d(4*self.chn, 1280, kernel_size=1) 28 | self.bnlast = nn.BatchNorm2d(1280) 29 | self.fc = nn.Linear(1280, self.cls) 30 | 31 | def forward(self, y): 32 | # y: [B, im, 224, 224] 33 | # conv1 34 | y = self.conv1(y) # [B, chn//2, 112, 112] 35 | y = self.bn1(y) # [B, chn//2, 112, 112] 36 | 37 | # conv2 38 | y = F.relu(y) # [B, chn//2, 112, 112] 39 | y = self.conv2(y) # [B, chn, 56, 56] 40 | y = self.bn2(y) # [B, chn, 56, 56] 41 | 42 | # conv3, conv4, conv5 43 | y = self.dagly3(y) # [B, chn, 28, 28] 44 | y = self.dagly4(y) # [B, 2*chn, 14, 14] 45 | y = self.dagly5(y) # [B, 4*chn, 7, 7] 46 | 47 | # classifier 48 | y = F.relu(y) # [B, 4*chn, 7, 7] 49 | y = self.convlast(y) # [B, 1280, 7, 7] 50 | y = self.bnlast(y) # [B, 1280, 7, 7] 51 | y = F.adaptive_avg_pool2d(y, (1, 1)) # [B, 1280, 1, 1] 52 | y = y.view(y.size(0), -1) # [B, 1280] 53 | y = self.fc(y) # [B, cls] 54 | y = F.log_softmax(y, dim=1) # [B, cls] 55 | return y 56 | -------------------------------------------------------------------------------- /model/dag_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import deque 4 | 5 | from .node import NodeOp 6 | 7 | 8 | class DAGLayer(nn.Module): 9 | def __init__(self, in_channel, out_channel, num_nodes, edges): 10 | super(DAGLayer, self).__init__() 11 | self.num_nodes = num_nodes 12 | self.edges = edges 13 | 14 | self.adjlist = [[] for _ in range(num_nodes)] # adjacency list 15 | self.rev_adjlist = [[] for _ in range(num_nodes)] # reversed adjlist 16 | self.in_degree = [0 for _ in range(num_nodes)] 17 | self.out_degree = [0 for _ in range(num_nodes)] 18 | 19 | for s, e in edges: 20 | self.in_degree[e] += 1 21 | self.out_degree[s] += 1 22 | self.adjlist[s].append(e) 23 | self.rev_adjlist[e].append(s) 24 | 25 | self.input_nodes = [x for x in range(num_nodes) 26 | if self.in_degree[x] == 0] 27 | self.output_nodes = [x for x in range(num_nodes) 28 | if self.out_degree[x] == 0] 29 | assert len(self.input_nodes) > 0, '%d' % len(self.input_nodes) 30 | assert len(self.output_nodes) > 0, '%d' % len(self.output_nodes) 31 | 32 | for node in self.input_nodes: 33 | assert len(self.rev_adjlist[node]) == 0 34 | self.rev_adjlist[node].append(-1) 35 | 36 | self.nodes = nn.ModuleList([ 37 | NodeOp(in_degree=max(1, self.in_degree[x]), 38 | in_channel=in_channel if x in self.input_nodes else out_channel, 39 | out_channel=out_channel, 40 | stride=2 if x in self.input_nodes else 1) 41 | for x in range(num_nodes)]) 42 | 43 | def forward(self, y): 44 | # y: [B, C, N, M] 45 | outputs = [None for _ in range(self.num_nodes)] + [y] 46 | queue = deque(self.input_nodes) 47 | in_degree = self.in_degree.copy() 48 | 49 | while queue: 50 | now = queue.popleft() 51 | input_list = [outputs[x] for x in self.rev_adjlist[now]] 52 | feed = torch.stack(input_list, dim=-1) 53 | outputs[now] = self.nodes[now](feed) 54 | for v in self.adjlist[now]: 55 | in_degree[v] -= 1 56 | if in_degree[v] == 0: 57 | queue.append(v) 58 | 59 | out_list = [outputs[x] for x in self.output_nodes] 60 | return torch.mean(torch.stack(out_list), dim=0) # [B, C, N, M] 61 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import logging 5 | import argparse 6 | 7 | from utils.train import train 8 | from utils.hparams import HParam 9 | from utils.writer import MyWriter 10 | from utils.graph_reader import read_graph 11 | from dataset.dataloader import create_dataloader, MNIST_dataloader, CIFAR10_dataloader 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-c', '--config', type=str, required=True, 17 | help="yaml file for configuration") 18 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, required=False, 19 | help="path of checkpoint pt file") 20 | parser.add_argument('-m', '--model', type=str, required=True, 21 | help="name of the model. used for logging/saving checkpoints") 22 | args = parser.parse_args() 23 | 24 | hp = HParam(args.config) 25 | with open(args.config, 'r') as f: 26 | hp_str = ''.join(f.readlines()) 27 | 28 | pt_path = os.path.join('.', hp.log.chkpt_dir) 29 | out_dir = os.path.join(pt_path, args.model) 30 | os.makedirs(out_dir, exist_ok=True) 31 | 32 | log_dir = os.path.join('.', hp.log.log_dir) 33 | log_dir = os.path.join(log_dir, args.model) 34 | os.makedirs(log_dir, exist_ok=True) 35 | 36 | if args.checkpoint_path is not None: 37 | chkpt_path = args.checkpoint_path 38 | else: 39 | chkpt_path = None 40 | 41 | logging.basicConfig( 42 | level=logging.INFO, 43 | format='%(asctime)s - %(levelname)s - %(message)s', 44 | handlers=[ 45 | logging.FileHandler(os.path.join(log_dir, 46 | '%s-%d.log' % (args.model, time.time()))), 47 | logging.StreamHandler() 48 | ] 49 | ) 50 | logger = logging.getLogger() 51 | 52 | if hp.data.train == '' or hp.data.val == '': 53 | logger.error("hp.data.train, hp.data.val cannot be empty") 54 | raise Exception("Please specify directories of train data.") 55 | 56 | if hp.model.graph0 == '' or hp.model.graph1 == '' or hp.model.graph2 == '': 57 | logger.error("hp.model.graph0, graph1, graph2 cannot be empty") 58 | raise Exception("Please specify random DAG architecture.") 59 | 60 | graphs = [ 61 | read_graph(hp.model.graph0), 62 | read_graph(hp.model.graph1), 63 | read_graph(hp.model.graph2), 64 | ] 65 | 66 | writer = MyWriter(log_dir) 67 | 68 | dataset = hp.data.type 69 | switcher = { 70 | 'MNIST': MNIST_dataloader, 71 | 'CIFAR10':CIFAR10_dataloader, 72 | 'ImageNet':create_dataloader, 73 | } 74 | assert dataset in switcher.keys(), 'Dataset type currently not supported' 75 | dl_func = switcher[dataset] 76 | trainset = dl_func(hp, args, True) 77 | valset = dl_func(hp, args, False) 78 | 79 | train(out_dir, chkpt_path, trainset, valset, writer, logger, hp, hp_str, graphs) 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RandWireNN 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/exploring-randomly-wired-neural-networks-for/image-classification-imagenet-image-reco)](https://paperswithcode.com/sota/image-classification-imagenet-image-reco?p=exploring-randomly-wired-neural-networks-for) 4 | 5 | Unofficial PyTorch Implementation of: 6 | [Exploring Randomly Wired Neural Networks for Image Recognition](https://arxiv.org/abs/1904.01569). 7 | 8 | ![](./assets/teaser.png) 9 | 10 | ## Results 11 | 12 | Validation result on Imagenet(ILSVRC2012) dataset: 13 | 14 | | Top 1 accuracy (%) | Paper | Here | 15 | | -------------------------- | ----- | ---- | 16 | | RandWire-WS(4, 0.75), C=78 | 74.7 | 69.2 | 17 | 18 | 19 | - (2019.06.26) 69.2%: 250 epoch with SGD optimizer, lr 0.1, momentum 0.9, weight decay 5e-5, cosine annealing lr schedule (no label smoothing applied, see loss curve below) 20 | - (2019.04.14) 62.6%: 396k steps with SGD optimizer, lr 0.1, momentum 0.9, weigth decay 5e-5, lr decay about 0.1 at 300k 21 | - (2019.04.12) 62.6%: 416k steps with Adabound optimizer, initial lr 0.001(decayed about 0.1 at 300k), final lr 0.1, no weight decay 22 | - (2019.04) [JiaminRen's implementation](https://github.com/JiaminRen/RandWireNN) reached accuarcy which is almost close to paper, using identical training strategy with paper. 23 | - (2019.04.10) 63.0%: 450k steps with Adam optimizer, initial lr 0.001, lr decay about 0.1 for every 150k step 24 | - (2019.04.07) 56.8%: Training took about 16 hours on AWS p3.2xlarge(NVIDIA V100). 120k steps were done in total, and Adam optimizer with `lr=0.001, batch_size=128` was used with no learning rate decay. 25 | ![](./assets/fix11.png) 26 | 27 | ## Dependencies 28 | 29 | This code was tested on Python 3.6 with PyTorch 1.0.1. Other packages can be installed by: 30 | ```bash 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ## Generate random DAG 35 | 36 | ```bash 37 | cd model/graphs 38 | python er.py -p 0.2 -o er-02.txt # Erdos-Renyi 39 | python ba.py -m 7 -o ba-7.txt # Barbasi-Albert 40 | python ws.py -k 4 -p 0.75 ws-4-075.txt # Watts-Strogatz 41 | # number of nodes: -n option 42 | ``` 43 | 44 | All outputs from commands shown above will produce txt file like: 45 | ``` 46 | (number of nodes) 47 | (number of edges) 48 | (lines, each line representing edges) 49 | ``` 50 | 51 | ## Train RandWireNN 52 | 53 | 1. Download ImageNet dataset. Train/val folder should contain list of 1,000 directories, each containing list of images for corresponding category. For validation image files, this script can be useful: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh 54 | 1. Edit `config.yaml` 55 | ```bash 56 | cd config 57 | cp default.yaml config.yaml 58 | vim config.yaml # specify data directory, graph txt files 59 | ``` 60 | 1. Train 61 | 62 | *Note.* Validation performed here won't use entire test set, since it will consume much time. (about 3 min.) 63 | ``` 64 | python trainer.py -c [config yaml] -m [name] 65 | ``` 66 | 1. View tensorboardX 67 | ``` 68 | tensorboard --logdir ./logs 69 | ``` 70 | 71 | ## Validation 72 | 73 | Run full validation: 74 | 75 | ```bash 76 | python validation.py -c [config path] -p [checkpoint path] 77 | ``` 78 | 79 | This will show accuracy and average test loss of the trained model. 80 | 81 | 82 | ## Author 83 | 84 | Seungwon Park / [@seungwonpark](http://swpark.me) 85 | 86 | ## License 87 | 88 | Apache License 2.0 89 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import adabound 8 | import itertools 9 | import traceback 10 | 11 | from utils.hparams import load_hparam_str 12 | from utils.evaluation import validate 13 | from model.model import RandWire 14 | 15 | 16 | def train(out_dir, chkpt_path, trainset, valset, writer, logger, hp, hp_str, graphs): 17 | model = RandWire(hp, graphs).cuda() 18 | 19 | if hp.train.optimizer == 'adam': 20 | optimizer = torch.optim.Adam(model.parameters(), 21 | lr=hp.train.adam) 22 | elif hp.train.optimizer == 'adabound': 23 | optimizer = adabound.AdaBound(model.parameters(), 24 | lr=hp.train.adabound.initial, 25 | final_lr=hp.train.adabound.final) 26 | elif hp.train.optimizer == 'sgd': 27 | optimizer = torch.optim.SGD(model.parameters(), 28 | lr=hp.train.sgd.lr, 29 | momentum=hp.train.sgd.momentum, 30 | weight_decay=hp.train.sgd.weight_decay) 31 | else: 32 | raise Exception("Optimizer not supported: %s" % hp.train.optimizer) 33 | 34 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 35 | optimizer, hp.train.epoch) 36 | 37 | init_epoch = -1 38 | step = 0 39 | 40 | if chkpt_path is not None: 41 | logger.info("Resuming from checkpoint: %s" % chkpt_path) 42 | checkpoint = torch.load(chkpt_path) 43 | model.load_state_dict(checkpoint['model']) 44 | optimizer.load_state_dict(checkpoint['optimizer']) 45 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 46 | step = checkpoint['step'] 47 | init_epoch = checkpoint['epoch'] 48 | 49 | if hp_str != checkpoint['hp_str']: 50 | logger.warning("New hparams are different from checkpoint.") 51 | logger.warning("Will use new hparams.") 52 | # hp = load_hparam_str(hp_str) 53 | else: 54 | logger.info("Starting new training run") 55 | logger.info("Writing graph to tensorboardX...") 56 | writer.write_graph(model, torch.randn(7, hp.model.input_maps, 224, 224).cuda()) 57 | logger.info("Finished.") 58 | 59 | try: 60 | model.train() 61 | for epoch in itertools.count(init_epoch+1): 62 | loader = tqdm.tqdm(trainset, desc='Train data loader') 63 | for data, target in loader: 64 | data, target = data.cuda(), target.cuda() 65 | optimizer.zero_grad() 66 | output = model(data) 67 | loss = F.nll_loss(output, target) 68 | loss.backward() 69 | optimizer.step() 70 | 71 | loss = loss.item() 72 | if loss > 1e8 or math.isnan(loss): 73 | logger.error("Loss exploded to %.02f at step %d!" % (loss, step)) 74 | raise Exception("Loss exploded") 75 | 76 | writer.log_training(loss, step) 77 | loader.set_description('Loss %.02f at step %d' % (loss, step)) 78 | step += 1 79 | 80 | save_path = os.path.join(out_dir, 'chkpt_%03d.pt' % epoch) 81 | torch.save({ 82 | 'model': model.state_dict(), 83 | 'optimizer': optimizer.state_dict(), 84 | 'lr_scheduler': lr_scheduler.state_dict(), 85 | 'step': step, 86 | 'epoch': epoch, 87 | 'hp_str': hp_str, 88 | }, save_path) 89 | logger.info("Saved checkpoint to: %s" % save_path) 90 | 91 | validate(model, valset, writer, epoch) 92 | lr_scheduler.step() 93 | 94 | except Exception as e: 95 | logger.info("Exiting due to exception: %s" % e) 96 | traceback.print_exc() 97 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | # I followed ImageNet data loading convention shown in: 2 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 3 | 4 | import os 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | from torchvision import datasets 9 | 10 | 11 | def create_dataloader(hp, args, train): 12 | normalize = transforms.Normalize( 13 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 14 | 15 | if train: 16 | return torch.utils.data.DataLoader( 17 | datasets.ImageFolder(hp.data.train, transforms.Compose([ 18 | transforms.RandomResizedCrop(224), 19 | transforms.RandomHorizontalFlip(), 20 | transforms.ToTensor(), 21 | normalize, 22 | ])), 23 | batch_size=hp.data.batch_size, 24 | num_workers=hp.data.num_workers, 25 | shuffle=True, pin_memory=True, drop_last=True) 26 | else: 27 | return torch.utils.data.DataLoader( 28 | datasets.ImageFolder(hp.data.val, transforms.Compose([ 29 | transforms.Resize(256), 30 | transforms.CenterCrop(224), 31 | transforms.ToTensor(), 32 | normalize, 33 | ])), 34 | batch_size=hp.data.batch_size, 35 | num_workers=hp.data.num_workers, 36 | shuffle=False, pin_memory=True, drop_last=False) 37 | 38 | # MNIST data loading 39 | 40 | def MNIST_dataloader(hp, args, train): 41 | ''' 42 | :bs: int 43 | batch size of train and test dataloaders 44 | :download: bool 45 | whether to download a new copy of MNIST to ./data 46 | ''' 47 | root = './data' 48 | if not os.path.exists(root): 49 | os.mkdir(root) 50 | 51 | transf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) 52 | if train==True: 53 | train_set = datasets.MNIST(root=root, train=True, transform=transf, download=hp.data.train) 54 | trainloader = torch.utils.data.DataLoader( 55 | dataset=train_set, 56 | batch_size=hp.data.batch_size, 57 | shuffle=True) 58 | return trainloader 59 | else: 60 | test_set = datasets.MNIST(root=root, train=False, transform=transf, download=True) 61 | testloader = torch.utils.data.DataLoader( 62 | dataset=test_set, 63 | batch_size=129, 64 | shuffle=False) 65 | return testloader 66 | 67 | # CIFAR10 data loading 68 | 69 | def CIFAR10_dataloader(hp, args, train, path='.'): 70 | ''' 71 | :path: str 72 | path to raw CIFAR10, as found in https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 73 | :bs: int 74 | batch size 75 | ''' 76 | data1 = unpickle( f'{path}\data_batch_1') 77 | data2 = unpickle( f'{path}\data_batch_2') 78 | data3 = unpickle( f'{path}\data_batch_3') 79 | data4 = unpickle( f'{path}\data_batch_4') 80 | data5 = unpickle( f'{path}\data_batch_5') 81 | test = unpickle( f'{path}\\test_batch') 82 | 83 | ds = [] 84 | dlabels = [] 85 | test_ds = [] 86 | test_dlabels = [] 87 | 88 | for i in range(10000): 89 | im = np.reshape(data1[b'data'][i],(3, 32, 32)) 90 | ds.append(im) 91 | dlabels.append(data1[b'labels'][i]) 92 | for i in range(10000): 93 | im = np.reshape(data2[b'data'][i],(3, 32, 32)) 94 | ds.append(im) 95 | dlabels.append(data2[b'labels'][i]) 96 | for i in range(10000): 97 | im = np.reshape(data3[b'data'][i],(3, 32, 32)) 98 | ds.append(im) 99 | dlabels.append(data3[b'labels'][i]) 100 | for i in range(10000): 101 | im = np.reshape(data4[b'data'][i],(3, 32, 32)) 102 | ds.append(im) 103 | dlabels.append(data4[b'labels'][i]) 104 | for i in range(10000): 105 | im = np.reshape(data5[b'data'][i],(3, 32, 32)) 106 | ds.append(im) 107 | dlabels.append(data5[b'labels'][i]) 108 | for i in range(10000): 109 | im = np.reshape(test[b'data'][i],(3, 32, 32)) 110 | test_ds.append(im) 111 | test_dlabels.append(test[b'labels'][i]) 112 | 113 | train = torch.utils.data.TensorDataset(torch.Tensor(ds), torch.LongTensor(dlabels)) 114 | test = torch.utils.data.TensorDataset(torch.Tensor(test_ds), torch.LongTensor(test_dlabels)) 115 | 116 | trainloader = torch.utils.data.DataLoader(train, batch_size = hp.data.batch_size) 117 | testloader = torch.utils.data.DataLoader(test, batch_size = hp.data.batch_size) 118 | 119 | if train==True: return trainloader 120 | else: return testloader 121 | 122 | def unpickle(file): 123 | import pickle 124 | with open(file, 'rb') as fo: 125 | dict = pickle.load(fo, encoding='bytes') 126 | return dict --------------------------------------------------------------------------------