├── .gitignore ├── README.md ├── archs └── examples │ └── sbms_cluster.yaml ├── data ├── COLLAB.py ├── QM9.py ├── SBMs.py ├── SBMs │ ├── generate_SBM_CLUSTER.ipynb │ └── generate_SBM_PATTERN.ipynb ├── TSP.py ├── TSP │ ├── generate_TSP.py │ └── prepare_TSP.ipynb ├── __init__.py ├── cora.py ├── molecules.py ├── molecules │ └── prepare_molecules.ipynb ├── superpixels.py └── superpixels │ ├── prepare_superpixels_CIFAR.ipynb │ └── prepare_superpixels_MNIST.ipynb ├── environment_gpu.yml ├── example_geno.yaml ├── models ├── architect.py ├── cell_search.py ├── cell_train.py ├── mixed.py ├── model_search.py ├── model_train.py ├── networks.py └── operations.py ├── scripts ├── search_molecules_zinc.sh ├── search_sbms_cluster.sh ├── search_sbms_pattern.sh ├── search_superpixels_cifar10.sh ├── search_superpixels_mnist.sh ├── train_molecules_zinc.sh ├── train_sbms_cluster.sh ├── train_sbms_pattern.sh ├── train_superpixels_cifar10.sh └── train_superpixels_mnist.sh ├── search.py ├── train.py └── utils ├── record_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | /dataset 132 | .vscode 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | 143 | debug.ipynb 144 | *.png 145 | *.txt 146 | *.gz 147 | *.pt 148 | *.pkl 149 | *.json 150 | *.pickle 151 | *.pkl 152 | *.index 153 | *.zip 154 | *.pdf 155 | pics 156 | runs 157 | archs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 🔥News🔥 3 | Hi, if you like this work, you may be interested in our new work, too. 4 | 5 | Welcome to our latest work [Automatic Relation-aware Graph Network Proliferation](https://github.com/phython96/ARGNP). 6 | 7 | **This work has been accepted by CVPR2022 and selected as an ORAL presentation.** 8 | 9 | In the latest work, we have achieved state-of-the-art results on SBM_CLUSTER (77.35% OA), ZINC_100k (0.136 MAE), CIFAR10 (73.90% OA), TSP (0.855 F1-score) datasets and so on. 10 | 11 | ### What's new? 12 | 13 | 1. **We devise a novel dual relation-aware graph search space that comprises both node and relation learning operations.** 14 | So, the ARGNP can leverage the edge attributes in some datasets, such as ZINC. 15 | It significantly improves the graph representative capability. 16 | Interestingly, we also observe the performance improvement even if there is no available edge attributes in some datasets. 17 | 18 | 2. **We design a network proliferation search paradigm (NPSP) to progressively determine the GNN architectures by iteratively performing network division and differentiation.** 19 | The network proliferation search paradigm decomposes the training of global supernet into sequential local supernets optimization, which alleviates the interference among child graph neural architectures. It reduces the spatial-time complexity from quadratic to linear and enables the search to thoroughly free from the cell-sharing trick. 20 | 21 | 3. **Our framework is suitable for solving node-level, edge-level, and graph-level tasks. The codes are easy to use.** 22 | 23 | --- 24 | 25 | # Rethinking Graph Neural Architecture Search from Message-passing 26 | 27 | 28 | 29 | 30 | 31 | ## Getting Started 32 | 33 | ### 0. Prerequisites 34 | 35 | + Linux 36 | + NVIDIA GPU + CUDA CuDNN 37 | 38 | ### 1. Setup Python Environment 39 | 40 | ```sh 41 | # clone Github repo 42 | conda install git 43 | git clone https://github.com/phython96/GNAS-MP.git 44 | cd GNAS-MP 45 | 46 | # Install python environment 47 | conda env create -f environment_gpu.yml 48 | conda activate gnasmp 49 | ``` 50 | 51 | ### 2. Download datasets 52 | 53 | The datasets are provided by project [benchmarking-gnns](https://github.com/graphdeeplearning/benchmarking-gnns), you can click [here](https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/docs/02_download_datasets.md) to download all the required datasets. 54 | 55 | ### 3. Search Architectures 56 | 57 | ```sh 58 | sh scripts/search_molecules_zinc.sh [gpu_id] 59 | ``` 60 | 61 | ### 4. Train & Test 62 | 63 | ``` 64 | sh scripts/train_molecules_zinc.sh [gpu_id] '[path_to_genotypes]/example.yaml' 65 | ``` 66 | 67 | ## Reference 68 | ```latex 69 | @inproceedings{cai2021rethinking, 70 | title={Rethinking Graph Neural Architecture Search from Message-passing}, 71 | author={Cai, Shaofei and Li, Liang and Deng, Jincan and Zhang, Beichen and Zha, Zheng-Jun and Su, Li and Huang, Qingming}, 72 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 73 | pages={6657--6666}, 74 | year={2021} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /archs/examples/sbms_cluster.yaml: -------------------------------------------------------------------------------- 1 | Genotype: 2 | - id: 0 3 | topology: 4 | - dst: 1 5 | ops: V_Dense 6 | src: 0 7 | - dst: 2 8 | ops: V_Sparse 9 | src: 1 10 | - dst: 3 11 | ops: V_Sum 12 | src: 1 13 | - dst: 4 14 | ops: V_Sum 15 | src: 2 16 | - dst: 5 17 | ops: V_Dense 18 | src: 3 19 | - dst: 6 20 | ops: V_Sparse 21 | src: 3 22 | - id: 1 23 | topology: 24 | - dst: 1 25 | ops: V_Sparse 26 | src: 0 27 | - dst: 2 28 | ops: V_I 29 | src: 0 30 | - dst: 3 31 | ops: V_Sum 32 | src: 1 33 | - dst: 4 34 | ops: V_Sum 35 | src: 2 36 | - dst: 5 37 | ops: V_Sparse 38 | src: 3 39 | - dst: 6 40 | ops: V_Dense 41 | src: 4 42 | - id: 2 43 | topology: 44 | - dst: 1 45 | ops: V_Sparse 46 | src: 0 47 | - dst: 2 48 | ops: V_I 49 | src: 0 50 | - dst: 3 51 | ops: V_Sum 52 | src: 1 53 | - dst: 4 54 | ops: V_Sum 55 | src: 2 56 | - dst: 5 57 | ops: V_Dense 58 | src: 4 59 | - dst: 6 60 | ops: V_Sparse 61 | src: 3 62 | - id: 3 63 | topology: 64 | - dst: 1 65 | ops: V_Dense 66 | src: 0 67 | - dst: 2 68 | ops: V_Dense 69 | src: 0 70 | - dst: 3 71 | ops: V_Sum 72 | src: 1 73 | - dst: 4 74 | ops: V_Sum 75 | src: 2 76 | - dst: 5 77 | ops: V_Sparse 78 | src: 3 79 | - dst: 6 80 | ops: V_Dense 81 | src: 4 82 | -------------------------------------------------------------------------------- /data/COLLAB.py: -------------------------------------------------------------------------------- 1 | import time 2 | import dgl 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from ogb.linkproppred import DglLinkPropPredDataset, Evaluator 7 | 8 | from scipy import sparse as sp 9 | import numpy as np 10 | 11 | 12 | def positional_encoding(g, pos_enc_dim): 13 | """ 14 | Graph positional encoding v/ Laplacian eigenvectors 15 | """ 16 | 17 | # Laplacian 18 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) 19 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) 20 | L = sp.eye(g.number_of_nodes()) - N * A * N 21 | 22 | # # Eigenvectors with numpy 23 | # EigVal, EigVec = np.linalg.eig(L.toarray()) 24 | # idx = EigVal.argsort() # increasing order 25 | # EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) 26 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float() 27 | 28 | # Eigenvectors with scipy 29 | #EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR') 30 | EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR', tol=1e-2) 31 | EigVec = EigVec[:, EigVal.argsort()] # increasing order 32 | g.ndata['pos_enc'] = torch.from_numpy(np.real(EigVec[:,1:pos_enc_dim+1])).float() 33 | 34 | return g 35 | 36 | 37 | class COLLABDataset(Dataset): 38 | def __init__(self, name): 39 | start = time.time() 40 | print("[I] Loading dataset %s..." % (name)) 41 | self.name = name 42 | self.dataset = DglLinkPropPredDataset(name='ogbl-collab') 43 | 44 | self.graph = self.dataset[0] # single DGL graph 45 | 46 | # Create edge feat by concatenating weight and year 47 | self.graph.edata['feat'] = torch.cat( 48 | [self.graph.edata['weight'], self.graph.edata['year']], 49 | dim=1 50 | ) 51 | 52 | self.split_edge = self.dataset.get_edge_split() 53 | self.train_edges = self.split_edge['train']['edge'] # positive train edges 54 | self.val_edges = self.split_edge['valid']['edge'] # positive val edges 55 | self.val_edges_neg = self.split_edge['valid']['edge_neg'] # negative val edges 56 | self.test_edges = self.split_edge['test']['edge'] # positive test edges 57 | self.test_edges_neg = self.split_edge['test']['edge_neg'] # negative test edges 58 | 59 | self.evaluator = Evaluator(name='ogbl-collab') 60 | 61 | print("[I] Finished loading.") 62 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 63 | 64 | def _add_positional_encodings(self, pos_enc_dim): 65 | 66 | # Graph positional encoding v/ Laplacian eigenvectors 67 | self.graph = positional_encoding(self.graph, pos_enc_dim) -------------------------------------------------------------------------------- /data/QM9.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import dgl.data 4 | 5 | class QM9DGL(torch.utils.data.Dataset): 6 | 7 | def __init__(self, graph_list): 8 | self.graph_list = graph_list 9 | 10 | def __getitem__(self, i): 11 | return self.graph_list[i] 12 | 13 | def __len__(self): 14 | return len(self.graph_list) 15 | 16 | class QM9Dataset(torch.utils.data.Dataset): 17 | 18 | def __init__(self, name, target): 19 | """ 20 | Loading QM9 Dataset 21 | """ 22 | start = time.time() 23 | print("[I] Loading dataset %s..." % (name)) 24 | self.name = name 25 | self.target = target 26 | self.data = dgl.data.QM9EdgeDataset([target]) 27 | graph_list = [] 28 | for i in range(len(self.data)): 29 | graph, label = self.data.__getitem__(i) 30 | #graph.ndata['feat'] = torch.cat([graph.ndata['pos'], graph.ndata['attr']], dim = -1) 31 | graph.ndata['feat'] = graph.ndata['attr'] 32 | graph.edata['feat'] = graph.edata['edge_attr'] 33 | graph_list.append((graph, label)) 34 | 35 | self.train = QM9DGL(graph_list[:110000]) 36 | self.val = QM9DGL(graph_list[110000:120000]) 37 | self.test = QM9DGL(graph_list[120000:]) 38 | 39 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) 40 | print("[I] Finished loading.") 41 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 42 | 43 | 44 | def collate(self, samples): 45 | # The input samples is a list of pairs (graph, label). 46 | graphs, labels = map(list, zip(*samples)) 47 | batched_graph = dgl.batch(graphs) 48 | labels = torch.cat(labels, dim = 0) 49 | return batched_graph, labels 50 | 51 | 52 | if __name__ == '__main__': 53 | dataset = QM9Dataset('QM9', 'mu') 54 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /data/SBMs.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import pickle 4 | import numpy as np 5 | 6 | import dgl 7 | import torch 8 | 9 | from scipy import sparse as sp 10 | import numpy as np 11 | 12 | 13 | 14 | class load_SBMsDataSetDGL(torch.utils.data.Dataset): 15 | 16 | def __init__(self, 17 | data_dir, 18 | name, 19 | split): 20 | 21 | self.split = split 22 | self.is_test = split.lower() in ['test', 'val'] 23 | with open(os.path.join(data_dir, name + '_%s.pkl' % self.split), 'rb') as f: 24 | self.dataset = pickle.load(f) 25 | self.node_labels = [] 26 | self.graph_lists = [] 27 | self.n_samples = len(self.dataset) 28 | self._prepare() 29 | 30 | 31 | def _prepare(self): 32 | 33 | print("preparing %d graphs for the %s set..." % (self.n_samples, self.split.upper())) 34 | 35 | for data in self.dataset: 36 | 37 | node_features = data.node_feat 38 | edge_list = (data.W != 0).nonzero() # converting adj matrix to edge_list 39 | 40 | # Create the DGL Graph 41 | g = dgl.DGLGraph() 42 | g.add_nodes(node_features.size(0)) 43 | g.ndata['feat'] = node_features.long() 44 | for src, dst in edge_list: 45 | g.add_edges(src.item(), dst.item()) 46 | 47 | # adding edge features for Residual Gated ConvNet 48 | #edge_feat_dim = g.ndata['feat'].size(1) # dim same as node feature dim 49 | edge_feat_dim = 1 # dim same as node feature dim 50 | g.edata['feat'] = torch.ones(g.number_of_edges(), edge_feat_dim) 51 | 52 | self.graph_lists.append(g) 53 | self.node_labels.append(data.node_label) 54 | 55 | 56 | def __len__(self): 57 | """Return the number of graphs in the dataset.""" 58 | return self.n_samples 59 | 60 | def __getitem__(self, idx): 61 | """ 62 | Get the idx^th sample. 63 | Parameters 64 | --------- 65 | idx : int 66 | The sample index. 67 | Returns 68 | ------- 69 | (dgl.DGLGraph, int) 70 | DGLGraph with node feature stored in `feat` field 71 | And its label. 72 | """ 73 | return self.graph_lists[idx], self.node_labels[idx] 74 | 75 | 76 | class SBMsDatasetDGL(torch.utils.data.Dataset): 77 | 78 | def __init__(self, name): 79 | """ 80 | TODO 81 | """ 82 | start = time.time() 83 | print("[I] Loading data ...") 84 | self.name = name 85 | data_dir = 'data/SBMs' 86 | self.train = load_SBMsDataSetDGL(data_dir, name, split='train') 87 | self.test = load_SBMsDataSetDGL(data_dir, name, split='test') 88 | self.val = load_SBMsDataSetDGL(data_dir, name, split='val') 89 | print("[I] Finished loading.") 90 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 91 | 92 | 93 | 94 | 95 | def self_loop(g): 96 | """ 97 | Utility function only, to be used only when necessary as per user self_loop flag 98 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat'] 99 | 100 | 101 | This function is called inside a function in SBMsDataset class. 102 | """ 103 | new_g = dgl.DGLGraph() 104 | new_g.add_nodes(g.number_of_nodes()) 105 | new_g.ndata['feat'] = g.ndata['feat'] 106 | 107 | src, dst = g.all_edges(order="eid") 108 | src = dgl.backend.zerocopy_to_numpy(src) 109 | dst = dgl.backend.zerocopy_to_numpy(dst) 110 | non_self_edges_idx = src != dst 111 | nodes = np.arange(g.number_of_nodes()) 112 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx]) 113 | new_g.add_edges(nodes, nodes) 114 | 115 | # This new edata is not used since this function gets called only for GCN, GAT 116 | # However, we need this for the generic requirement of ndata and edata 117 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges()) 118 | return new_g 119 | 120 | 121 | 122 | def positional_encoding(g, pos_enc_dim): 123 | """ 124 | Graph positional encoding v/ Laplacian eigenvectors 125 | """ 126 | 127 | # Laplacian 128 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) 129 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) 130 | L = sp.eye(g.number_of_nodes()) - N * A * N 131 | 132 | # # Eigenvectors with numpy 133 | # EigVal, EigVec = np.linalg.eig(L.toarray()) 134 | # idx = EigVal.argsort() # increasing order 135 | # EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) 136 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float() 137 | 138 | # Eigenvectors with scipy 139 | #EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR') 140 | EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR', tol=1e-2) # for 40 PEs 141 | EigVec = EigVec[:, EigVal.argsort()] # increasing order 142 | g.ndata['pos_enc'] = torch.from_numpy(np.real(EigVec[:,1:pos_enc_dim+1])).float() 143 | 144 | return g 145 | 146 | 147 | 148 | class SBMsDataset(torch.utils.data.Dataset): 149 | 150 | def __init__(self, name): 151 | """ 152 | Loading SBM datasets 153 | """ 154 | start = time.time() 155 | print("[I] Loading dataset %s..." % (name)) 156 | self.name = name 157 | data_dir = 'data/SBMs/' 158 | with open(data_dir+name+'.pkl',"rb") as f: 159 | f = pickle.load(f) 160 | self.train = f[0] 161 | self.val = f[1] 162 | self.test = f[2] 163 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) 164 | print("[I] Finished loading.") 165 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 166 | 167 | 168 | # form a mini batch from a given list of samples = [(graph, label) pairs] 169 | def collate(self, samples): 170 | # The input samples is a list of pairs (graph, label). 171 | graphs, labels = map(list, zip(*samples)) 172 | labels = torch.cat(labels).long() 173 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 174 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 175 | #snorm_n = torch.cat(tab_snorm_n).sqrt() 176 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))] 177 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ] 178 | #snorm_e = torch.cat(tab_snorm_e).sqrt() 179 | batched_graph = dgl.batch(graphs) 180 | 181 | return batched_graph, labels 182 | 183 | # prepare dense tensors for GNNs which use; such as RingGNN and 3WLGNN 184 | def collate_dense_gnn(self, samples): 185 | # The input samples is a list of pairs (graph, label). 186 | graphs, labels = map(list, zip(*samples)) 187 | labels = torch.cat(labels).long() 188 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 189 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 190 | #snorm_n = tab_snorm_n[0][0].sqrt() 191 | 192 | #batched_graph = dgl.batch(graphs) 193 | 194 | g = graphs[0] 195 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense()) 196 | """ 197 | Adapted from https://github.com/leichen2018/Ring-GNN/ 198 | Assigning node and edge feats:: 199 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}. 200 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix. 201 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i. 202 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j). 203 | """ 204 | 205 | zero_adj = torch.zeros_like(adj) 206 | 207 | if self.name == 'SBM_CLUSTER': 208 | self.num_node_type = 7 209 | elif self.name == 'SBM_PATTERN': 210 | self.num_node_type = 3 211 | 212 | # use node feats to prepare adj 213 | adj_node_feat = torch.stack([zero_adj for j in range(self.num_node_type)]) 214 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0) 215 | 216 | for node, node_label in enumerate(g.ndata['feat']): 217 | adj_node_feat[node_label.item()+1][node][node] = 1 218 | 219 | x_node_feat = adj_node_feat.unsqueeze(0) 220 | 221 | return x_node_feat, labels 222 | 223 | def _sym_normalize_adj(self, adj): 224 | deg = torch.sum(adj, dim = 0)#.squeeze() 225 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size())) 226 | deg_inv = torch.diag(deg_inv) 227 | return torch.mm(deg_inv, torch.mm(adj, deg_inv)) 228 | 229 | 230 | def _add_self_loops(self): 231 | 232 | # function for adding self loops 233 | # this function will be called only if self_loop flag is True 234 | 235 | self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists] 236 | self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists] 237 | self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists] 238 | 239 | 240 | def _add_positional_encodings(self, pos_enc_dim): 241 | 242 | # Graph positional encoding v/ Laplacian eigenvectors 243 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists] 244 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists] 245 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists] -------------------------------------------------------------------------------- /data/TSP.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import numpy as np 4 | import itertools 5 | from scipy.spatial.distance import pdist, squareform 6 | 7 | import dgl 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | from scipy import sparse as sp 12 | 13 | class TSP(Dataset): 14 | def __init__(self, data_dir, split="train", num_neighbors=25, max_samples=10000): 15 | self.data_dir = data_dir 16 | self.split = split 17 | self.filename = f'{data_dir}/tsp50-500_{split}.txt' 18 | self.max_samples = max_samples 19 | self.num_neighbors = num_neighbors 20 | self.is_test = split.lower() in ['test', 'val'] 21 | 22 | self.graph_lists = [] 23 | self.edge_labels = [] 24 | self._prepare() 25 | self.n_samples = len(self.edge_labels) 26 | 27 | def _prepare(self): 28 | print('preparing all graphs for the %s set...' % self.split.upper()) 29 | 30 | file_data = open(self.filename, "r").readlines()[:self.max_samples] 31 | 32 | for graph_idx, line in enumerate(file_data): 33 | line = line.split(" ") # Split into list 34 | num_nodes = int(line.index('output')//2) 35 | 36 | # Convert node coordinates to required format 37 | nodes_coord = [] 38 | for idx in range(0, 2 * num_nodes, 2): 39 | nodes_coord.append([float(line[idx]), float(line[idx + 1])]) 40 | 41 | # Compute distance matrix 42 | W_val = squareform(pdist(nodes_coord, metric='euclidean')) 43 | # Determine k-nearest neighbors for each node 44 | knns = np.argpartition(W_val, kth=self.num_neighbors, axis=-1)[:, self.num_neighbors::-1] 45 | 46 | # Convert tour nodes to required format 47 | # Don't add final connection for tour/cycle 48 | tour_nodes = [int(node) - 1 for node in line[line.index('output') + 1:-1]][:-1] 49 | 50 | # Compute an edge adjacency matrix representation of tour 51 | edges_target = np.zeros((num_nodes, num_nodes)) 52 | for idx in range(len(tour_nodes) - 1): 53 | i = tour_nodes[idx] 54 | j = tour_nodes[idx + 1] 55 | edges_target[i][j] = 1 56 | edges_target[j][i] = 1 57 | # Add final connection of tour in edge target 58 | edges_target[j][tour_nodes[0]] = 1 59 | edges_target[tour_nodes[0]][j] = 1 60 | 61 | # Construct the DGL graph 62 | g = dgl.DGLGraph() 63 | g.add_nodes(num_nodes) 64 | g.ndata['feat'] = torch.Tensor(nodes_coord) 65 | 66 | edge_feats = [] # edge features i.e. euclidean distances between nodes 67 | edge_labels = [] # edges_targets as a list 68 | # Important!: order of edge_labels must be the same as the order of edges in DGLGraph g 69 | # We ensure this by adding them together 70 | for idx in range(num_nodes): 71 | for n_idx in knns[idx]: 72 | if n_idx != idx: # No self-connection 73 | g.add_edge(idx, n_idx) 74 | edge_feats.append(W_val[idx][n_idx]) 75 | edge_labels.append(int(edges_target[idx][n_idx])) 76 | # dgl.transform.remove_self_loop(g) 77 | 78 | # Sanity check 79 | assert len(edge_feats) == g.number_of_edges() == len(edge_labels) 80 | 81 | # Add edge features 82 | g.edata['feat'] = torch.Tensor(edge_feats).unsqueeze(-1) 83 | 84 | # # Uncomment to add dummy edge features instead (for Residual Gated ConvNet) 85 | # edge_feat_dim = g.ndata['feat'].shape[1] # dim same as node feature dim 86 | # g.edata['feat'] = torch.ones(g.number_of_edges(), edge_feat_dim) 87 | 88 | self.graph_lists.append(g) 89 | self.edge_labels.append(edge_labels) 90 | 91 | def __len__(self): 92 | """Return the number of graphs in the dataset.""" 93 | return self.n_samples 94 | 95 | def __getitem__(self, idx): 96 | """ 97 | Get the idx^th sample. 98 | Parameters 99 | --------- 100 | idx : int 101 | The sample index. 102 | Returns 103 | ------- 104 | (dgl.DGLGraph, list) 105 | DGLGraph with node feature stored in `feat` field 106 | And a list of labels for each edge in the DGLGraph. 107 | """ 108 | return self.graph_lists[idx], self.edge_labels[idx] 109 | 110 | 111 | class TSPDatasetDGL(Dataset): 112 | def __init__(self, name): 113 | self.name = name 114 | self.train = TSP(data_dir='./data/TSP', split='train', num_neighbors=25, max_samples=10000) 115 | self.val = TSP(data_dir='./data/TSP', split='val', num_neighbors=25, max_samples=1000) 116 | self.test = TSP(data_dir='./data/TSP', split='test', num_neighbors=25, max_samples=1000) 117 | 118 | 119 | def positional_encoding(g, pos_enc_dim): 120 | """ 121 | Graph positional encoding v/ Laplacian eigenvectors 122 | """ 123 | 124 | # Laplacian 125 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) 126 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) 127 | L = sp.eye(g.number_of_nodes()) - N * A * N 128 | 129 | # Eigenvectors with numpy 130 | EigVal, EigVec = np.linalg.eig(L.toarray()) 131 | idx = EigVal.argsort() # increasing order 132 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) 133 | g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float() 134 | 135 | # # Eigenvectors with scipy 136 | # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR') 137 | # EigVec = EigVec[:, EigVal.argsort()] # increasing order 138 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float() 139 | 140 | return g 141 | 142 | 143 | class TSPDataset(Dataset): 144 | def __init__(self, name): 145 | start = time.time() 146 | print("[I] Loading dataset %s..." % (name)) 147 | self.name = name 148 | data_dir = 'data/TSP/' 149 | with open(data_dir+name+'.pkl',"rb") as f: 150 | f = pickle.load(f) 151 | self.train = f[0] 152 | self.test = f[1] 153 | self.val = f[2] 154 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) 155 | print("[I] Finished loading.") 156 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 157 | 158 | # form a mini batch from a given list of samples = [(graph, label) pairs] 159 | def collate(self, samples): 160 | # The input samples is a list of pairs (graph, label). 161 | graphs, labels = map(list, zip(*samples)) 162 | # Edge classification labels need to be flattened to 1D lists 163 | labels = torch.LongTensor(np.array(list(itertools.chain(*labels)))) 164 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 165 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 166 | #snorm_n = torch.cat(tab_snorm_n).sqrt() 167 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))] 168 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ] 169 | #snorm_e = torch.cat(tab_snorm_e).sqrt() 170 | batched_graph = dgl.batch(graphs) 171 | 172 | return batched_graph, labels 173 | 174 | 175 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN 176 | def collate_dense_gnn(self, samples, edge_feat): 177 | # The input samples is a list of pairs (graph, label). 178 | graphs, labels = map(list, zip(*samples)) 179 | # Edge classification labels need to be flattened to 1D lists 180 | labels = torch.LongTensor(np.array(list(itertools.chain(*labels)))) 181 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 182 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 183 | #snorm_n = tab_snorm_n[0][0].sqrt() 184 | 185 | #batched_graph = dgl.batch(graphs) 186 | 187 | g = graphs[0] 188 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense()) 189 | """ 190 | Adapted from https://github.com/leichen2018/Ring-GNN/ 191 | Assigning node and edge feats:: 192 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}. 193 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix. 194 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i. 195 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j). 196 | """ 197 | 198 | zero_adj = torch.zeros_like(adj) 199 | 200 | in_node_dim = g.ndata['feat'].shape[1] 201 | in_edge_dim = g.edata['feat'].shape[1] 202 | 203 | if edge_feat: 204 | # use edge feats also to prepare adj 205 | adj_with_edge_feat = torch.stack([zero_adj for j in range(in_node_dim + in_edge_dim)]) 206 | adj_with_edge_feat = torch.cat([adj.unsqueeze(0), adj_with_edge_feat], dim=0) 207 | 208 | us, vs = g.edges() 209 | for idx, edge_feat in enumerate(g.edata['feat']): 210 | adj_with_edge_feat[1+in_node_dim:, us[idx], vs[idx]] = edge_feat 211 | 212 | for node, node_feat in enumerate(g.ndata['feat']): 213 | adj_with_edge_feat[1:1+in_node_dim, node, node] = node_feat 214 | 215 | x_with_edge_feat = adj_with_edge_feat.unsqueeze(0) 216 | 217 | return None, x_with_edge_feat, labels, g.edges() 218 | else: 219 | # use only node feats to prepare adj 220 | adj_no_edge_feat = torch.stack([zero_adj for j in range(in_node_dim)]) 221 | adj_no_edge_feat = torch.cat([adj.unsqueeze(0), adj_no_edge_feat], dim=0) 222 | 223 | for node, node_feat in enumerate(g.ndata['feat']): 224 | adj_no_edge_feat[1:1+in_node_dim, node, node] = node_feat 225 | 226 | x_no_edge_feat = adj_no_edge_feat.unsqueeze(0) 227 | 228 | return x_no_edge_feat, None, labels, g.edges() 229 | 230 | def _sym_normalize_adj(self, adj): 231 | deg = torch.sum(adj, dim = 0)#.squeeze() 232 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size())) 233 | deg_inv = torch.diag(deg_inv) 234 | return torch.mm(deg_inv, torch.mm(adj, deg_inv)) 235 | 236 | 237 | def _add_self_loops(self): 238 | """ 239 | No self-loop support since TSP edge classification dataset. 240 | """ 241 | raise NotImplementedError 242 | 243 | 244 | def _add_positional_encodings(self, pos_enc_dim): 245 | 246 | # Graph positional encoding v/ Laplacian eigenvectors 247 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists] 248 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists] 249 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists] 250 | 251 | if __name__ == '__main__': 252 | tsp = TSPDataset('TSP') 253 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /data/TSP/generate_TSP.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import pprint as pp 4 | import os 5 | 6 | import numpy as np 7 | from concorde.tsp import TSPSolver # Install from https://github.com/jvkersch/pyconcorde 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--min_nodes", type=int, default=50) 13 | parser.add_argument("--max_nodes", type=int, default=500) 14 | parser.add_argument("--num_samples", type=int, default=10000) 15 | parser.add_argument("--filename", type=str, default=None) 16 | parser.add_argument("--node_dim", type=int, default=2) 17 | parser.add_argument("--seed", type=int, default=1234) 18 | opts = parser.parse_args() 19 | 20 | if opts.filename is None: 21 | opts.filename = f"tsp{opts.min_nodes}-{opts.max_nodes}.txt" 22 | 23 | # Pretty print the run args 24 | pp.pprint(vars(opts)) 25 | 26 | np.random.seed(opts.seed) 27 | 28 | with open(opts.filename, "w") as f: 29 | start_time = time.time() 30 | idx = 0 31 | while idx < opts.num_samples: 32 | num_nodes = np.random.randint(low=opts.min_nodes, high=opts.max_nodes+1) 33 | 34 | nodes_coord = np.random.random([num_nodes, opts.node_dim]) 35 | solver = TSPSolver.from_data(nodes_coord[:, 0], nodes_coord[:, 1], norm="GEO") 36 | solution = solver.solve() 37 | 38 | # Only write instances with valid solutions 39 | if (np.sort(solution.tour) == np.arange(num_nodes)).all(): 40 | f.write( " ".join( str(x)+str(" ")+str(y) for x,y in nodes_coord) ) 41 | f.write( str(" ") + str('output') + str(" ") ) 42 | f.write( str(" ").join( str(node_idx+1) for node_idx in solution.tour) ) 43 | f.write( str(" ") + str(solution.tour[0]+1) + str(" ") ) 44 | f.write( "\n" ) 45 | idx += 1 46 | 47 | end_time = time.time() - start_time 48 | 49 | print(f"Completed generation of {opts.num_samples} samples of TSP{opts.min_nodes}-{opts.max_nodes}.") 50 | print(f"Total time: {end_time/60:.1f}m") 51 | print(f"Average time: {end_time/opts.num_samples:.1f}s") 52 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | from data.molecules import MoleculeDataset 5 | from data.QM9 import QM9Dataset 6 | from data.SBMs import SBMsDataset 7 | from data.TSP import TSPDataset 8 | from data.superpixels import SuperPixDataset 9 | from data.cora import CoraDataset 10 | from models.networks import * 11 | from utils.utils import * 12 | 13 | 14 | class TransInput(nn.Module): 15 | 16 | def __init__(self, trans_fn): 17 | super().__init__() 18 | self.trans = trans_fn 19 | 20 | def forward(self, input): 21 | if self.trans: 22 | input['V'] = self.trans(input['V']) 23 | return input 24 | 25 | 26 | class TransOutput(nn.Module): 27 | 28 | def __init__(self, args): 29 | super().__init__() 30 | self.args = args 31 | if args.task == 'node_level': 32 | channel_sequence = (args.node_dim, ) * args.nb_mlp_layer + (args.nb_classes, ) 33 | self.trans = MLP(channel_sequence) 34 | elif args.task == 'link_level': 35 | channel_sequence = (args.node_dim * 2, ) * args.nb_mlp_layer + (args.nb_classes, ) 36 | self.trans = MLP(channel_sequence) 37 | elif args.task == 'graph_level': 38 | channel_sequence = (args.node_dim, ) * args.nb_mlp_layer + (args.nb_classes, ) 39 | self.trans = MLP(channel_sequence) 40 | else: 41 | raise Exception('Unknown task!') 42 | 43 | 44 | def forward(self, input): 45 | G, V = input['G'], input['V'] 46 | if self.args.task == 'node_level': 47 | output = self.trans(V) 48 | elif self.args.task == 'link_level': 49 | def _edge_feat(edges): 50 | e = torch.cat([edges.src['V'], edges.dst['V']], dim=1) 51 | return {'e': e} 52 | G.ndata['V'] = V 53 | G.apply_edges(_edge_feat) 54 | output = self.trans(G.edata['e']) 55 | elif self.args.task == 'graph_level': 56 | G.ndata['V'] = V 57 | readout = dgl.mean_nodes(G, 'V') 58 | output = self.trans(readout) 59 | else: 60 | raise Exception('Unknown task!') 61 | return output 62 | 63 | 64 | def get_trans_input(args): 65 | if args.data in ['ZINC']: 66 | trans_input = nn.Embedding(args.in_dim_V, args.node_dim) 67 | elif args.data in ['TSP']: 68 | trans_input = nn.Linear(args.in_dim_V, args.node_dim) 69 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']: 70 | trans_input = nn.Embedding(args.in_dim_V, args.node_dim) 71 | elif args.data in ['CIFAR10', 'MNIST', 'Cora']: 72 | trans_input = nn.Linear(args.in_dim_V, args.node_dim) 73 | elif args.data in ['QM9']: 74 | trans_input = nn.Linear(args.in_dim_V, args.node_dim) 75 | else: 76 | raise Exception('Unknown dataset!') 77 | return trans_input 78 | 79 | 80 | def get_loss_fn(args): 81 | if args.data in ['ZINC', 'QM9']: 82 | loss_fn = MoleculesCriterion() 83 | elif args.data in ['TSP']: 84 | loss_fn = TSPCriterion() 85 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']: 86 | loss_fn = SBMsCriterion(args.nb_classes) 87 | elif args.data in ['CIFAR10', 'MNIST']: 88 | loss_fn = SuperPixCriterion() 89 | elif args.data in ['Cora']: 90 | loss_fn = CiteCriterion() 91 | else: 92 | raise Exception('Unknown dataset!') 93 | return loss_fn 94 | 95 | 96 | def load_data(args): 97 | if args.data in ['ZINC']: 98 | return MoleculeDataset(args.data) 99 | elif args.data in ['QM9']: 100 | return QM9Dataset(args.data, args.extra) 101 | elif args.data in ['TSP']: 102 | return TSPDataset(args.data) 103 | elif args.data in ['MNIST', 'CIFAR10']: 104 | return SuperPixDataset(args.data) 105 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']: 106 | return SBMsDataset(args.data) 107 | elif args.data in ['Cora']: 108 | return CoraDataset(args.data) 109 | else: 110 | raise Exception('Unknown dataset!') 111 | 112 | 113 | def load_metric(args): 114 | if args.data in ['ZINC', 'QM9']: 115 | return MAE 116 | elif args.data in ['TSP']: 117 | return binary_f1_score 118 | elif args.data in ['MNIST', 'CIFAR10']: 119 | return accuracy_MNIST_CIFAR 120 | elif args.data in ['SBM_CLUSTER', 'SBM_PATTERN']: 121 | return accuracy_SBM 122 | elif args.data in ['Cora']: 123 | return CoraAccuracy 124 | else: 125 | raise Exception('Unknown dataset!') 126 | -------------------------------------------------------------------------------- /data/cora.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import dgl.data 4 | 5 | class CoraDGL(torch.utils.data.Dataset): 6 | 7 | def __init__(self): 8 | self.graph = dgl.data.CoraGraphDataset()[0] 9 | self.graph.edata['feat'] = torch.ones([self.graph.num_edges(), 1]).float() 10 | 11 | def __getitem__(self, i): 12 | assert i == 0 13 | return self.graph, self.graph.ndata['label'] 14 | 15 | def __len__(self): 16 | return 1 17 | 18 | class CoraDataset(torch.utils.data.Dataset): 19 | 20 | def __init__(self, name): 21 | """ 22 | Loading Cora Dataset 23 | """ 24 | start = time.time() 25 | print("[I] Loading dataset %s..." % (name)) 26 | self.name = name 27 | base_graph = CoraDGL() 28 | self.train = base_graph 29 | self.val = base_graph 30 | self.test = base_graph 31 | 32 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) 33 | print("[I] Finished loading.") 34 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 35 | 36 | 37 | def collate(self, samples): 38 | return samples[0] 39 | 40 | if __name__ == '__main__': 41 | dataset = CoraDataset('Cora') 42 | print(dataset.train.__getitem__(0)) -------------------------------------------------------------------------------- /data/molecules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import torch.utils.data 4 | import time 5 | import os 6 | import numpy as np 7 | 8 | import csv 9 | 10 | import dgl 11 | 12 | from scipy import sparse as sp 13 | import numpy as np 14 | 15 | # *NOTE 16 | # The dataset pickle and index files are in ./zinc_molecules/ dir 17 | # [.pickle and .index; for split 'train', 'val' and 'test'] 18 | 19 | 20 | class MoleculeDGL(torch.utils.data.Dataset): 21 | def __init__(self, data_dir, split, num_graphs=None): 22 | self.data_dir = data_dir 23 | self.split = split 24 | self.num_graphs = num_graphs 25 | 26 | with open(data_dir + "/%s.pickle" % self.split,"rb") as f: 27 | self.data = pickle.load(f) 28 | 29 | if self.num_graphs in [10000, 1000]: 30 | # loading the sampled indices from file ./zinc_molecules/.index 31 | with open(data_dir + "/%s.index" % self.split,"r") as f: 32 | data_idx = [list(map(int, idx)) for idx in csv.reader(f)] 33 | self.data = [ self.data[i] for i in data_idx[0] ] 34 | 35 | assert len(self.data)==num_graphs, "Sample num_graphs again; available idx: train/val/test => 10k/1k/1k" 36 | 37 | """ 38 | data is a list of Molecule dict objects with following attributes 39 | 40 | molecule = data[idx] 41 | ; molecule['num_atom'] : nb of atoms, an integer (N) 42 | ; molecule['atom_type'] : tensor of size N, each element is an atom type, an integer between 0 and num_atom_type 43 | ; molecule['bond_type'] : tensor of size N x N, each element is a bond type, an integer between 0 and num_bond_type 44 | ; molecule['logP_SA_cycle_normalized'] : the chemical property to regress, a float variable 45 | """ 46 | 47 | self.graph_lists = [] 48 | self.graph_labels = [] 49 | self.n_samples = len(self.data) 50 | self._prepare() 51 | 52 | def _prepare(self): 53 | print("preparing %d graphs for the %s set..." % (self.num_graphs, self.split.upper())) 54 | 55 | for molecule in self.data: 56 | node_features = molecule['atom_type'].long() 57 | 58 | adj = molecule['bond_type'] 59 | edge_list = (adj != 0).nonzero() # converting adj matrix to edge_list 60 | 61 | edge_idxs_in_adj = edge_list.split(1, dim=1) 62 | edge_features = adj[edge_idxs_in_adj].reshape(-1).long() 63 | 64 | # Create the DGL Graph 65 | g = dgl.DGLGraph() 66 | g.add_nodes(molecule['num_atom']) 67 | g.ndata['feat'] = node_features 68 | 69 | for src, dst in edge_list: 70 | g.add_edges(src.item(), dst.item()) 71 | g.edata['feat'] = edge_features 72 | 73 | self.graph_lists.append(g) 74 | self.graph_labels.append(molecule['logP_SA_cycle_normalized']) 75 | 76 | def __len__(self): 77 | """Return the number of graphs in the dataset.""" 78 | return self.n_samples 79 | 80 | def __getitem__(self, idx): 81 | """ 82 | Get the idx^th sample. 83 | Parameters 84 | --------- 85 | idx : int 86 | The sample index. 87 | Returns 88 | ------- 89 | (dgl.DGLGraph, int) 90 | DGLGraph with node feature stored in `feat` field 91 | And its label. 92 | """ 93 | return self.graph_lists[idx], self.graph_labels[idx] 94 | 95 | 96 | class MoleculeDatasetDGL(torch.utils.data.Dataset): 97 | def __init__(self, name='Zinc'): 98 | t0 = time.time() 99 | self.name = name 100 | 101 | self.num_atom_type = 28 # known meta-info about the zinc dataset; can be calculated as well 102 | self.num_bond_type = 4 # known meta-info about the zinc dataset; can be calculated as well 103 | 104 | data_dir='./data/molecules' 105 | if self.name == 'ZINC-full': 106 | data_dir='./data/molecules/zinc_full' 107 | self.train = MoleculeDGL(data_dir, 'train', num_graphs=220011) 108 | self.val = MoleculeDGL(data_dir, 'val', num_graphs=24445) 109 | self.test = MoleculeDGL(data_dir, 'test', num_graphs=5000) 110 | else: 111 | self.train = MoleculeDGL(data_dir, 'train', num_graphs=10000) 112 | self.val = MoleculeDGL(data_dir, 'val', num_graphs=1000) 113 | self.test = MoleculeDGL(data_dir, 'test', num_graphs=1000) 114 | print("Time taken: {:.4f}s".format(time.time()-t0)) 115 | 116 | 117 | 118 | def self_loop(g): 119 | """ 120 | Utility function only, to be used only when necessary as per user self_loop flag 121 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat'] 122 | 123 | 124 | This function is called inside a function in MoleculeDataset class. 125 | """ 126 | new_g = dgl.DGLGraph() 127 | new_g.add_nodes(g.number_of_nodes()) 128 | new_g.ndata['feat'] = g.ndata['feat'] 129 | 130 | src, dst = g.all_edges(order="eid") 131 | src = dgl.backend.zerocopy_to_numpy(src) 132 | dst = dgl.backend.zerocopy_to_numpy(dst) 133 | non_self_edges_idx = src != dst 134 | nodes = np.arange(g.number_of_nodes()) 135 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx]) 136 | new_g.add_edges(nodes, nodes) 137 | 138 | # This new edata is not used since this function gets called only for GCN, GAT 139 | # However, we need this for the generic requirement of ndata and edata 140 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges()) 141 | return new_g 142 | 143 | 144 | 145 | def positional_encoding(g, pos_enc_dim): 146 | """ 147 | Graph positional encoding v/ Laplacian eigenvectors 148 | """ 149 | 150 | # Laplacian 151 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) 152 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) 153 | L = sp.eye(g.number_of_nodes()) - N * A * N 154 | 155 | # Eigenvectors with numpy 156 | EigVal, EigVec = np.linalg.eig(L.toarray()) 157 | idx = EigVal.argsort() # increasing order 158 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) 159 | g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float() 160 | 161 | # # Eigenvectors with scipy 162 | # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR') 163 | # EigVec = EigVec[:, EigVal.argsort()] # increasing order 164 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float() 165 | 166 | return g 167 | 168 | 169 | 170 | class MoleculeDataset(torch.utils.data.Dataset): 171 | 172 | def __init__(self, name): 173 | """ 174 | Loading Molecules datasets 175 | """ 176 | start = time.time() 177 | print("[I] Loading dataset %s..." % (name)) 178 | self.name = name 179 | data_dir = 'data/molecules/' 180 | with open(data_dir+name+'.pkl',"rb") as f: 181 | f = pickle.load(f) 182 | self.train = f[0] 183 | self.val = f[1] 184 | self.test = f[2] 185 | self.num_atom_type = f[3] 186 | self.num_bond_type = f[4] 187 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) 188 | print("[I] Finished loading.") 189 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 190 | 191 | 192 | # form a mini batch from a given list of samples = [(graph, label) pairs] 193 | def collate(self, samples): 194 | # The input samples is a list of pairs (graph, label). 195 | graphs, labels = map(list, zip(*samples)) 196 | labels = torch.tensor(np.array(labels)).unsqueeze(1) 197 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 198 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 199 | #snorm_n = torch.cat(tab_snorm_n).sqrt() 200 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))] 201 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ] 202 | #snorm_e = torch.cat(tab_snorm_e).sqrt() 203 | batched_graph = dgl.batch(graphs) 204 | 205 | return batched_graph, labels 206 | 207 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN 208 | def collate_dense_gnn(self, samples, edge_feat): 209 | # The input samples is a list of pairs (graph, label). 210 | graphs, labels = map(list, zip(*samples)) 211 | labels = torch.tensor(np.array(labels)).unsqueeze(1) 212 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 213 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 214 | #snorm_n = tab_snorm_n[0][0].sqrt() 215 | 216 | #batched_graph = dgl.batch(graphs) 217 | 218 | g = graphs[0] 219 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense()) 220 | """ 221 | Adapted from https://github.com/leichen2018/Ring-GNN/ 222 | Assigning node and edge feats:: 223 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}. 224 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix. 225 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i. 226 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j). 227 | """ 228 | 229 | zero_adj = torch.zeros_like(adj) 230 | 231 | if edge_feat: 232 | # use edge feats also to prepare adj 233 | adj_with_edge_feat = torch.stack([zero_adj for j in range(self.num_atom_type + self.num_bond_type)]) 234 | adj_with_edge_feat = torch.cat([adj.unsqueeze(0), adj_with_edge_feat], dim=0) 235 | 236 | us, vs = g.edges() 237 | for idx, edge_label in enumerate(g.edata['feat']): 238 | adj_with_edge_feat[edge_label.item()+1+self.num_atom_type][us[idx]][vs[idx]] = 1 239 | 240 | for node, node_label in enumerate(g.ndata['feat']): 241 | adj_with_edge_feat[node_label.item()+1][node][node] = 1 242 | 243 | x_with_edge_feat = adj_with_edge_feat.unsqueeze(0) 244 | 245 | return None, x_with_edge_feat, labels 246 | 247 | else: 248 | # use only node feats to prepare adj 249 | adj_no_edge_feat = torch.stack([zero_adj for j in range(self.num_atom_type)]) 250 | adj_no_edge_feat = torch.cat([adj.unsqueeze(0), adj_no_edge_feat], dim=0) 251 | 252 | for node, node_label in enumerate(g.ndata['feat']): 253 | adj_no_edge_feat[node_label.item()+1][node][node] = 1 254 | 255 | x_no_edge_feat = adj_no_edge_feat.unsqueeze(0) 256 | 257 | return x_no_edge_feat, None, labels 258 | 259 | def _sym_normalize_adj(self, adj): 260 | deg = torch.sum(adj, dim = 0)#.squeeze() 261 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size())) 262 | deg_inv = torch.diag(deg_inv) 263 | return torch.mm(deg_inv, torch.mm(adj, deg_inv)) 264 | 265 | def _add_self_loops(self): 266 | 267 | # function for adding self loops 268 | # this function will be called only if self_loop flag is True 269 | 270 | self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists] 271 | self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists] 272 | self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists] 273 | 274 | def _add_positional_encodings(self, pos_enc_dim): 275 | 276 | # Graph positional encoding v/ Laplacian eigenvectors 277 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists] 278 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists] 279 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists] 280 | 281 | -------------------------------------------------------------------------------- /data/molecules/prepare_molecules.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Notebook for preparing and saving MOLECULAR graphs" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import torch\n", 18 | "import pickle\n", 19 | "import time\n", 20 | "import os\n", 21 | "%matplotlib inline\n", 22 | "import matplotlib.pyplot as plt\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "# Download ZINC dataset" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "output_type": "stream", 39 | "name": "stdout", 40 | "text": [ 41 | "File already downloaded\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "if not os.path.isfile('molecules.zip'):\n", 47 | " print('downloading..')\n", 48 | " # !curl https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1 -o molecules.zip -J -L -k\n", 49 | " !unzip molecules.zip -d ../\n", 50 | " # !tar -xvf molecules.zip -C ../\n", 51 | "else:\n", 52 | " print('File already downloaded')\n", 53 | " " 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# Convert to DGL format and save with pickle" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "output_type": "stream", 70 | "name": "stdout", 71 | "text": [ 72 | "/data00/caishaofei/workspace/GNAS2\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "import os\n", 78 | "os.chdir('../../') # go to root folder of the project\n", 79 | "print(os.getcwd())\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "import pickle\n", 89 | "\n", 90 | "%load_ext autoreload\n", 91 | "%autoreload 2\n", 92 | "from data.molecules import MoleculeDatasetDGL \n", 93 | "\n", 94 | "#from data import LoadData\n", 95 | "from torch.utils.data import DataLoader\n", 96 | "\n", 97 | "from data.molecules import MoleculeDataset\n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 7, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "output_type": "stream", 107 | "name": "stdout", 108 | "text": [ 109 | "preparing 10000 graphs for the TRAIN set...\n", 110 | "/data00/caishaofei/miniconda3/envs/gnas2/lib/python3.6/site-packages/dgl/base.py:45: DGLWarning: Recommend creating graphs by `dgl.graph(data)` instead of `dgl.DGLGraph(data)`.\n", 111 | " return warnings.warn(message, category=category, stacklevel=1)\n", 112 | "preparing 1000 graphs for the VAL set...\n", 113 | "preparing 1000 graphs for the TEST set...\n", 114 | "Time taken: 378.4136s\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "DATASET_NAME = 'ZINC'\n", 120 | "dataset = MoleculeDatasetDGL(DATASET_NAME) \n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 8, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "output_type": "display_data", 130 | "data": { 131 | "text/plain": "
", 132 | "image/svg+xml": "\n\n\n\n \n \n \n \n 2021-04-26T12:51:58.618541\n image/svg+xml\n \n \n Matplotlib v3.3.4, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 133 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEICAYAAACzliQjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUDklEQVR4nO3df7DldX3f8eerq5JEQZZwQ9fdNYvOaouMWeIWyTRaWhNBSLKYZuhuU0VLXZjAjI52UkjbgdrSUiNx4sSuWcIWyCiEBA3bijWrcaSZFuSCG35KWXAJu7Pu3oiIRGcr8O4f53vLcb337r33nD333v08HzNn7ve8v78+H77wul8+3x83VYUkqQ1/a6EbIEkaHUNfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr50GEk+meTfLnQ7pGGI9+nraJdkN/AvquqLC92W6SyFNuro4Jm+mpbkJQvdBmmUDH0d1ZL8IfBq4L8leTbJbyapJBcm+Svgz7vl/jjJN5N8J8kdSd7Qt43rk/yHbvrMJHuSfCjJgST7kry3b9lzkjyU5LtJ9ib5l33zfinJziRPJ/lfSd44XRtH8g9HTTL0dVSrqncBfwX8clW9Arilm/UPgL8LnNV9/zywFvgp4F7gUzNs9m8DrwRWAhcCn0iyvJt3HXBRVR0LnMqLv1ROA7YBFwE/Cfw+sD3JMYe2sao+MnDHpWkY+mrVlVX1N1X1fYCq2lZV362qg8CVwM8keeU06/4A+HBV/aCqbgeeBV7fN++UJMdV1ber6t6uvhn4/aq6q6qer6obgIPAGUeof9KUDH216snJiSTLklyd5LEkzwC7u1knTrPut6rqub7v3wNe0U3/Y+Ac4IkkX0nyc139p4EPdUM7Tyd5GlgNvGo43ZFmx9BXC6a6Ra2/9k+BDcAv0Bu2WdPVM+cdVd1dVRvoDRP9KS8OJz0JXFVVx/d9fqKqbpqhjdLQGfpqwX7gNTPMP5beUMu3gJ8A/uN8dpLkZUl+Pckrq+oHwDPAC93sa4GLk7w5PS9Pcm6SY2fZRmkoDH214D8B/6YbUvm1KebfCDwB7AUeAu4cYF/vAnZ3w0QXA78OUFXjwPuA3wO+DewC3jNVG/vv+JGGzYezJKkhnulLUkMMfUlqiKEvSQ0x9CWpIYv+ZVMnnnhirVmzZqGbIUlLxj333PPXVTU21bxFH/pr1qxhfHx8oZshSUtGkiemm+fwjiQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNWTRP5ErLVZrLvvcvNfdffW5Q2yJNHue6UtSQwx9SWqIoS9JDTH0Jakhhw39JNuSHEjyQF/tj5Ls7D67k+zs6muSfL9v3if71nlTkvuT7Ery8SQ5Ij2SJE1rNnfvXA/8HnDjZKGq/snkdJJrgO/0Lf9YVa2bYjtbgPcBdwG3A2cDn59ziyVJ83bYM/2qugN4aqp53dn6+cBNM20jyQrguKq6s6qK3i+Q8+bcWknSQAYd038LsL+qHu2rnZzka0m+kuQtXW0lsKdvmT1dbUpJNicZTzI+MTExYBMlSZMGDf1N/PBZ/j7g1VV1GvBB4NNJjpvrRqtqa1Wtr6r1Y2NT/plHSdI8zPuJ3CQvAX4VeNNkraoOAge76XuSPAa8DtgLrOpbfVVXkySN0CBn+r8AfL2q/v+wTZKxJMu66dcAa4HHq2of8EySM7rrAO8Gbhtg35KkeZjNLZs3Af8beH2SPUku7GZt5Ecv4L4VuK+7hfNPgIuravIi8G8AfwDsAh7DO3ckaeQOO7xTVZumqb9nitqtwK3TLD8OnDrH9kmShsgnciWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGzOYPo29LciDJA321K5PsTbKz+5zTN+/yJLuSPJLkrL762V1tV5LLht8VSdLhzOZM/3rg7CnqH6uqdd3ndoAkpwAbgTd06/yXJMuSLAM+AbwDOAXY1C0rSRqhlxxugaq6I8maWW5vA3BzVR0EvpFkF3B6N29XVT0OkOTmbtmH5t5kSdJ8DTKmf2mS+7rhn+VdbSXwZN8ye7radHVJ0gjNN/S3AK8F1gH7gGuG1SCAJJuTjCcZn5iYGOamJalp8wr9qtpfVc9X1QvAtbw4hLMXWN236KquNl19uu1vrar1VbV+bGxsPk2UJE1hXqGfZEXf13cCk3f2bAc2JjkmycnAWuCrwN3A2iQnJ3kZvYu92+ffbEnSfBz2Qm6Sm4AzgROT7AGuAM5Msg4oYDdwEUBVPZjkFnoXaJ8DLqmq57vtXAp8AVgGbKuqB4fdGUnSzGZz986mKcrXzbD8VcBVU9RvB26fU+skSUPlE7mS1BBDX5IaYuhLUkMMfUlqiKEvSQ057N07khaXNZd9bqD1d1997pBaoqXIM31JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDfE1DFpwvlZAGh3P9CWpIYa+JDXE0Jekhhw29JNsS3IgyQN9td9O8vUk9yX5bJLju/qaJN9PsrP7fLJvnTcluT/JriQfT5Ij0iNJ0rRmc6Z/PXD2IbUdwKlV9Ubg/wCX9817rKrWdZ+L++pbgPcBa7vPoduUJB1hhw39qroDeOqQ2p9V1XPd1zuBVTNtI8kK4LiqurOqCrgROG9eLZYkzdswxvT/OfD5vu8nJ/lakq8keUtXWwns6VtmT1ebUpLNScaTjE9MTAyhiZIkGDD0k/xr4DngU11pH/DqqjoN+CDw6STHzXW7VbW1qtZX1fqxsbFBmihJ6jPvh7OSvAf4JeBt3ZANVXUQONhN35PkMeB1wF5+eAhoVVeTJI3QvM70k5wN/CbwK1X1vb76WJJl3fRr6F2wfbyq9gHPJDmju2vn3cBtA7dekjQnhz3TT3ITcCZwYpI9wBX07tY5BtjR3Xl5Z3enzluBDyf5AfACcHFVTV4E/g16dwL9OL1rAP3XASRJI3DY0K+qTVOUr5tm2VuBW6eZNw6cOqfWSZKGyidyJakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIbMKvSTbEtyIMkDfbUTkuxI8mj3c3lXT5KPJ9mV5L4kP9u3zgXd8o8muWD43ZEkzWS2Z/rXA2cfUrsM+FJVrQW+1H0HeAewtvtsBrZA75cEcAXwZuB04IrJXxSSpNGYVehX1R3AU4eUNwA3dNM3AOf11W+snjuB45OsAM4CdlTVU1X1bWAHP/qLRJJ0BA0ypn9SVe3rpr8JnNRNrwSe7FtuT1ebrv4jkmxOMp5kfGJiYoAmSpL6DeVCblUVUMPYVre9rVW1vqrWj42NDWuzktS8QUJ/fzdsQ/fzQFffC6zuW25VV5uuLkkakUFCfzsweQfOBcBtffV3d3fxnAF8pxsG+gLw9iTLuwu4b+9qkqQReclsFkpyE3AmcGKSPfTuwrkauCXJhcATwPnd4rcD5wC7gO8B7wWoqqeS/Hvg7m65D1fVoReHJUlH0KxCv6o2TTPrbVMsW8Al02xnG7Bt1q2TJA2VT+RKUkMMfUlqiKEvSQ0x9CWpIbO6kCtpuNZc9rmFboIa5Zm+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEN+9oyVvkPfY7L763CG2RFr8PNOXpIYY+pLUEENfkhoy79BP8vokO/s+zyT5QJIrk+ztq5/Tt87lSXYleSTJWcPpgiRptuZ9IbeqHgHWASRZBuwFPgu8F/hYVX20f/kkpwAbgTcArwK+mOR1VfX8fNsgSZqbYQ3vvA14rKqemGGZDcDNVXWwqr4B7AJOH9L+JUmzMKzQ3wjc1Pf90iT3JdmWZHlXWwk82bfMnq72I5JsTjKeZHxiYmJITZQkDRz6SV4G/Arwx11pC/BaekM/+4Br5rrNqtpaVeurav3Y2NigTZQkdYZxpv8O4N6q2g9QVfur6vmqegG4lheHcPYCq/vWW9XVJEkjMozQ30Tf0E6SFX3z3gk80E1vBzYmOSbJycBa4KtD2L8kaZYGeg1DkpcDvwhc1Ff+SJJ1QAG7J+dV1YNJbgEeAp4DLvHOHS20QV7hIC1FA4V+Vf0N8JOH1N41w/JXAVcNsk9J0vz5RK4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JashAr1bW0WWQd8vvvvrcIbZE0pHimb4kNcTQl6SGGPqS1BBDX5IaMnDoJ9md5P4kO5OMd7UTkuxI8mj3c3lXT5KPJ9mV5L4kPzvo/iVJszesM/1/WFXrqmp99/0y4EtVtRb4Uvcd4B3A2u6zGdgypP1LkmbhSA3vbABu6KZvAM7rq99YPXcCxydZcYTaIEk6xDBCv4A/S3JPks1d7aSq2tdNfxM4qZteCTzZt+6ervZDkmxOMp5kfGJiYghNlCTBcB7O+vmq2pvkp4AdSb7eP7OqKknNZYNVtRXYCrB+/fo5rStJmt7AZ/pVtbf7eQD4LHA6sH9y2Kb7eaBbfC+wum/1VV1NkjQCA4V+kpcnOXZyGng78ACwHbigW+wC4LZuejvw7u4unjOA7/QNA0mSjrBBh3dOAj6bZHJbn66q/5HkbuCWJBcCTwDnd8vfDpwD7AK+B7x3wP1rkRjkvT2SRmeg0K+qx4GfmaL+LeBtU9QLuGSQfUqS5s8nciWpIYa+JDXE9+lLjVmov5vg32tYHDzTl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kN8dXKkmbNP4u59HmmL0kNmXfoJ1md5MtJHkryYJL3d/Urk+xNsrP7nNO3zuVJdiV5JMlZw+iAJGn2BhneeQ74UFXdm+RY4J4kO7p5H6uqj/YvnOQUYCPwBuBVwBeTvK6qnh+gDZKkOZj3mX5V7auqe7vp7wIPAytnWGUDcHNVHayqbwC7gNPnu39J0twNZUw/yRrgNOCurnRpkvuSbEuyvKutBJ7sW20P0/ySSLI5yXiS8YmJiWE0UZLEEEI/ySuAW4EPVNUzwBbgtcA6YB9wzVy3WVVbq2p9Va0fGxsbtImSpM5AoZ/kpfQC/1NV9RmAqtpfVc9X1QvAtbw4hLMXWN23+qquJkkakUHu3glwHfBwVf1OX31F32LvBB7oprcDG5Mck+RkYC3w1fnuX5I0d4PcvfP3gXcB9yfZ2dV+C9iUZB1QwG7gIoCqejDJLcBD9O78ucQ7dyRptOYd+lX1F0CmmHX7DOtcBVw1331KkgbjE7mS1BBDX5Ia4gvXjjK+EEvSTDzTl6SGeKYv6ag3yP8B77763CG2ZOEZ+pIWPYcth8fhHUlqiKEvSQ0x9CWpIYa+JDXEC7lHgBedJC1WnulLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQ79OXpBkcbW/o9Exfkhoy8jP9JGcDvwssA/6gqq4edRtmw6dqJR2NRhr6SZYBnwB+EdgD3J1ke1U9dCT2Z3BLWkiLcWho1MM7pwO7qurxqvq/wM3AhhG3QZKaNerhnZXAk33f9wBvPnShJJuBzd3XZ5M8MoK2nQj89Qj2M2r2a+k5Wvtmv+Yg/3mg1X96uhmL8u6dqtoKbB3lPpOMV9X6Ue5zFOzX0nO09s1+LQ6jHt7ZC6zu+76qq0mSRmDUoX83sDbJyUleBmwEto+4DZLUrJEO71TVc0kuBb5A75bNbVX14CjbMIORDieNkP1aeo7WvtmvRSBVtdBtkCSNiE/kSlJDDH1JakhzoZ9kW5IDSR7oq52QZEeSR7ufyxeyjfM1Td+uTLI3yc7uc85CtnE+kqxO8uUkDyV5MMn7u/qSPm4z9OtoOGY/luSrSf6y69u/6+onJ7krya4kf9Td0LFkzNCv65N8o++YrVvgpk6ruTH9JG8FngVurKpTu9pHgKeq6uoklwHLq+pfLWQ752Oavl0JPFtVH13Itg0iyQpgRVXdm+RY4B7gPOA9LOHjNkO/zmfpH7MAL6+qZ5O8FPgL4P3AB4HPVNXNST4J/GVVbVnIts7FDP26GPjvVfUnC9rAWWjuTL+q7gCeOqS8Abihm76B3n94S840fVvyqmpfVd3bTX8XeJje091L+rjN0K8lr3qe7b6+tPsU8I+AyWBcisdsun4tGc2F/jROqqp93fQ3gZMWsjFHwKVJ7uuGf5bUEMihkqwBTgPu4ig6bof0C46CY5ZkWZKdwAFgB/AY8HRVPdctsocl+Evu0H5V1eQxu6o7Zh9LcszCtXBmhv4hqjfetaR+cx/GFuC1wDpgH3DNgrZmAEleAdwKfKCqnumft5SP2xT9OiqOWVU9X1Xr6D15fzrwdxa2RcNxaL+SnApcTq9/fw84AVi0w4yGfs/+bnx1cpz1wAK3Z2iqan/3L+kLwLX0/uNbcrrx01uBT1XVZ7rykj9uU/XraDlmk6rqaeDLwM8BxyeZfCh0Sb+Gpa9fZ3dDdVVVB4H/yiI+ZoZ+z3bggm76AuC2BWzLUE2GYuedwAPTLbtYdRfPrgMerqrf6Zu1pI/bdP06So7ZWJLju+kfp/c3NB6mF5K/1i22FI/ZVP36et/JR+hdp1i0x6zFu3duAs6k9zrU/cAVwJ8CtwCvBp4Azq+qJXdBdJq+nUlvmKCA3cBFfePgS0KSnwf+J3A/8EJX/i16499L9rjN0K9NLP1j9kZ6F2qX0Tu5vKWqPpzkNfT+jsYJwNeAf9adHS8JM/Trz4ExIMBO4OK+C76LSnOhL0ktc3hHkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SG/D8Wxj4pbmUvowAAAABJRU5ErkJggg==\n" 134 | }, 135 | "metadata": { 136 | "needs_background": "light" 137 | } 138 | }, 139 | { 140 | "output_type": "stream", 141 | "name": "stdout", 142 | "text": [ 143 | "min/max : 9 37\n" 144 | ] 145 | }, 146 | { 147 | "output_type": "display_data", 148 | "data": { 149 | "text/plain": "
", 150 | "image/svg+xml": "\n\n\n\n \n \n \n \n 2021-04-26T12:51:58.887591\n image/svg+xml\n \n \n Matplotlib v3.3.4, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 151 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAS2ElEQVR4nO3df6xfdX3H8edrFXFRTGHcNR1QCwZdmNnKcocu/lgn2wQ0QxbDrJuiMysksmlcMqtLJltCgk4kW3SYGhBI5NdEJplskzAmcxnOghWLldliiW1KexVRGI6t8N4f33O3L9d723vv+d7e3s99PpJvvuf7Oed8z/uT0/v6nn6+53xPqgpJUlt+YrELkCSNnuEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12aIsn6JLsXuw6pD8NdWgBJ1iapJM9Z7Fq0PBnuktQgw13NSvK+JJ+Z0vaXSf4qyTuSbE/yeJKHklx4iPfZ0y37YJIzu/afSLIpyc4k30tyc5LjutXu7p4fS/JEkl9emF5K0zPc1bIbgXOSHAOQZAVwPnA9sB94A/BC4B3AFUl+ceobJHkpcDHwS1V1DPA6YFc3+w+ANwK/AvwM8H3g492813TPK6vqBVX1b6PunHQwhruaVVUPA/cB53VNrwWerKp7qurzVbWzBr4IfAF49TRv8zRwNHBakqOqaldV7ezmXQT8SVXtrqqngEuANznOriOB4a7WXQ9s6Kbf0r0mydlJ7knyaJLHgHOA46euXFU7gPcwCO79SW5M8jPd7BcBtyZ5rHuP7Qw+DFYtWG+kWTLc1bq/AdYnOZHBEfz1SY4GbgE+AqyqqpXA7UCme4Oqur6qXsUgzAv4UDfrO8DZVbVy6PG8qtrTLSctGsNdTauqCeCfgU8B366q7cBzGQy1TAAHkpwN/MZ06yd5aZLXdh8I/wX8CHimm/0J4NIkL+qWHUtybjdvolvulAXpmHQIhruWg+uBX+ueqarHgT8EbmbwJehbgNtmWPdo4DLgu8AjwE8D7+/m/WW33heSPA7cA7y828aTwKXAv3bDNq8YfbekmcWbdUhSezxyl6QGGe6S1KBDhnuSk5LcleQbSR5I8u6u/bgkdyT5Vvd8bNee7grAHUnun+7CEEnSwprNkfsB4I+q6jTgFcC7kpwGbALurKpTgTu71wBnA6d2j43AlSOvWpJ0UIe8kq6q9gJ7u+nHk2wHTgDOBdZ3i13L4HSz93Xt19Xgm9p7kqxMsrp7n2kdf/zxtXbt2h7dkKTl59577/1uVY1NN29Ol0knWQucDnyZwcUfk4H9CP9/Vd4JDC7umLS7a3tWuCfZyODInjVr1rBly5a5lCJJy16Sh2eaN+svVJO8gMFVfe+pqh8Oz+uO0ud0TmVVba6q8aoaHxub9oNHkjRPswr3JEcxCPZPV9Vnu+Z9SVZ381cz+JU9gD3ASUOrn9i1SZIOk9mcLRPgKmB7VX10aNZtwAXd9AXA54ba39adNfMK4AcHG2+XJI3ebMbcXwm8Ffh6kq1d2wcYXJJ9c5J3Ag8z+J1sGPwA0znADuBJBr+VLUk6jGZztsyXmOHX8oAzp1m+gHf1rEuS1INXqEpSgwx3SWqQ4S5JDTLcJalB3shXOoS1mz4/73V3Xfb6EVYizZ5H7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ2azQ2yr06yP8m2obabkmztHrsm762aZG2SHw3N+8QC1i5JmsFsfvL3GuBjwHWTDVX125PTSS4HfjC0/M6qWjei+iRJ8zCbG2TfnWTtdPOSBDgfeO2I65Ik9dB3zP3VwL6q+tZQ28lJvprki0lePdOKSTYm2ZJky8TERM8yJEnD+ob7BuCGodd7gTVVdTrwXuD6JC+cbsWq2lxV41U1PjY21rMMSdKweYd7kucAvwXcNNlWVU9V1fe66XuBncBL+hYpSZqbPkfuvwZ8s6p2TzYkGUuyops+BTgVeKhfiZKkuZrNqZA3AP8GvDTJ7iTv7Ga9mWcPyQC8Bri/OzXyM8BFVfXoCOuVJM3CbM6W2TBD+9unabsFuKV/WZKkPrxCVZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSg2ZzD9Wrk+xPsm2o7ZIke5Js7R7nDM17f5IdSR5M8rqFKlySNLPZHLlfA5w1TfsVVbWue9wOkOQ0BjfO/rlunb9OsmJUxUqSZueQ4V5VdwOPzvL9zgVurKqnqurbwA7gjB71SZLmoc+Y+8VJ7u+GbY7t2k4AvjO0zO6u7cck2ZhkS5ItExMTPcqQJE0133C/EngxsA7YC1w+1zeoqs1VNV5V42NjY/MsQ5I0nXmFe1Xtq6qnq+oZ4JP8/9DLHuCkoUVP7NokSYfRvMI9yeqhl+cBk2fS3Aa8OcnRSU4GTgX+vV+JkqS5es6hFkhyA7AeOD7JbuCDwPok64ACdgEXAlTVA0luBr4BHADeVVVPL0jlkqQZHTLcq2rDNM1XHWT5S4FL+xQlSerHK1QlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXokOGe5Ook+5NsG2r7iyTfTHJ/kluTrOza1yb5UZKt3eMTC1i7JGkGszlyvwY4a0rbHcDLqurngf8A3j80b2dVreseF42mTEnSXBwy3KvqbuDRKW1fqKoD3ct7gBMXoDZJ0jyNYsz994C/H3p9cpKvJvliklfPtFKSjUm2JNkyMTExgjIkSZN6hXuSPwEOAJ/umvYCa6rqdOC9wPVJXjjdulW1uarGq2p8bGysTxmSpCnmHe5J3g68AfidqiqAqnqqqr7XTd8L7AReMoI6JUlzMK9wT3IW8MfAb1bVk0PtY0lWdNOnAKcCD42iUEnS7D3nUAskuQFYDxyfZDfwQQZnxxwN3JEE4J7uzJjXAH+e5H+AZ4CLqurRad9YkrRgDhnuVbVhmuarZlj2FuCWvkVJkvrxClVJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ2aVbgnuTrJ/iTbhtqOS3JHkm91z8d27UnyV0l2JLk/yS8uVPGSpOkd8h6qnWuAjwHXDbVtAu6sqsuSbOpevw84Gzi1e7wcuLJ7VgPWbvr8vNfdddnrR1iJpIOZ1ZF7Vd0NPDql+Vzg2m76WuCNQ+3X1cA9wMokq0dQqyRplvqMua+qqr3d9CPAqm76BOA7Q8vt7tqeJcnGJFuSbJmYmOhRhiRpqtkOyxxUVVWSmuM6m4HNAOPj43NaV1oqHMbSYulz5L5vcrile97fte8BThpa7sSuTZJ0mPQJ99uAC7rpC4DPDbW/rTtr5hXAD4aGbyRJh8GshmWS3ACsB45Pshv4IHAZcHOSdwIPA+d3i98OnAPsAJ4E3jHimiVJhzCrcK+qDTPMOnOaZQt4V5+iJEn9eIWqJDXIcJekBhnuktQgw12SGmS4S1KDRnKFqqTR8+pW9WG4Sw3q88EAfji0wGEZSWqQ4S5JDTLcJalBjrlrSXAMWZobj9wlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWrQvM9zT/JS4KahplOAPwVWAr8PTHTtH6iq2+e7HUnS3M073KvqQWAdQJIVwB7gVgY3xL6iqj4yigIlSXM3qitUzwR2VtXDSUb0ltLo9L3CVVpqRjXm/mbghqHXFye5P8nVSY6dboUkG5NsSbJlYmJiukUkSfPUO9yTPBf4TeBvuqYrgRczGLLZC1w+3XpVtbmqxqtqfGxsrG8ZkqQhozhyPxu4r6r2AVTVvqp6uqqeAT4JnDGCbUiS5mAU4b6BoSGZJKuH5p0HbBvBNiRJc9DrC9Ukzwd+HbhwqPnDSdYBBeyaMk+SdBj0Cveq+k/gp6a0vbVXRZKk3rxCVZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNGtVt9qRD8lZ30uHjkbskNchwl6QGGe6S1CDDXZIaZLhLUoN6ny2TZBfwOPA0cKCqxpMcB9wErGVwH9Xzq+r7fbclSZqdUR25/2pVrauq8e71JuDOqjoVuLN7LUk6TBZqWOZc4Npu+lrgjQu0HUnSNEYR7gV8Icm9STZ2bauqam83/QiwagTbkSTN0iiuUH1VVe1J8tPAHUm+OTyzqipJTV2p+yDYCLBmzZoRlCFJmtT7yL2q9nTP+4FbgTOAfUlWA3TP+6dZb3NVjVfV+NjYWN8yJElDeoV7kucnOWZyGvgNYBtwG3BBt9gFwOf6bEeSNDd9h2VWAbcmmXyv66vqH5J8Bbg5yTuBh4Hze25HkjQHvcK9qh4CfmGa9u8BZ/Z5b0nS/HmFqiQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lq0Chus6clZu2mzy92CZIWmEfuktQgw12SGmS4S1KDDHdJatC8wz3JSUnuSvKNJA8keXfXfkmSPUm2do9zRleuJGk2+pwtcwD4o6q6L8kxwL1J7ujmXVFVH+lfXrs8Y0X6cX3+LnZd9voRVrL0zTvcq2ovsLebfjzJduCEURUmafEYskvfSM5zT7IWOB34MvBK4OIkbwO2MDi6//4062wENgKsWbNmFGVIOgL4v9IjQ+8vVJO8ALgFeE9V/RC4EngxsI7Bkf3l061XVZuraryqxsfGxvqWIUka0ivckxzFINg/XVWfBaiqfVX1dFU9A3wSOKN/mZKkuehztkyAq4DtVfXRofbVQ4udB2ybf3mSpPnoM+b+SuCtwNeTbO3aPgBsSLIOKGAXcGGPbUiS5qHP2TJfAjLNrNvnX44kaRS8QlWSGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkho0kp/8laTF5m/QP5vh3oO/Wy3pSOWwjCQ1yCN3SerpSBwS8shdkhrkkbukZa/F78+Wfbi3uFMlqYlwN6Al6dkcc5ekBi1YuCc5K8mDSXYk2bRQ25Ek/bgFCfckK4CPA2cDpzG4afZpC7EtSdKPW6gj9zOAHVX1UFX9N3AjcO4CbUuSNMVCfaF6AvCdode7gZcPL5BkI7Cxe/lEkgd7bO944Ls91l8qlks/Yfn0dbn0E5ZPX+fUz3yo17ZeNNOMRTtbpqo2A5tH8V5JtlTV+Cje60i2XPoJy6evy6WfsHz6eqT0c6GGZfYAJw29PrFrkyQdBgsV7l8BTk1ycpLnAm8GblugbUmSpliQYZmqOpDkYuAfgRXA1VX1wEJsqzOS4Z0lYLn0E5ZPX5dLP2H59PWI6GeqarFrkCSNmFeoSlKDDHdJatCSCvckVyfZn2TbUNtxSe5I8q3u+djFrHFUZujrJUn2JNnaPc5ZzBpHIclJSe5K8o0kDyR5d9fe3H49SF+b2q9Jnpfk35N8revnn3XtJyf5cveTJDd1J1ssaQfp6zVJvj20T9cd9tqW0ph7ktcATwDXVdXLurYPA49W1WXdb9gcW1XvW8w6R2GGvl4CPFFVH1nM2kYpyWpgdVXdl+QY4F7gjcDbaWy/HqSv59PQfk0S4PlV9USSo4AvAe8G3gt8tqpuTPIJ4GtVdeVi1trXQfp6EfB3VfWZxaptSR25V9XdwKNTms8Fru2mr2Xwx7LkzdDX5lTV3qq6r5t+HNjO4Arn5vbrQfralBp4ont5VPco4LXAZNi1sk9n6uuiW1LhPoNVVbW3m34EWLWYxRwGFye5vxu2WfJDFcOSrAVOB75M4/t1Sl+hsf2aZEWSrcB+4A5gJ/BYVR3oFtlNIx9sU/taVZP79NJun16R5OjDXVcL4f5/ajDGdER8ai6QK4EXA+uAvcDli1rNCCV5AXAL8J6q+uHwvNb26zR9bW6/VtXTVbWOwdXpZwA/u7gVLZypfU3yMuD9DPr8S8BxwGEfUmwh3Pd1Y5mTY5r7F7meBVNV+7p/SM8An2TwR7PkdWOVtwCfrqrPds1N7tfp+trqfgWoqseAu4BfBlYmmbxwsrmfJBnq61ndEFxV1VPAp1iEfdpCuN8GXNBNXwB8bhFrWVCTYdc5D9g207JLRfeF1FXA9qr66NCs5vbrTH1tbb8mGUuyspv+SeDXGXy/cBfwpm6xVvbpdH395tCBSRh8t3DY9+lSO1vmBmA9g5/U3Ad8EPhb4GZgDfAwcH5VLfkvImfo63oG/3UvYBdw4dC49JKU5FXAvwBfB57pmj/AYCy6qf16kL5uoKH9muTnGXxhuoLBAeTNVfXnSU5hcG+H44CvAr/bHdkuWQfp6z8BY0CArcBFQ1+8Hp7allK4S5Jmp4VhGUnSFIa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJatD/Artwr29Huy65AAAAAElFTkSuQmCC\n" 152 | }, 153 | "metadata": { 154 | "needs_background": "light" 155 | } 156 | }, 157 | { 158 | "output_type": "stream", 159 | "name": "stdout", 160 | "text": [ 161 | "min/max : 10 36\n" 162 | ] 163 | }, 164 | { 165 | "output_type": "display_data", 166 | "data": { 167 | "text/plain": "
", 168 | "image/svg+xml": "\n\n\n\n \n \n \n \n 2021-04-26T12:51:59.147674\n image/svg+xml\n \n \n Matplotlib v3.3.4, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n", 169 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAR8klEQVR4nO3de5CddX3H8ffHKGoRC5SVSbk04KDWMjZ0VtQRLfVSEdRIL0jGC6g10pEZHW0t0E5BO0ypFZl22kJDoaBVLiNeaKEqRUd0rJcFIwaBChiGpCFZBQTUoQLf/rFP2uO6S3b3OZvd/e37NXNmn/N7bt9fnuSzv/zOc85JVSFJasvjFroASdLwGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7mpNkU5KX9TzGiUm+vFjqkWbLcJekBhnuakqSjwAHAv+a5MEk703y/CRfSXJfkm8lOXJg+xOT3JHkgSTfS/L6JL8KnAe8oDvGfd22Ryf5TrftliR/NHCcVyXZ0J3jK0meM109u+wPQ8ta/PgBtSbJJuAPquo/kuwH3Ai8EfgM8FLgUuBZwI+BrcBzq+rWJCuBvavqpiQndsc4YuC4W4HjqupLSfYCDqqqG5IcBnwWeDUwBrwBeB/wzKp6aLCeXdF/CRy5q31vAK6uqqur6tGquoaJAD66W/8ocGiSJ1fV1qq66TGO9VPg2UmeWlX3VtUNXfs64B+r6mtV9UhVXQw8BDx/nvok7ZThrtb9CvD73XTJfd0UyxHAyqr6EfA64CRga5KrkjzrMY71u0z8UrgzyReTvGDgHO+ZdI4DgF+epz5JO2W4q0WDc413AR+pqj0HHrtX1VkAVfXZqno5sBK4BTh/imPQbfuNqloDPA34FHD5wDnOnHSOX6iqS6Y7ljTfDHe1aBtwcLf8L8Crk7wiyYokT0pyZJL9k+ybZE2S3ZmYRnmQiWmaHcfYP8luAEl2615s/cWq+ilw/8C25wMnJXleJuye5Jgke0xRj7RLGO5q0V8Cf9ZNj7wOWAOcBowzMcr+Yyb+7j8OeDfw38A9wG8Cf9gd4/PATcDdSb7ftb0R2JTkfiamcl4PUFVjwNuAvwPuBW4DTpyqnsE7bKT55N0yktQgR+6S1CDDXZIaZLhLUoMMd0lq0OMXugCAffbZp1atWrXQZUjSknL99dd/v6pGplq3KMJ91apVjI2NLXQZkrSkJLlzunVOy0hSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMWxTtUpcVs1SlXzXnfTWcdM8RKpJnb6cg9yYVJtifZONB2WZIN3WNTkg1d+6okPxlYd9481i5JmsZMRu4XMfH1YR/e0VBVr9uxnORs4IcD299eVauHVJ8kaQ52Gu5VdV2SVVOtSxLgOOAlQ65LktRD3xdUXwRsq6rvDrQdlOSbSb6Y5EXT7ZhkXZKxJGPj4+M9y5AkDeob7muBSwaebwUOrKrDmPhW+Y8leepUO1bV+qoararRkZEpP45YkjRHcw73JI8Hfge4bEdbVT1UVT/olq8Hbgee0bdISdLs9Bm5vwy4pao272hIMpJkRbd8MHAIcEe/EiVJszWTWyEvAf4TeGaSzUne2q06np+dkgF4MXBjd2vkx4GTquqeIdYrSZqBmdwts3aa9hOnaLsCuKJ/WZKkPvz4AUlqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNWin4Z7kwiTbk2wcaDsjyZYkG7rH0QPrTk1yW5Jbk7xivgqXJE1vJiP3i4Cjpmg/p6pWd4+rAZI8Gzge+LVun39IsmJYxUqSZman4V5V1wH3zPB4a4BLq+qhqvoecBtweI/6JElz0GfO/eQkN3bTNnt1bfsBdw1ss7lr+zlJ1iUZSzI2Pj7eowxJ0mRzDfdzgacDq4GtwNmzPUBVra+q0aoaHRkZmWMZkqSpzCncq2pbVT1SVY8C5/P/Uy9bgAMGNt2/a5Mk7UJzCvckKweeHgvsuJPmSuD4JE9MchBwCPD1fiVKkmbr8TvbIMklwJHAPkk2A6cDRyZZDRSwCXg7QFXdlORy4DvAw8A7quqRealckjStnYZ7Va2dovmCx9j+TODMPkVJkvrxHaqS1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktSgnb5DVRq06pSr5rzvprOOGWIlkh6LI3dJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDdppuCe5MMn2JBsH2v46yS1JbkzyySR7du2rkvwkyYbucd481i5JmsZMRu4XAUdNarsGOLSqngP8F3DqwLrbq2p19zhpOGVKkmZjp+FeVdcB90xq+1xVPdw9/Sqw/zzUJkmao2HMub8F+PeB5wcl+WaSLyZ50XQ7JVmXZCzJ2Pj4+BDKkCTt0Cvck/wp8DDw0a5pK3BgVR0GvBv4WJKnTrVvVa2vqtGqGh0ZGelThiRpkjmHe5ITgVcBr6+qAqiqh6rqB93y9cDtwDOGUKckaRbmFO5JjgLeC7ymqn480D6SZEW3fDBwCHDHMAqVJM3cTr+JKcklwJHAPkk2A6czcXfME4FrkgB8tbsz5sXA+5P8FHgUOKmq7pnywJKkebPTcK+qtVM0XzDNtlcAV/QtSpLUj+9QlaQGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSg2YU7kkuTLI9ycaBtr2TXJPku93Pvbr2JPnbJLcluTHJb8xX8ZKkqc105H4RcNSktlOAa6vqEODa7jnAK4FDusc64Nz+ZUqSZmNG4V5V1wH3TGpeA1zcLV8MvHag/cM14avAnklWDqFWSdIM9Zlz37eqtnbLdwP7dsv7AXcNbLe5a/sZSdYlGUsyNj4+3qMMSdJkQ3lBtaoKqFnus76qRqtqdGRkZBhlSJI6fcJ9247plu7n9q59C3DAwHb7d22SpF2kT7hfCZzQLZ8AfHqg/U3dXTPPB344MH0jSdoFHj+TjZJcAhwJ7JNkM3A6cBZweZK3AncCx3WbXw0cDdwG/Bh485BrliTtxIzCvarWTrPqpVNsW8A7+hQlSerHd6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1aEafCikNw6pTrprzvpvOOmaIlUjtM9y1JPT5xQD+ctDyY7hLi5T/01EfzrlLUoMMd0lqkOEuSQ0y3CWpQXN+QTXJM4HLBpoOBv4c2BN4GzDetZ9WVVfP9TySpNmbc7hX1a3AaoAkK4AtwCeBNwPnVNUHh1GgpNnz1lENa1rmpcDtVXXnkI4nSephWOF+PHDJwPOTk9yY5MIke021Q5J1ScaSjI2Pj0+1iSRpjnq/iSnJbsBrgFO7pnOBvwCq+3k28JbJ+1XVemA9wOjoaPWtQ3osfacppKVmGCP3VwI3VNU2gKraVlWPVNWjwPnA4UM4hyRpFoYR7msZmJJJsnJg3bHAxiGcQ5I0C72mZZLsDrwcePtA8weSrGZiWmbTpHWSpF2gV7hX1Y+AX5rU9sZeFUmSevMdqpLUID/yV5pH3qWjheLIXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSg3p/zV6STcADwCPAw1U1mmRv4DJgFbAJOK6q7u17LknSzAxr5P5bVbW6qka756cA11bVIcC13XNJ0i4yX9Mya4CLu+WLgdfO03kkSVMYRrgX8Lkk1ydZ17XtW1Vbu+W7gX0n75RkXZKxJGPj4+NDKEOStEPvOXfgiKrakuRpwDVJbhlcWVWVpCbvVFXrgfUAo6OjP7dekjR3vUfuVbWl+7kd+CRwOLAtyUqA7uf2vueRJM1cr3BPsnuSPXYsA78NbASuBE7oNjsB+HSf80iSZqfvtMy+wCeT7DjWx6rqM0m+AVye5K3AncBxPc8jSZqFXuFeVXcAvz5F+w+Al/Y5tiRp7nyHqiQ1yHCXpAYZ7pLUIMNdkho0jDcxaQ5WnXLVnPfddNYxQ6xEUoscuUtSgxy5L0F9Rv3gyF9aDhy5S1KDDHdJapDhLkkNcs59Geo7Zy9p8XPkLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDZpzuCc5IMkXknwnyU1J3tm1n5FkS5IN3ePo4ZUrSZqJPh8/8DDwnqq6IckewPVJrunWnVNVH+xfniRpLuYc7lW1FdjaLT+Q5GZgv2EVJkmau6F8cFiSVcBhwNeAFwInJ3kTMMbE6P7eKfZZB6wDOPDAA4dRhqQh8Wsgl77eL6gmeQpwBfCuqrofOBd4OrCaiZH92VPtV1Xrq2q0qkZHRkb6liFJGtAr3JM8gYlg/2hVfQKgqrZV1SNV9ShwPnB4/zIlSbPR526ZABcAN1fVhwbaVw5sdiywce7lSZLmos+c+wuBNwLfTrKhazsNWJtkNVDAJuDtPc4hSZqDPnfLfBnIFKuunns5kqRh8B2qktQgv0NV0qLhLZjD48hdkhrkyF3SUPUZfWt4HLlLUoMMd0lqkOEuSQ0y3CWpQb6g2oMvHElarBy5S1KDDHdJapDhLkkNMtwlqUGGuyQ1qIm7ZfywIUn6WY7cJalBTYzc+/BedUktWvbhLkl9LcapYadlJKlBjtwlLXstTs/OW7gnOQr4G2AF8E9VddZ8nUuSWgzoPuZlWibJCuDvgVcCzwbWJnn2fJxLkvTz5mvO/XDgtqq6o6r+B7gUWDNP55IkTTJf0zL7AXcNPN8MPG9wgyTrgHXd0weT3Dqweh/g+/NU22KyXPoJy6evy6WfsHz6Oq/9zF/12v1XpluxYC+oVtV6YP1U65KMVdXoLi5pl1su/YTl09fl0k9YPn1dqv2cr2mZLcABA8/379okSbvAfIX7N4BDkhyUZDfgeODKeTqXJGmSeZmWqaqHk5wMfJaJWyEvrKqbZnGIKadrGrRc+gnLp6/LpZ+wfPq6JPuZqlroGiRJQ+bHD0hSgwx3SWrQgod7kguTbE+ycaBt7yTXJPlu93OvhaxxGKbp5xlJtiTZ0D2OXsgahyHJAUm+kOQ7SW5K8s6uvcVrOl1fm7quSZ6U5OtJvtX1831d+0FJvpbktiSXdTdPLGmP0deLknxv4JquXuBSd2rB59yTvBh4EPhwVR3atX0AuKeqzkpyCrBXVf3JQtbZ1zT9PAN4sKo+uJC1DVOSlcDKqrohyR7A9cBrgRNp75pO19fjaOi6Jgmwe1U9mOQJwJeBdwLvBj5RVZcmOQ/4VlWdu5C19vUYfT0J+Leq+viCFjgLCz5yr6rrgHsmNa8BLu6WL2biH8ySNk0/m1NVW6vqhm75AeBmJt6x3OI1na6vTakJD3ZPn9A9CngJsCPsWrmm0/V1yVnwcJ/GvlW1tVu+G9h3IYuZZycnubGbtlnyUxWDkqwCDgO+RuPXdFJfobHrmmRFkg3AduAa4Hbgvqp6uNtkM438Ypvc16racU3P7K7pOUmeuHAVzsxiDff/UxPzRkvyN+cMnAs8HVgNbAXOXtBqhijJU4ArgHdV1f2D61q7plP0tbnrWlWPVNVqJt5tfjjwrIWtaP5M7muSQ4FTmejzc4G9gUU/pbhYw31bN5+5Y15z+wLXMy+qalv3F+lR4Hwm/tEsed1c5RXAR6vqE11zk9d0qr62el0Bquo+4AvAC4A9k+x4I2RzHzEy0Nejuim4qqqHgH9mCVzTxRruVwIndMsnAJ9ewFrmzY6w6xwLbJxu26Wie0HqAuDmqvrQwKrmrul0fW3tuiYZSbJnt/xk4OVMvL7wBeD3us1auaZT9fWWgYFJmHhtYdFf08Vwt8wlwJFMfKzmNuB04FPA5cCBwJ3AcVW1pF+MnKafRzLxX/cCNgFvH5iXXpKSHAF8Cfg28GjXfBoTc9GtXdPp+rqWhq5rkucw8YLpCiYGhJdX1fuTHMzEdzXsDXwTeEM3sl2yHqOvnwdGgAAbgJMGXnhdlBY83CVJw7dYp2UkST0Y7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalB/wt5mHs8bR+hZwAAAABJRU5ErkJggg==\n" 170 | }, 171 | "metadata": { 172 | "needs_background": "light" 173 | } 174 | }, 175 | { 176 | "output_type": "stream", 177 | "name": "stdout", 178 | "text": [ 179 | "min/max : 11 37\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "def plot_histo_graphs(dataset, title):\n", 185 | " # histogram of graph sizes\n", 186 | " graph_sizes = []\n", 187 | " for graph in dataset:\n", 188 | " graph_sizes.append(graph[0].number_of_nodes())\n", 189 | " plt.figure(1)\n", 190 | " plt.hist(graph_sizes, bins=20)\n", 191 | " plt.title(title)\n", 192 | " plt.show()\n", 193 | " graph_sizes = torch.Tensor(graph_sizes)\n", 194 | " print('min/max :',graph_sizes.min().long().item(),graph_sizes.max().long().item())\n", 195 | " \n", 196 | "plot_histo_graphs(dataset.train,'trainset')\n", 197 | "plot_histo_graphs(dataset.val,'valset')\n", 198 | "plot_histo_graphs(dataset.test,'testset')\n" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 9, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "output_type": "stream", 208 | "name": "stdout", 209 | "text": [ 210 | "10000\n1000\n1000\n(Graph(num_nodes=29, num_edges=64,\n ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}\n edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}), tensor([0.8350]))\n(Graph(num_nodes=35, num_edges=78,\n ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}\n edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}), tensor([0.6299]))\n(Graph(num_nodes=16, num_edges=34,\n ndata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}\n edata_schemes={'feat': Scheme(shape=(), dtype=torch.int64)}), tensor([1.9973]))\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "print(len(dataset.train))\n", 216 | "print(len(dataset.val))\n", 217 | "print(len(dataset.test))\n", 218 | "\n", 219 | "print(dataset.train[0])\n", 220 | "print(dataset.val[0])\n", 221 | "print(dataset.test[0])\n" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "num_atom_type = 28\n", 231 | "num_bond_type = 4\n" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 11, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "output_type": "stream", 241 | "name": "stderr", 242 | "text": [ 243 | "/data00/caishaofei/miniconda3/envs/gnas2/lib/python3.6/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n", 244 | " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n", 245 | "Time (sec): 6.077600479125977\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "start = time.time()\n", 251 | "with open('data/molecules/ZINC.pkl','wb') as f:\n", 252 | " pickle.dump([dataset.train,dataset.val,dataset.test,num_atom_type,num_bond_type],f)\n", 253 | "print('Time (sec):',time.time() - start)\n" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "# Test load function" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 12, 266 | "metadata": {}, 267 | "outputs": [ 268 | { 269 | "output_type": "error", 270 | "ename": "NameError", 271 | "evalue": "name 'LoadData' is not defined", 272 | "traceback": [ 273 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 274 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 275 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mDATASET_NAME\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'ZINC'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLoadData\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDATASET_NAME\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mtrainset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtestset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 276 | "\u001b[0;31mNameError\u001b[0m: name 'LoadData' is not defined" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "DATASET_NAME = 'ZINC'\n", 282 | "dataset = LoadData(DATASET_NAME)\n", 283 | "trainset, valset, testset = dataset.train, dataset.val, dataset.test\n" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 11, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "batch_size = 10\n", 301 | "collate = MoleculeDataset.collate\n", 302 | "print(MoleculeDataset)\n", 303 | "train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)\n" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [] 319 | } 320 | ], 321 | "metadata": { 322 | "kernelspec": { 323 | "name": "python3613jvsc74a57bd0720c3f9f2262024fbbfd811ec9823dbfd16426a3ecbb08b5c838007ee1202dee", 324 | "display_name": "Python 3.6.13 64-bit ('gnas2': conda)" 325 | }, 326 | "language_info": { 327 | "codemirror_mode": { 328 | "name": "ipython", 329 | "version": 3 330 | }, 331 | "file_extension": ".py", 332 | "mimetype": "text/x-python", 333 | "name": "python", 334 | "nbconvert_exporter": "python", 335 | "pygments_lexer": "ipython3", 336 | "version": "3.6.13" 337 | }, 338 | "metadata": { 339 | "interpreter": { 340 | "hash": "720c3f9f2262024fbbfd811ec9823dbfd16426a3ecbb08b5c838007ee1202dee" 341 | } 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 4 346 | } -------------------------------------------------------------------------------- /data/superpixels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.spatial.distance import cdist 4 | import numpy as np 5 | import itertools 6 | 7 | import dgl 8 | import torch 9 | import torch.utils.data 10 | 11 | import time 12 | 13 | import csv 14 | from sklearn.model_selection import StratifiedShuffleSplit 15 | 16 | from scipy import sparse as sp 17 | import numpy as np 18 | 19 | def sigma(dists, kth=8): 20 | # Compute sigma and reshape 21 | try: 22 | # Get k-nearest neighbors for each node 23 | knns = np.partition(dists, kth, axis=-1)[:, kth::-1] 24 | sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth 25 | except ValueError: # handling for graphs with num_nodes less than kth 26 | num_nodes = dists.shape[0] 27 | # this sigma value is irrelevant since not used for final compute_edge_list 28 | sigma = np.array([1]*num_nodes).reshape(num_nodes,1) 29 | 30 | return sigma + 1e-8 # adding epsilon to avoid zero value of sigma 31 | 32 | 33 | def compute_adjacency_matrix_images(coord, feat, use_feat=True, kth=8): 34 | coord = coord.reshape(-1, 2) 35 | # Compute coordinate distance 36 | c_dist = cdist(coord, coord) 37 | 38 | if use_feat: 39 | # Compute feature distance 40 | f_dist = cdist(feat, feat) 41 | # Compute adjacency 42 | A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 ) 43 | else: 44 | A = np.exp(- (c_dist/sigma(c_dist))**2) 45 | 46 | # Convert to symmetric matrix 47 | A = 0.5 * (A + A.T) 48 | A[np.diag_indices_from(A)] = 0 49 | return A 50 | 51 | 52 | def compute_edges_list(A, kth=8+1): 53 | # Get k-similar neighbor indices for each node 54 | 55 | num_nodes = A.shape[0] 56 | new_kth = num_nodes - kth 57 | 58 | if num_nodes > 9: 59 | knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1] 60 | knn_values = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1] # NEW 61 | else: 62 | # handling for graphs with less than kth nodes 63 | # in such cases, the resulting graph will be fully connected 64 | knns = np.tile(np.arange(num_nodes), num_nodes).reshape(num_nodes, num_nodes) 65 | knn_values = A # NEW 66 | 67 | # removing self loop 68 | if num_nodes != 1: 69 | knn_values = A[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) # NEW 70 | knns = knns[knns != np.arange(num_nodes)[:,None]].reshape(num_nodes,-1) 71 | return knns, knn_values # NEW 72 | 73 | 74 | class SuperPixDGL(torch.utils.data.Dataset): 75 | def __init__(self, 76 | data_dir, 77 | dataset, 78 | split, 79 | use_mean_px=True, 80 | use_coord=True): 81 | 82 | self.split = split 83 | 84 | self.graph_lists = [] 85 | 86 | if dataset == 'MNIST': 87 | self.img_size = 28 88 | with open(os.path.join(data_dir, 'mnist_75sp_%s.pkl' % split), 'rb') as f: 89 | self.labels, self.sp_data = pickle.load(f) 90 | self.graph_labels = torch.LongTensor(self.labels) 91 | elif dataset == 'CIFAR10': 92 | self.img_size = 32 93 | with open(os.path.join(data_dir, 'cifar10_150sp_%s.pkl' % split), 'rb') as f: 94 | self.labels, self.sp_data = pickle.load(f) 95 | self.graph_labels = torch.LongTensor(self.labels) 96 | 97 | self.use_mean_px = use_mean_px 98 | self.use_coord = use_coord 99 | self.n_samples = len(self.labels) 100 | 101 | self._prepare() 102 | 103 | def _prepare(self): 104 | print("preparing %d graphs for the %s set..." % (self.n_samples, self.split.upper())) 105 | self.Adj_matrices, self.node_features, self.edges_lists, self.edge_features = [], [], [], [] 106 | for index, sample in enumerate(self.sp_data): 107 | mean_px, coord = sample[:2] 108 | 109 | try: 110 | coord = coord / self.img_size 111 | except AttributeError: 112 | VOC_has_variable_image_sizes = True 113 | 114 | if self.use_mean_px: 115 | A = compute_adjacency_matrix_images(coord, mean_px) # using super-pixel locations + features 116 | else: 117 | A = compute_adjacency_matrix_images(coord, mean_px, False) # using only super-pixel locations 118 | edges_list, edge_values_list = compute_edges_list(A) # NEW 119 | 120 | N_nodes = A.shape[0] 121 | 122 | mean_px = mean_px.reshape(N_nodes, -1) 123 | coord = coord.reshape(N_nodes, 2) 124 | x = np.concatenate((mean_px, coord), axis=1) 125 | 126 | edge_values_list = edge_values_list.reshape(-1) # NEW # TO DOUBLE-CHECK ! 127 | 128 | self.node_features.append(x) 129 | self.edge_features.append(edge_values_list) # NEW 130 | self.Adj_matrices.append(A) 131 | self.edges_lists.append(edges_list) 132 | 133 | for index in range(len(self.sp_data)): 134 | g = dgl.DGLGraph() 135 | g.add_nodes(self.node_features[index].shape[0]) 136 | g.ndata['feat'] = torch.Tensor(self.node_features[index]).half() 137 | 138 | for src, dsts in enumerate(self.edges_lists[index]): 139 | # handling for 1 node where the self loop would be the only edge 140 | # since, VOC Superpixels has few samples (5 samples) with only 1 node 141 | if self.node_features[index].shape[0] == 1: 142 | g.add_edges(src, dsts) 143 | else: 144 | g.add_edges(src, dsts[dsts!=src]) 145 | 146 | # adding edge features for Residual Gated ConvNet 147 | edge_feat_dim = g.ndata['feat'].shape[1] # dim same as node feature dim 148 | #g.edata['feat'] = torch.ones(g.number_of_edges(), edge_feat_dim).half() 149 | g.edata['feat'] = torch.Tensor(self.edge_features[index]).unsqueeze(1).half() # NEW 150 | 151 | self.graph_lists.append(g) 152 | 153 | def __len__(self): 154 | """Return the number of graphs in the dataset.""" 155 | return self.n_samples 156 | 157 | def __getitem__(self, idx): 158 | """ 159 | Get the idx^th sample. 160 | Parameters 161 | --------- 162 | idx : int 163 | The sample index. 164 | Returns 165 | ------- 166 | (dgl.DGLGraph, int) 167 | DGLGraph with node feature stored in `feat` field 168 | And its label. 169 | """ 170 | return self.graph_lists[idx], self.graph_labels[idx] 171 | 172 | 173 | class DGLFormDataset(torch.utils.data.Dataset): 174 | """ 175 | DGLFormDataset wrapping graph list and label list as per pytorch Dataset. 176 | *lists (list): lists of 'graphs' and 'labels' with same len(). 177 | """ 178 | def __init__(self, *lists): 179 | assert all(len(lists[0]) == len(li) for li in lists) 180 | self.lists = lists 181 | self.graph_lists = lists[0] 182 | self.graph_labels = lists[1] 183 | 184 | def __getitem__(self, index): 185 | return tuple(li[index] for li in self.lists) 186 | 187 | def __len__(self): 188 | return len(self.lists[0]) 189 | 190 | 191 | class SuperPixDatasetDGL(torch.utils.data.Dataset): 192 | def __init__(self, name, num_val=5000): 193 | """ 194 | Takes input standard image dataset name (MNIST/CIFAR10) 195 | and returns the superpixels graph. 196 | 197 | This class uses results from the above SuperPix class. 198 | which contains the steps for the generation of the Superpixels 199 | graph from a superpixel .pkl file that has been given by 200 | https://github.com/bknyaz/graph_attention_pool 201 | 202 | Please refer the SuperPix class for details. 203 | """ 204 | t_data = time.time() 205 | self.name = name 206 | 207 | use_mean_px = True # using super-pixel locations + features 208 | use_mean_px = False # using only super-pixel locations 209 | if use_mean_px: 210 | print('Adj matrix defined from super-pixel locations + features') 211 | else: 212 | print('Adj matrix defined from super-pixel locations (only)') 213 | use_coord = True 214 | self.test = SuperPixDGL("./data/superpixels", dataset=self.name, split='test', 215 | use_mean_px=use_mean_px, 216 | use_coord=use_coord) 217 | 218 | self.train_ = SuperPixDGL("./data/superpixels", dataset=self.name, split='train', 219 | use_mean_px=use_mean_px, 220 | use_coord=use_coord) 221 | 222 | _val_graphs, _val_labels = self.train_[:num_val] 223 | _train_graphs, _train_labels = self.train_[num_val:] 224 | 225 | self.val = DGLFormDataset(_val_graphs, _val_labels) 226 | self.train = DGLFormDataset(_train_graphs, _train_labels) 227 | 228 | print("[I] Data load time: {:.4f}s".format(time.time()-t_data)) 229 | 230 | 231 | 232 | def self_loop(g): 233 | """ 234 | Utility function only, to be used only when necessary as per user self_loop flag 235 | : Overwriting the function dgl.transform.add_self_loop() to not miss ndata['feat'] and edata['feat'] 236 | 237 | 238 | This function is called inside a function in SuperPixDataset class. 239 | """ 240 | new_g = dgl.DGLGraph() 241 | new_g.add_nodes(g.number_of_nodes()) 242 | new_g.ndata['feat'] = g.ndata['feat'] 243 | 244 | src, dst = g.all_edges(order="eid") 245 | src = dgl.backend.zerocopy_to_numpy(src) 246 | dst = dgl.backend.zerocopy_to_numpy(dst) 247 | non_self_edges_idx = src != dst 248 | nodes = np.arange(g.number_of_nodes()) 249 | new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx]) 250 | new_g.add_edges(nodes, nodes) 251 | 252 | # This new edata is not used since this function gets called only for GCN, GAT 253 | # However, we need this for the generic requirement of ndata and edata 254 | new_g.edata['feat'] = torch.zeros(new_g.number_of_edges()) 255 | return new_g 256 | 257 | def positional_encoding(g, pos_enc_dim): 258 | """ 259 | Graph positional encoding v/ Laplacian eigenvectors 260 | """ 261 | 262 | # Laplacian 263 | A = g.adjacency_matrix_scipy(return_edge_ids=False).astype(float) 264 | N = sp.diags(dgl.backend.asnumpy(g.in_degrees()).clip(1) ** -0.5, dtype=float) 265 | L = sp.eye(g.number_of_nodes()) - N * A * N 266 | 267 | # Eigenvectors with numpy 268 | EigVal, EigVec = np.linalg.eig(L.toarray()) 269 | idx = EigVal.argsort() # increasing order 270 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:,idx]) 271 | g.ndata['pos_enc'] = torch.from_numpy(EigVec[:,1:pos_enc_dim+1]).float() 272 | 273 | # # Eigenvectors with scipy 274 | # EigVal, EigVec = sp.linalg.eigs(L, k=pos_enc_dim+1, which='SR') 275 | # EigVec = EigVec[:, EigVal.argsort()] # increasing order 276 | # g.ndata['pos_enc'] = torch.from_numpy(np.abs(EigVec[:,1:pos_enc_dim+1])).float() 277 | 278 | return g 279 | 280 | class SuperPixDataset(torch.utils.data.Dataset): 281 | 282 | def __init__(self, name): 283 | """ 284 | Loading Superpixels datasets 285 | """ 286 | start = time.time() 287 | print("[I] Loading dataset %s..." % (name)) 288 | self.name = name 289 | data_dir = 'data/superpixels/' 290 | with open(data_dir+name+'.pkl',"rb") as f: 291 | f = pickle.load(f) 292 | self.train = f[0] 293 | self.val = f[1] 294 | self.test = f[2] 295 | print('train, test, val sizes :',len(self.train),len(self.test),len(self.val)) 296 | print("[I] Finished loading.") 297 | print("[I] Data load time: {:.4f}s".format(time.time()-start)) 298 | 299 | 300 | # form a mini batch from a given list of samples = [(graph, label) pairs] 301 | def collate(self, samples): 302 | # The input samples is a list of pairs (graph, label). 303 | graphs, labels = map(list, zip(*samples)) 304 | labels = torch.tensor(np.array(labels)) 305 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 306 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 307 | #snorm_n = torch.cat(tab_snorm_n).sqrt() 308 | #tab_sizes_e = [ graphs[i].number_of_edges() for i in range(len(graphs))] 309 | #tab_snorm_e = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_e ] 310 | #snorm_e = torch.cat(tab_snorm_e).sqrt() 311 | for idx, graph in enumerate(graphs): 312 | graphs[idx].ndata['feat'] = graph.ndata['feat'].float() 313 | graphs[idx].edata['feat'] = graph.edata['feat'].float() 314 | batched_graph = dgl.batch(graphs) 315 | 316 | return batched_graph, labels 317 | 318 | 319 | # prepare dense tensors for GNNs using them; such as RingGNN, 3WLGNN 320 | def collate_dense_gnn(self, samples): 321 | # The input samples is a list of pairs (graph, label). 322 | graphs, labels = map(list, zip(*samples)) 323 | labels = torch.tensor(np.array(labels)) 324 | #tab_sizes_n = [ graphs[i].number_of_nodes() for i in range(len(graphs))] 325 | #tab_snorm_n = [ torch.FloatTensor(size,1).fill_(1./float(size)) for size in tab_sizes_n ] 326 | #snorm_n = tab_snorm_n[0][0].sqrt() 327 | 328 | #batched_graph = dgl.batch(graphs) 329 | 330 | g = graphs[0] 331 | adj = self._sym_normalize_adj(g.adjacency_matrix().to_dense()) 332 | """ 333 | Adapted from https://github.com/leichen2018/Ring-GNN/ 334 | Assigning node and edge feats:: 335 | we have the adjacency matrix in R^{n x n}, the node features in R^{d_n} and edge features R^{d_e}. 336 | Then we build a zero-initialized tensor, say T, in R^{(1 + d_n + d_e) x n x n}. T[0, :, :] is the adjacency matrix. 337 | The diagonal T[1:1+d_n, i, i], i = 0 to n-1, store the node feature of node i. 338 | The off diagonal T[1+d_n:, i, j] store edge features of edge(i, j). 339 | """ 340 | 341 | zero_adj = torch.zeros_like(adj) 342 | 343 | in_dim = g.ndata['feat'].shape[1] 344 | 345 | # use node feats to prepare adj 346 | adj_node_feat = torch.stack([zero_adj for j in range(in_dim)]) 347 | adj_node_feat = torch.cat([adj.unsqueeze(0), adj_node_feat], dim=0) 348 | 349 | for node, node_feat in enumerate(g.ndata['feat']): 350 | adj_node_feat[1:, node, node] = node_feat 351 | 352 | x_node_feat = adj_node_feat.unsqueeze(0) 353 | 354 | return x_node_feat, labels 355 | 356 | def _sym_normalize_adj(self, adj): 357 | deg = torch.sum(adj, dim = 0)#.squeeze() 358 | deg_inv = torch.where(deg>0, 1./torch.sqrt(deg), torch.zeros(deg.size())) 359 | deg_inv = torch.diag(deg_inv) 360 | return torch.mm(deg_inv, torch.mm(adj, deg_inv)) 361 | 362 | 363 | def _add_self_loops(self): 364 | 365 | # function for adding self loops 366 | # this function will be called only if self_loop flag is True 367 | 368 | self.train.graph_lists = [self_loop(g) for g in self.train.graph_lists] 369 | self.val.graph_lists = [self_loop(g) for g in self.val.graph_lists] 370 | self.test.graph_lists = [self_loop(g) for g in self.test.graph_lists] 371 | 372 | self.train = DGLFormDataset(self.train.graph_lists, self.train.graph_labels) 373 | self.val = DGLFormDataset(self.val.graph_lists, self.val.graph_labels) 374 | self.test = DGLFormDataset(self.test.graph_lists, self.test.graph_labels) 375 | 376 | 377 | def _add_positional_encodings(self, pos_enc_dim): 378 | 379 | # Graph positional encoding v/ Laplacian eigenvectors 380 | self.train.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.train.graph_lists] 381 | self.val.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.val.graph_lists] 382 | self.test.graph_lists = [positional_encoding(g, pos_enc_dim) for g in self.test.graph_lists] 383 | 384 | if __name__ == '__main__': 385 | ds = SuperPixDataset('MNIST') 386 | print(ds.train.__getitem__(0)) 387 | ds = SuperPixDataset('CIFAR10') 388 | print(ds.train.__getitem__(0)) 389 | 390 | -------------------------------------------------------------------------------- /environment_gpu.yml: -------------------------------------------------------------------------------- 1 | name: gnasmp 2 | channels: 3 | - defaults 4 | - dglteam 5 | dependencies: 6 | - python=3.6 7 | - cudatoolkit=11.0 8 | - dgl-cuda11.0 9 | - pygraphviz 10 | - pip: 11 | - ipdb 12 | - tqdm 13 | - torch==1.5.0 14 | - scikit-learn 15 | - tensorboardX 16 | - rich 17 | -------------------------------------------------------------------------------- /example_geno.yaml: -------------------------------------------------------------------------------- 1 | Genotype: 2 | - 3 | id: 1 4 | topology: 5 | - 6 | src: 0 7 | dst: 1 8 | ops: 'V_Max' 9 | - 10 | src: 1 11 | dst: 2 12 | ops: 'V_Sum' 13 | - 14 | src: 2 15 | dst: 3 16 | ops: 'V_Sum' 17 | - 18 | src: 3 19 | dst: 4 20 | ops: 'V_Max' 21 | - 22 | id: 2 23 | topology: 24 | - 25 | src: 0 26 | dst: 1 27 | ops: 'V_Max' 28 | - 29 | src: 1 30 | dst: 2 31 | ops: 'V_Sum' 32 | - 33 | src: 2 34 | dst: 3 35 | ops: 'V_Sum' 36 | - 37 | src: 3 38 | dst: 4 39 | ops: 'V_Max' 40 | - 41 | id: 3 42 | topology: 43 | - 44 | src: 0 45 | dst: 1 46 | ops: 'V_Max' 47 | - 48 | src: 1 49 | dst: 2 50 | ops: 'V_Sum' 51 | - 52 | src: 2 53 | dst: 3 54 | ops: 'V_Sum' 55 | - 56 | src: 3 57 | dst: 4 58 | ops: 'V_Max' 59 | - 60 | id: 4 61 | topology: 62 | - 63 | src: 0 64 | dst: 1 65 | ops: 'V_Max' 66 | - 67 | src: 1 68 | dst: 2 69 | ops: 'V_Sum' 70 | - 71 | src: 2 72 | dst: 3 73 | ops: 'V_Sum' 74 | - 75 | src: 3 76 | dst: 4 77 | ops: 'V_Max' -------------------------------------------------------------------------------- /models/architect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | def _concat(xs): 8 | return torch.cat([x.view(-1) for x in xs]) 9 | 10 | 11 | class Architect(object): 12 | 13 | def __init__(self, model, args): 14 | self.args = args 15 | self.network_momentum = args.momentum 16 | self.network_weight_decay = args.weight_decay 17 | self.model = model 18 | self.optimizer = torch.optim.Adam( 19 | params = self.model.arch_parameters(), 20 | lr = args.arch_lr, 21 | betas = (0.5, 0.999), 22 | weight_decay = args.arch_weight_decay 23 | ) 24 | 25 | def _compute_unrolled_model(self, input, target, eta, network_optimizer): 26 | loss = self.model._loss(input, target) 27 | theta = _concat(self.model.parameters()).data 28 | try: 29 | moment = _concat(network_optimizer.state[v]['momentum_buffer'] 30 | for v in self.model.parameters()).mul_(self.network_momentum) 31 | except: 32 | moment = torch.zeros_like(theta) 33 | dtheta = _concat(torch.autograd.grad( 34 | outputs = loss, 35 | inputs = self.model.parameters()) 36 | ).data + self.network_weight_decay*theta 37 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta)) 38 | return unrolled_model 39 | 40 | def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled): 41 | self.optimizer.zero_grad() 42 | if unrolled: 43 | self._backward_step_unrolled( 44 | input_train, target_train, input_valid, target_valid, eta, network_optimizer) 45 | else: 46 | if self.args.search_mode == 'darts_1': 47 | self._backward_step(input_valid, target_valid) 48 | elif self.args.search_mode == 'train': 49 | self._backward_step(input_train, target_train) 50 | self.optimizer.step() 51 | 52 | def _backward_step(self, input_valid, target_valid): 53 | loss = self.model._loss(input_valid, target_valid) 54 | loss.backward() 55 | 56 | def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer): 57 | unrolled_model = self._compute_unrolled_model( 58 | input_train, target_train, eta, network_optimizer) 59 | unrolled_loss = unrolled_model._loss(input_valid, target_valid) 60 | 61 | unrolled_loss.backward() 62 | dalpha = [v.grad for v in unrolled_model.arch_parameters()] 63 | vector = [v.grad.data for v in unrolled_model.parameters()] 64 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train) 65 | 66 | for g, ig in zip(dalpha, implicit_grads): 67 | g.data.sub_(eta, ig.data) 68 | 69 | for v, g in zip(self.model.arch_parameters(), dalpha): 70 | if v.grad is None: 71 | v.grad = Variable(g.data) 72 | else: 73 | v.grad.data.copy_(g.data) 74 | 75 | def _construct_model_from_theta(self, theta): 76 | model_new = self.model.new() 77 | model_dict = self.model.state_dict() 78 | 79 | params, offset = {}, 0 80 | for k, v in self.model.named_parameters(): 81 | v_length = np.prod(v.size()) 82 | params[k] = theta[offset: offset+v_length].view(v.size()) 83 | offset += v_length 84 | 85 | assert offset == len(theta) 86 | model_dict.update(params) 87 | model_new.load_state_dict(model_dict) 88 | if not model_new.args.disable_cuda: 89 | model_new.cuda = model_new.cuda() 90 | return model_new 91 | 92 | def _hessian_vector_product(self, vector, input, target, r=1e-2): 93 | R = r / _concat(vector).norm() 94 | for p, v in zip(self.model.parameters(), vector): 95 | p.data.add_(R, v) 96 | loss = self.model._loss(input, target) 97 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) 98 | 99 | for p, v in zip(self.model.parameters(), vector): 100 | p.data.sub_(2*R, v) 101 | loss = self.model._loss(input, target) 102 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) 103 | 104 | for p, v in zip(self.model.parameters(), vector): 105 | p.data.add_(R, v) 106 | 107 | return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] 108 | -------------------------------------------------------------------------------- /models/cell_search.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from models.operations import OPS 7 | from models.mixed import Mixed 8 | 9 | ''' 10 | cell_arch : 11 | topology: list 12 | (src, dst, weights, ops) 13 | ''' 14 | 15 | class Cell(nn.Module): 16 | 17 | def __init__(self, args, cell_arch): 18 | super().__init__() 19 | self.args = args 20 | self.nb_nodes = args.nb_nodes*3 #! warning 21 | self.cell_arch = cell_arch 22 | self.trans_concat_V = nn.Linear(self.nb_nodes*args.node_dim, args.node_dim, bias = True) 23 | self.batchnorm_V = nn.BatchNorm1d(args.node_dim) 24 | self.activate = nn.LeakyReLU(args.leaky_slope) 25 | self.load_arch() 26 | 27 | 28 | def load_arch(self): 29 | link_para = {} 30 | link_dict = {} 31 | for src, dst, w, ops in self.cell_arch: 32 | if dst not in link_dict: 33 | link_dict[dst] = [] 34 | link_dict[dst].append((src, w)) 35 | link_para[str((src, dst))] = Mixed(self.args, ops) 36 | 37 | self.link_dict = link_dict 38 | self.link_para = nn.ModuleDict(link_para) 39 | 40 | 41 | def forward(self, input, weight): 42 | G, V_in = input['G'], input['V'] 43 | link_para = self.link_para 44 | link_dict = self.link_dict 45 | states = [V_in] 46 | for dst in range(1, self.nb_nodes + 1): 47 | tmp_states = [] 48 | for src, w in link_dict[dst]: 49 | sub_input = {'G': G, 'V': states[src], 'V_in': V_in} 50 | tmp_states.append(link_para[str((src, dst))](sub_input, weight[w])) 51 | states.append(sum(tmp_states)) 52 | 53 | V = self.trans_concat_V(torch.cat(states[1:], dim = 1)) 54 | 55 | if self.batchnorm_V: 56 | V = self.batchnorm_V(V) 57 | 58 | V = self.activate(V) 59 | V = F.dropout(V, self.args.dropout, training = self.training) 60 | V = V + V_in 61 | return {'G': G, 'V': V} 62 | 63 | -------------------------------------------------------------------------------- /models/cell_train.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from models.operations import V_Package, OPS 7 | 8 | 9 | class Cell(nn.Module): 10 | 11 | def __init__(self, args, genotype): 12 | 13 | super().__init__() 14 | self.args = args 15 | self.nb_nodes = args.nb_nodes 16 | self.genotype = genotype 17 | self.trans_concat_V = nn.Linear(args.nb_nodes * args.node_dim, args.node_dim, bias = True) 18 | self.batchnorm_V = nn.BatchNorm1d(args.node_dim) 19 | self.activate = nn.LeakyReLU(args.leaky_slope) 20 | self.load_genotype() 21 | 22 | 23 | def load_genotype(self): 24 | geno = self.genotype 25 | link_dict = {} 26 | module_dict = {} 27 | for edge in geno['topology']: 28 | src, dst, ops = edge['src'], edge['dst'], edge['ops'] 29 | dst = f'{dst}' 30 | 31 | if dst not in link_dict: 32 | link_dict[dst] = [] 33 | link_dict[dst].append(src) 34 | 35 | if dst not in module_dict: 36 | module_dict[dst] = nn.ModuleList([]) 37 | module_dict[dst].append(V_Package(self.args, OPS[ops](self.args))) 38 | 39 | self.link_dict = link_dict 40 | self.module_dict = nn.ModuleDict(module_dict) 41 | 42 | 43 | def forward(self, input): 44 | 45 | G, V_in = input['G'], input['V'] 46 | states = [V_in] 47 | for dst in range(1, self.nb_nodes + 1): 48 | dst = f'{dst}' 49 | agg = [] 50 | for i, src in enumerate(self.link_dict[dst]): 51 | sub_input = {'G': G, 'V': states[src], 'V_in': V_in} 52 | agg.append(self.module_dict[dst][i](sub_input)) 53 | states.append(sum(agg)) 54 | 55 | V = self.trans_concat_V(torch.cat(states[1:], dim = 1)) 56 | 57 | if self.batchnorm_V: 58 | V = self.batchnorm_V(V) 59 | 60 | V = self.activate(V) 61 | V = F.dropout(V, self.args.dropout, training = self.training) 62 | V = V + V_in 63 | return { 'G' : G, 'V' : V } 64 | 65 | 66 | if __name__ == '__main__': 67 | import yaml 68 | from easydict import EasyDict as edict 69 | geno = yaml.load(open('example_geno.yaml', 'r')) 70 | geno = geno['Genotype'][0] 71 | args = edict({ 72 | 'nb_nodes': 4, 73 | 'node_dim': 50, 74 | 'leaky_slope': 0.2, 75 | 'batchnorm_op': True, 76 | }) 77 | cell = Cell(args, geno) 78 | -------------------------------------------------------------------------------- /models/mixed.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from models.operations import V_Package, OPS 7 | 8 | 9 | class Mixed(nn.Module): 10 | 11 | def __init__(self, args, ops): 12 | super().__init__() 13 | self.args = args 14 | self.type = type 15 | self.ops = ops 16 | self.candidates = nn.ModuleDict({ 17 | name: V_Package(args, OPS[name](args)) 18 | for name in self.ops 19 | }) 20 | 21 | 22 | def forward(self, input, weight): 23 | ''' 24 | weight: a dict whose 'key' is operation name and 'val' is operation weight 25 | ''' 26 | weight = weight.softmax(0) 27 | output = sum( weight[i] * self.candidates[name](input) for i, name in enumerate(self.ops) ) 28 | # residual = input[1] if self.type == 'V' else input[2] 29 | return output # + residual * DecayScheduler().decay_rate 30 | 31 | -------------------------------------------------------------------------------- /models/model_search.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from models.cell_search import Cell 7 | from models.operations import OPS, First_Stage, Second_Stage, Third_Stage 8 | from models.networks import MLP 9 | from data import TransInput, TransOutput, get_trans_input 10 | 11 | 12 | class Model_Search(nn.Module): 13 | 14 | def __init__(self, args, trans_input_fn, loss_fn): 15 | super().__init__() 16 | self.args = args 17 | self.nb_layers = args.nb_layers 18 | self.cell_arch_topo = self.load_cell_arch() # obtain architecture topology 19 | self.cell_arch_para = self.init_cell_arch_para() # register architecture topology parameters 20 | self.cells = nn.ModuleList([Cell(args, self.cell_arch_topo[i]) for i in range(self.nb_layers)]) 21 | self.loss_fn = loss_fn 22 | self.trans_input_fn = trans_input_fn 23 | self.trans_input = TransInput(trans_input_fn) 24 | self.trans_output = TransOutput(args) 25 | if args.pos_encode > 0: 26 | self.position_encoding = nn.Linear(args.pos_encode, args.node_dim) 27 | 28 | 29 | def forward(self, input): 30 | arch_para_dict = self.group_arch_parameters() 31 | input = self.trans_input(input) 32 | G, V = input['G'], input['V'] 33 | if self.args.pos_encode > 0: 34 | V = V + self.position_encoding(G.ndata['pos_enc'].float().cuda()) 35 | output = {'G': G, 'V': V} 36 | for i, cell in enumerate(self.cells): 37 | output = cell(output, arch_para_dict[i]) 38 | output = self.trans_output(output) 39 | return output 40 | 41 | 42 | def load_cell_arch(self): 43 | cell_arch_topo = [] 44 | for _ in range(self.nb_layers): 45 | arch_topo = self.load_cell_arch_by_layer() 46 | cell_arch_topo.append(arch_topo) 47 | return cell_arch_topo 48 | 49 | 50 | def load_cell_arch_by_layer(self): 51 | arch_topo = [] 52 | w = 0 53 | for dst in range(1, self.args.nb_nodes+1): 54 | for src in range(dst): 55 | arch_topo.append((src, dst, w, First_Stage)) 56 | w += 1 57 | for dst in range(self.args.nb_nodes+1, 2*self.args.nb_nodes+1): 58 | src = dst - self.args.nb_nodes 59 | arch_topo.append((src, dst, w, Second_Stage)) 60 | w += 1 61 | for dst in range(2*self.args.nb_nodes+1, 3*self.args.nb_nodes+1): 62 | for src in range(self.args.nb_nodes+1, 2*self.args.nb_nodes+1): 63 | arch_topo.append((src, dst, w, Third_Stage)) 64 | w += 1 65 | return arch_topo 66 | 67 | 68 | def init_cell_arch_para(self): 69 | cell_arch_para = [] 70 | for i_layer in range(self.nb_layers): 71 | arch_para = self.init_arch_para(self.cell_arch_topo[i_layer]) 72 | cell_arch_para.extend(arch_para) 73 | self.nb_cell_topo = len(arch_para) 74 | return cell_arch_para 75 | 76 | 77 | def init_arch_para(self, arch_topo): 78 | arch_para = [] 79 | for src, dst, w, ops in arch_topo: 80 | arch_para.append(Variable(1e-3 * torch.rand(len(ops)).cuda(), requires_grad = True)) 81 | return arch_para 82 | 83 | 84 | def group_arch_parameters(self): 85 | group = [] 86 | start = 0 87 | for _ in range(self.nb_layers): 88 | group.append(self.arch_parameters()[start: start + self.nb_cell_topo]) 89 | start += self.nb_cell_topo 90 | return group 91 | 92 | 93 | # def load_cell_arch_by_layer(self): 94 | # arch_type_dict = [] 95 | # w = 0 96 | # for dst in range(1, self.args.nb_nodes + 1): 97 | # for src in range(dst): 98 | # arch_type_dict.append((src, dst, w)) 99 | # w += 1 100 | # return arch_type_dict 101 | 102 | 103 | # def init_cell_arch_para(self): 104 | # cell_arch_para = [] 105 | # for i_layer in range(self.nb_layers): 106 | # cell_arch_para.append(self.init_arch_para(self.cell_arch_topo[i_layer])) 107 | # return cell_arch_para 108 | 109 | 110 | # def init_arch_para(self, arch_topo): 111 | # arch_para = Variable(1e-3 * torch.rand(len(arch_topo), len(OPS)).cuda(), requires_grad = True) 112 | # return arch_para 113 | 114 | 115 | def new(self): 116 | model_new = Model_Search(self.args, get_trans_input(self.args), self.loss_fn).cuda() 117 | for x, y in zip(model_new.arch_parameters(), self.arch_parameters()): 118 | x.data.copy_(y.data) 119 | return model_new 120 | 121 | def load_alpha(self, alphas): 122 | for x, y in zip(self.arch_parameters(), alphas): 123 | x.data.copy_(y.data) 124 | 125 | def arch_parameters(self): 126 | return self.cell_arch_para 127 | 128 | def _loss(self, input, targets): 129 | scores = self.forward(input) 130 | return self.loss_fn(scores, targets) -------------------------------------------------------------------------------- /models/model_train.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from models.cell_train import Cell 7 | from models.operations import OPS 8 | from models.networks import MLP 9 | from data import TransInput, TransOutput 10 | 11 | 12 | class Model_Train(nn.Module): 13 | 14 | def __init__(self, args, genotypes, trans_input_fn, loss_fn): 15 | super().__init__() 16 | self.args = args 17 | self.nb_layers = args.nb_layers 18 | self.genotypes = genotypes 19 | self.cells = nn.ModuleList([Cell(args, genotypes[i]) for i in range(self.nb_layers)]) 20 | self.loss_fn = loss_fn 21 | self.trans_input = TransInput(trans_input_fn) 22 | self.trans_output = TransOutput(args) 23 | if args.pos_encode > 0: 24 | self.position_encoding = nn.Linear(args.pos_encode, args.node_dim) 25 | 26 | 27 | def forward(self, input): 28 | input = self.trans_input(input) 29 | G, V = input['G'], input['V'] 30 | if self.args.pos_encode > 0: 31 | V = V + self.position_encoding(G.ndata['pos_enc'].float().to("cuda")) 32 | output = {'G': G, 'V': V} 33 | for cell in self.cells: 34 | output = cell(output) 35 | output = self.trans_output(output) 36 | return output 37 | 38 | 39 | def _loss(self, input, targets): 40 | scores = self.forward(input) 41 | return self.loss_fn(scores, targets) -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MLP(nn.Module): 6 | def __init__(self, channel_sequence): 7 | super().__init__() 8 | nb_layers = len(channel_sequence) - 1 9 | self.seq = nn.Sequential() 10 | for i in range(nb_layers): 11 | self.seq.add_module(f"fc{i}", nn.Linear(channel_sequence[i], channel_sequence[i + 1])) 12 | if i != nb_layers - 1: 13 | self.seq.add_module(f"ReLU{i}", nn.ReLU(inplace=True)) 14 | 15 | def forward(self, x): 16 | out = self.seq(x) 17 | return out 18 | 19 | 20 | -------------------------------------------------------------------------------- /models/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dgl.function as fn 3 | import torch.nn as nn 4 | import numpy as np 5 | # from models.networks import * 6 | 7 | 8 | OPS = { 9 | 'V_None' : lambda args: V_None(args), 10 | 'V_I' : lambda args: V_I(args), 11 | 'V_Max' : lambda args: V_Max(args), 12 | 'V_Mean' : lambda args: V_Mean(args), 13 | 'V_Min' : lambda args: V_Min(args), 14 | 'V_Sum' : lambda args: V_Sum(args), 15 | 'V_Sparse': lambda args: V_Sparse(args), 16 | 'V_Dense' : lambda args: V_Dense(args), 17 | } 18 | 19 | 20 | First_Stage = ['V_None', 'V_I', 'V_Sparse', 'V_Dense'] 21 | Second_Stage = ['V_I', 'V_Mean', 'V_Sum', 'V_Max'] 22 | Third_Stage = ['V_None', 'V_I', 'V_Sparse', 'V_Dense'] 23 | 24 | 25 | class V_Package(nn.Module): 26 | 27 | def __init__(self, args, operation): 28 | 29 | super().__init__() 30 | self.args = args 31 | self.operation = operation 32 | if type(operation) in [V_None, V_I]: 33 | self.seq = None 34 | else: 35 | self.seq = nn.Sequential() 36 | self.seq.add_module('fc_bn', nn.Linear(args.node_dim, args.node_dim, bias = True)) 37 | if args.batchnorm_op: 38 | self.seq.add_module('bn', nn.BatchNorm1d(self.args.node_dim)) 39 | self.seq.add_module('act', nn.ReLU()) 40 | 41 | 42 | def forward(self, input): 43 | V = self.operation(input) 44 | if self.seq: 45 | V = self.seq(V) 46 | return V 47 | 48 | 49 | class NodePooling(nn.Module): 50 | 51 | def __init__(self, args): 52 | super().__init__() 53 | self.A = nn.Linear(args.node_dim, args.node_dim) 54 | # self.B = nn.Linear(args.node_dim, args.node_dim) 55 | self.activate = nn.ReLU() 56 | 57 | def forward(self, V): 58 | V = self.A(V) 59 | V = self.activate(V) 60 | # V = self.B(V) 61 | return V 62 | 63 | 64 | class V_None(nn.Module): 65 | 66 | def __init__(self, args): 67 | super().__init__() 68 | 69 | def forward(self, input): 70 | V = input['V'] 71 | return 0. * V 72 | 73 | 74 | class V_I(nn.Module): 75 | 76 | def __init__(self, args): 77 | super().__init__() 78 | 79 | def forward(self, input): 80 | V = input['V'] 81 | return V 82 | 83 | 84 | class V_Max(nn.Module): 85 | 86 | def __init__(self, args): 87 | super().__init__() 88 | self.pooling = NodePooling(args) 89 | 90 | def forward(self, input): 91 | G, V = input['G'], input['V'] 92 | # G.ndata['V'] = V 93 | G.ndata['V'] = self.pooling(V) 94 | G.update_all(fn.copy_u('V', 'M'), fn.max('M', 'V')) 95 | return G.ndata['V'] 96 | 97 | 98 | class V_Mean(nn.Module): 99 | 100 | def __init__(self, args): 101 | super().__init__() 102 | self.pooling = NodePooling(args) 103 | 104 | def forward(self, input): 105 | G, V = input['G'], input['V'] 106 | # G.ndata['V'] = V 107 | G.ndata['V'] = self.pooling(V) 108 | G.update_all(fn.copy_u('V', 'M'), fn.mean('M', 'V')) 109 | return G.ndata['V'] 110 | 111 | 112 | class V_Sum(nn.Module): 113 | 114 | def __init__(self, args): 115 | super().__init__() 116 | self.pooling = NodePooling(args) 117 | 118 | def forward(self, input): 119 | G, V = input['G'], input['V'] 120 | # G.ndata['V'] = self.pooling(V) 121 | G.ndata['V'] = V 122 | G.update_all(fn.copy_u('V', 'M'), fn.sum('M', 'V')) 123 | return G.ndata['V'] 124 | 125 | 126 | class V_Min(nn.Module): 127 | 128 | def __init__(self, args): 129 | super().__init__() 130 | self.pooling = NodePooling(args) 131 | 132 | def forward(self, input): 133 | G, V = input['G'], input['V'] 134 | G.ndata['V'] = self.pooling(V) 135 | G.update_all(fn.copy_u('V', 'M'), fn.min('M', 'V')) 136 | return G.ndata['V'] 137 | 138 | 139 | class V_Dense(nn.Module): 140 | 141 | def __init__(self, args): 142 | super().__init__() 143 | self.W = nn.Linear(args.node_dim*2, args.node_dim, bias = True) 144 | 145 | def forward(self, input): 146 | V, V_in = input['V'], input['V_in'] 147 | gates = torch.cat([V, V_in], dim = 1) 148 | gates = self.W(gates) 149 | return torch.sigmoid(gates) * V 150 | 151 | 152 | class V_Sparse(nn.Module): 153 | 154 | def __init__(self, args): 155 | super().__init__() 156 | self.W = nn.Linear(args.node_dim*2, args.node_dim, bias = True) 157 | self.a = nn.Linear(args.node_dim, 1, bias = False) 158 | 159 | def forward(self, input): 160 | V, V_in = input['V'], input['V_in'] 161 | gates = torch.cat([V, V_in], dim = 1) 162 | # gates = self.W(gates) 163 | gates = torch.relu(self.W(gates)) 164 | gates = self.a(gates) 165 | return torch.sigmoid(gates) * V 166 | 167 | 168 | if __name__ == '__main__': 169 | print("test") -------------------------------------------------------------------------------- /scripts/search_molecules_zinc.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | 3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \ 4 | --task 'graph_level' \ 5 | --data 'ZINC' \ 6 | --data_clip 1.0 \ 7 | --in_dim_V 28 \ 8 | --batch 64 \ 9 | --epochs 40 \ 10 | --node_dim 60 \ 11 | --nb_layers 12 \ 12 | --nb_nodes 3 \ 13 | --portion 0.9 \ 14 | --dropout 0.0 \ 15 | --pos_encode 0 \ 16 | --nb_workers 0 \ 17 | --report_freq 1 \ 18 | --arch_save 'archs/folder5' \ 19 | --search_mode 'train' \ 20 | --batchnorm_op 21 | 22 | -------------------------------------------------------------------------------- /scripts/search_sbms_cluster.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \ 3 | --task 'node_level' \ 4 | --data 'SBM_CLUSTER' \ 5 | --nb_classes 6 \ 6 | --data_clip 1.0 \ 7 | --in_dim_V 7 \ 8 | --batch 32 \ 9 | --node_dim 70 \ 10 | --pos_encode 0 \ 11 | --nb_layers 4 \ 12 | --nb_nodes 2 \ 13 | --dropout 0.2 \ 14 | --portion 0.5 \ 15 | --search_mode 'train' \ 16 | --nb_workers 0 \ 17 | --report_freq 1 \ 18 | --arch_save 'archs/folder5' \ 19 | --batchnorm_op 20 | -------------------------------------------------------------------------------- /scripts/search_sbms_pattern.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | 3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \ 4 | --task 'node_level' \ 5 | --data 'SBM_PATTERN' \ 6 | --nb_classes 2 \ 7 | --in_dim_V 3 \ 8 | --data_clip 0.5 \ 9 | --batch 8 \ 10 | --node_dim 50 \ 11 | --pos_encode 0 \ 12 | --nb_layers 4 \ 13 | --nb_nodes 3 \ 14 | --dropout 0.0 \ 15 | --portion 0.5 \ 16 | --nb_workers 0 \ 17 | --report_freq 1 \ 18 | --search_mode 'train' \ 19 | --arch_save 'archs/folder5' \ 20 | --batchnorm_op 21 | -------------------------------------------------------------------------------- /scripts/search_superpixels_cifar10.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | # no batchnorm + dropout 0.2 3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \ 4 | --task 'graph_level' \ 5 | --data 'CIFAR10' \ 6 | --nb_classes 10 \ 7 | --in_dim_V 5 \ 8 | --batch 64 \ 9 | --epochs 40 \ 10 | --node_dim 50 \ 11 | --portion 0.5 \ 12 | --nb_layers 1 \ 13 | --nb_nodes 4 \ 14 | --dropout 0.2 \ 15 | --nb_workers 0 \ 16 | --report_freq 1 \ 17 | --search_mode 'train' \ 18 | --arch_save 'archs/folder5' 19 | -------------------------------------------------------------------------------- /scripts/search_superpixels_mnist.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | # no batchnorm + dropout 0.2 3 | CUDA_VISIBLE_DEVICES=$DEVICES python search.py \ 4 | --task 'graph_level' \ 5 | --data 'MNIST' \ 6 | --nb_classes 10 \ 7 | --in_dim_V 3 \ 8 | --batch 64 \ 9 | --epochs 40 \ 10 | --node_dim 50 \ 11 | --portion 0.5 \ 12 | --nb_layers 4 \ 13 | --nb_nodes 3 \ 14 | --dropout 0.2 \ 15 | --nb_workers 0 \ 16 | --report_freq 1 \ 17 | --search_mode 'train' \ 18 | --arch_save 'archs/folder5' 19 | -------------------------------------------------------------------------------- /scripts/train_molecules_zinc.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | GENOTYPE=$2 3 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \ 4 | --task 'graph_level' \ 5 | --data 'ZINC' \ 6 | --in_dim_V 28 \ 7 | --batch 128 \ 8 | --node_dim 60 \ 9 | --dropout 0.0 \ 10 | --pos_encode 0 \ 11 | --batchnorm_op \ 12 | --epochs 200 \ 13 | --lr 1e-3 \ 14 | --weight_decay 0.0 \ 15 | --optimizer 'ADAM' \ 16 | --load_genotypes $GENOTYPE 17 | -------------------------------------------------------------------------------- /scripts/train_sbms_cluster.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | GENOTYPE=$2 3 | 4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \ 5 | --task 'node_level' \ 6 | --data 'SBM_CLUSTER' \ 7 | --nb_classes 6 \ 8 | --in_dim_V 7 \ 9 | --pos_encode 0 \ 10 | --batch 64 \ 11 | --node_dim 70 \ 12 | --dropout 0.2 \ 13 | --batchnorm_op \ 14 | --epochs 200 \ 15 | --lr 1e-3 \ 16 | --weight_decay 0.0 \ 17 | --optimizer 'ADAM' \ 18 | --load_genotypes $GENOTYPE -------------------------------------------------------------------------------- /scripts/train_sbms_pattern.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | GENOTYPE=$2 3 | 4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \ 5 | --task 'node_level' \ 6 | --data 'SBM_PATTERN' \ 7 | --nb_classes 2 \ 8 | --in_dim_V 3 \ 9 | --batch 64 \ 10 | --node_dim 70 \ 11 | --pos_encode 0 \ 12 | --dropout 0.2 \ 13 | --batchnorm_op \ 14 | --epochs 200 \ 15 | --lr 1e-3 \ 16 | --weight_decay 0.0 \ 17 | --optimizer 'ADAM' \ 18 | --load_genotypes $GENOTYPE -------------------------------------------------------------------------------- /scripts/train_superpixels_cifar10.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | GENOTYPE=$2 3 | # no batchnorm + dropout 0.2 4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \ 5 | --task 'graph_level' \ 6 | --data 'CIFAR10' \ 7 | --nb_classes 10 \ 8 | --in_dim_V 5 \ 9 | --batch 64 \ 10 | --node_dim 50 \ 11 | --pos_encode 0 \ 12 | --dropout 0.2 \ 13 | --epochs 200 \ 14 | --lr 1e-3 \ 15 | --weight_decay 0.0 \ 16 | --optimizer 'ADAM' \ 17 | --load_genotypes $GENOTYPE -------------------------------------------------------------------------------- /scripts/train_superpixels_mnist.sh: -------------------------------------------------------------------------------- 1 | DEVICES=$1 2 | GENOTYPE=$2 3 | # no batchnorm + dropout 0.2 4 | CUDA_VISIBLE_DEVICES=$DEVICES python train.py \ 5 | --task 'graph_level' \ 6 | --data 'MNIST' \ 7 | --nb_classes 10 \ 8 | --in_dim_V 3 \ 9 | --batch 64 \ 10 | --node_dim 50 \ 11 | --pos_encode 0 \ 12 | --dropout 0.2 \ 13 | --epochs 200 \ 14 | --lr 1e-3 \ 15 | --weight_decay 0.0 \ 16 | --optimizer 'ADAM' \ 17 | --load_genotypes $GENOTYPE -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import dgl 4 | import yaml 5 | import torch 6 | import argparse 7 | import numpy as np 8 | import torch.backends.cudnn as cudnn 9 | from tqdm import tqdm 10 | from data import * 11 | from models.model_search import * 12 | from utils.utils import * 13 | from models.architect import Architect 14 | 15 | 16 | class Searcher(object): 17 | 18 | def __init__(self, args): 19 | 20 | self.args = args 21 | self.console = Console() 22 | 23 | self.console.log('=> [1] Initial settings') 24 | np.random.seed(args.seed) 25 | torch.manual_seed(args.seed) 26 | torch.cuda.manual_seed(args.seed) 27 | cudnn.benchmark = True 28 | cudnn.enabled = True 29 | 30 | self.console.log('=> [2] Initial models') 31 | self.metric = load_metric(args) 32 | self.loss_fn = get_loss_fn(args).cuda() 33 | self.model = Model_Search(args, get_trans_input(args), self.loss_fn).cuda() 34 | self.console.log(f'=> Supernet Parameters: {count_parameters_in_MB(self.model)}', style = 'bold red') 35 | 36 | self.console.log(f'=> [3] Preparing dataset') 37 | self.dataset = load_data(args) 38 | if args.pos_encode > 0: 39 | #! add positional encoding 40 | self.console.log(f'==> [3.1] Adding positional encodings') 41 | self.dataset._add_positional_encodings(args.pos_encode) 42 | self.search_data = self.dataset.train 43 | self.val_data = self.dataset.val 44 | self.test_data = self.dataset.test 45 | self.load_dataloader() 46 | 47 | self.console.log(f'=> [4] Initial optimizer') 48 | self.optimizer = torch.optim.SGD( 49 | params = self.model.parameters(), 50 | lr = args.lr, 51 | momentum = args.momentum, 52 | weight_decay = args.weight_decay 53 | ) 54 | 55 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 56 | optimizer = self.optimizer, 57 | T_max = float(args.epochs), 58 | eta_min = args.lr_min 59 | ) 60 | 61 | self.architect = Architect(self.model, self.args) 62 | 63 | 64 | def load_dataloader(self): 65 | 66 | num_search = int(len(self.search_data) * self.args.data_clip) 67 | indices = list(range(num_search)) 68 | split = int(np.floor(self.args.portion * num_search)) 69 | self.console.log(f'=> Para set size: {split}, Arch set size: {num_search - split}') 70 | 71 | self.para_queue = torch.utils.data.DataLoader( 72 | dataset = self.search_data, 73 | batch_size = self.args.batch, 74 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), 75 | pin_memory = True, 76 | num_workers = self.args.nb_workers, 77 | collate_fn = self.dataset.collate 78 | ) 79 | 80 | self.arch_queue = torch.utils.data.DataLoader( 81 | dataset = self.search_data, 82 | batch_size = self.args.batch, 83 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]), 84 | pin_memory = True, 85 | num_workers = self.args.nb_workers, 86 | collate_fn = self.dataset.collate 87 | ) 88 | 89 | num_valid = int(len(self.val_data) * self.args.data_clip) 90 | indices = list(range(num_valid)) 91 | 92 | self.val_queue = torch.utils.data.DataLoader( 93 | dataset = self.val_data, 94 | batch_size = self.args.batch, 95 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices), 96 | pin_memory = True, 97 | num_workers = self.args.nb_workers, 98 | collate_fn = self.dataset.collate 99 | ) 100 | 101 | num_test = int(len(self.test_data) * self.args.data_clip) 102 | indices = list(range(num_test)) 103 | 104 | self.test_queue = torch.utils.data.DataLoader( 105 | dataset = self.test_data, 106 | batch_size = self.args.batch, 107 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices), 108 | pin_memory = True, 109 | num_workers = self.args.nb_workers, 110 | collate_fn = self.dataset.collate 111 | ) 112 | 113 | 114 | def run(self): 115 | 116 | self.console.log(f'=> [4] Search & Train') 117 | for i_epoch in range(self.args.epochs): 118 | self.scheduler.step() 119 | self.lr = self.scheduler.get_lr()[0] 120 | if i_epoch % self.args.report_freq == 0: 121 | geno = genotypes( 122 | args = self.args, 123 | arch_paras = self.model.group_arch_parameters(), 124 | arch_topos = self.model.cell_arch_topo, 125 | ) 126 | with open(f'{self.args.arch_save}/{self.args.data}/{i_epoch}.yaml', "w") as f: 127 | yaml.dump(geno, f) 128 | 129 | # => report genotype 130 | self.console.log( geno ) 131 | for i in range(self.args.nb_layers): 132 | for p in self.model.group_arch_parameters()[i]: 133 | self.console.log(p.softmax(0).detach().cpu().numpy()) 134 | 135 | search_result = self.search() 136 | self.console.log(f"[green]=> search result [{i_epoch}] - loss: {search_result['loss']:.4f} - metric : {search_result['metric']:.4f}",) 137 | # DecayScheduler().step(i_epoch) 138 | 139 | with torch.no_grad(): 140 | val_result = self.infer(self.val_queue) 141 | self.console.log(f"[yellow]=> valid result [{i_epoch}] - loss: {val_result['loss']:.4f} - metric : {val_result['metric']:.4f}") 142 | 143 | test_result = self.infer(self.test_queue) 144 | self.console.log(f"[red]=> test result [{i_epoch}] - loss: {test_result['loss']:.4f} - metric : {test_result['metric']:.4f}") 145 | 146 | 147 | def search(self): 148 | 149 | self.model.train() 150 | epoch_loss = 0 151 | epoch_metric = 0 152 | desc = '=> searching' 153 | device = torch.device('cuda') 154 | 155 | with tqdm(self.para_queue, desc = desc, leave = False) as t: 156 | for i_step, (batch_graphs, batch_targets) in enumerate(t): 157 | #! 1. preparing training datasets 158 | G = batch_graphs.to(device) 159 | V = batch_graphs.ndata['feat'].to(device) 160 | # E = batch_graphs.edata['feat'].to(device) 161 | batch_targets = batch_targets.to(device) 162 | #! 2. preparing validating datasets 163 | batch_graphs_search, batch_targets_search = next(iter(self.arch_queue)) 164 | GS = batch_graphs_search.to(device) 165 | VS = batch_graphs_search.ndata['feat'].to(device) 166 | # ES = batch_graphs_search.edata['feat'].to(device) 167 | batch_targets_search = batch_targets_search.to(device) 168 | #! 3. optimizing architecture topology parameters 169 | self.architect.step( 170 | input_train = {'G': G, 'V': V}, 171 | target_train = batch_targets, 172 | input_valid = {'G': GS, 'V': VS}, 173 | target_valid = batch_targets_search, 174 | eta = self.lr, 175 | network_optimizer = self.optimizer, 176 | unrolled = self.args.unrolled 177 | ) 178 | #! 4. optimizing model parameters 179 | self.optimizer.zero_grad() 180 | batch_scores = self.model({'G': G, 'V': V}) 181 | loss = self.loss_fn(batch_scores, batch_targets) 182 | loss.backward() 183 | self.optimizer.step() 184 | 185 | epoch_loss += loss.detach().item() 186 | epoch_metric += self.metric(batch_scores, batch_targets) 187 | t.set_postfix(lr = self.lr, 188 | loss = epoch_loss / (i_step + 1), 189 | metric = epoch_metric / (i_step + 1)) 190 | 191 | return {'loss' : epoch_loss / (i_step + 1), 192 | 'metric' : epoch_metric / (i_step + 1)} 193 | 194 | 195 | def infer(self, dataloader): 196 | 197 | self.model.eval() 198 | epoch_loss = 0 199 | epoch_metric = 0 200 | desc = '=> inferring' 201 | device = torch.device('cuda') 202 | 203 | with tqdm(dataloader, desc = desc, leave = False) as t: 204 | for i_step, (batch_graphs, batch_targets) in enumerate(t): 205 | G = batch_graphs.to(device) 206 | V = batch_graphs.ndata['feat'].to(device) 207 | # E = batch_graphs.edata['feat'].to(device) 208 | batch_targets = batch_targets.to(device) 209 | batch_scores = self.model({'G': G, 'V': V}) 210 | loss = self.loss_fn(batch_scores, batch_targets) 211 | 212 | epoch_loss += loss.detach().item() 213 | epoch_metric += self.metric(batch_scores, batch_targets) 214 | t.set_postfix(loss = epoch_loss / (i_step + 1), 215 | metric = epoch_metric / (i_step + 1)) 216 | 217 | return {'loss' : epoch_loss / (i_step + 1), 218 | 'metric' : epoch_metric / (i_step + 1)} 219 | 220 | 221 | if __name__ == '__main__': 222 | 223 | import warnings 224 | from rich.console import Console 225 | from rich.table import Table 226 | from rich.panel import Panel 227 | from rich.syntax import Syntax 228 | warnings.filterwarnings('ignore') 229 | 230 | parser = argparse.ArgumentParser('Rethinking Graph Neural Architecture Search From Message Passing') 231 | parser.add_argument('--task', type = str, default = 'graph_level') 232 | parser.add_argument('--data', type = str, default = 'ZINC') 233 | parser.add_argument('--in_dim_V', type = int, default = 28) 234 | parser.add_argument('--node_dim', type = int, default = 70) 235 | parser.add_argument('--nb_classes', type = int, default = 1) 236 | parser.add_argument('--nb_layers', type = int, default = 4) 237 | parser.add_argument('--nb_nodes', type = int, default = 3) 238 | parser.add_argument('--leaky_slope', type = float, default = 0.1) 239 | parser.add_argument('--batchnorm_op', action = 'store_true', default = False) 240 | parser.add_argument('--nb_mlp_layer', type = int, default = 4) 241 | parser.add_argument('--dropout', type = float, default = 0.0) 242 | parser.add_argument('--pos_encode', type = int, default = 0) 243 | 244 | parser.add_argument('--portion', type = float, default = 0.5) 245 | parser.add_argument('--data_clip', type = float, default = 1.0) 246 | parser.add_argument('--nb_workers', type = int, default = 0) 247 | parser.add_argument('--seed', type = int, default = 41) 248 | parser.add_argument('--epochs', type = int, default = 50) 249 | parser.add_argument('--batch', type = int, default = 64) 250 | parser.add_argument('--lr', type = float, default = 0.025) 251 | parser.add_argument('--lr_min', type = float, default = 0.001) 252 | parser.add_argument('--momentum', type = float, default = 0.9) 253 | parser.add_argument('--weight_decay', type = float, default = 3e-4) 254 | parser.add_argument('--unrolled', action = 'store_true', default = False) 255 | parser.add_argument('--search_mode', type = str, default = 'train') 256 | parser.add_argument('--arch_lr', type = float, default = 3e-4) 257 | parser.add_argument('--arch_weight_decay', type = float, default =1e-3) 258 | parser.add_argument('--report_freq', type = int, default = 1) 259 | parser.add_argument('--arch_save', type = str, default = './save_arch') 260 | 261 | console = Console() 262 | args = parser.parse_args() 263 | title = "[bold][red]Searching & Training" 264 | vis = "" 265 | for key, val in vars(args).items(): 266 | vis += f"{key}: {val}\n" 267 | vis = Syntax(vis[:-1], "yaml", theme="monokai", line_numbers=True) 268 | richPanel = Panel.fit(vis, title = title) 269 | console.print(richPanel) 270 | data_path = os.path.join(args.arch_save, args.data) 271 | if not os.path.exists(data_path): 272 | os.mkdir(data_path) 273 | with open(os.path.join(data_path, "configs.yaml"), "w") as f: 274 | yaml.dump(vars(args), f) 275 | Searcher(args).run() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import dgl 4 | import yaml 5 | import torch 6 | import argparse 7 | import numpy as np 8 | import torch.backends.cudnn as cudnn 9 | from tqdm import tqdm 10 | from data import * 11 | from models.model_train import * 12 | from utils.utils import * 13 | from tensorboardX import SummaryWriter 14 | from utils.record_utils import record_run 15 | 16 | 17 | class Trainer(object): 18 | 19 | def __init__(self, args): 20 | 21 | self.args = args 22 | self.console = Console() 23 | 24 | self.console.log('=> [0] Initial TensorboardX') 25 | self.writer = SummaryWriter(comment = f'Task: {args.task}, Data: {args.data}, Geno: {args.load_genotypes}') 26 | 27 | self.console.log('=> [1] Initial Settings') 28 | np.random.seed(args.seed) 29 | torch.manual_seed(args.seed) 30 | torch.cuda.manual_seed(args.seed) 31 | cudnn.enabled = True 32 | 33 | self.console.log('=> [2] Initial Models') 34 | if not os.path.isfile(args.load_genotypes): 35 | raise Exception('Genotype file not found!') 36 | else: 37 | with open(args.load_genotypes, "r") as f: 38 | genotypes = yaml.load(f) 39 | args.nb_layers = len(genotypes['Genotype']) 40 | args.nb_nodes = len({ edge['dst'] for edge in genotypes['Genotype'][0]['topology'] }) 41 | self.metric = load_metric(args) 42 | self.loss_fn = get_loss_fn(args).cuda() 43 | trans_input_fn = get_trans_input(args) 44 | self.model = Model_Train(args, genotypes['Genotype'], trans_input_fn, self.loss_fn).to("cuda") 45 | self.console.log(f'[red]=> Subnet Parameters: {count_parameters_in_MB(self.model)}') 46 | 47 | self.console.log(f'=> [3] Preparing Dataset') 48 | self.dataset = load_data(args) 49 | if args.pos_encode > 0: 50 | #! load - position encoding 51 | self.console.log(f'==> [3.1] Adding positional encodings') 52 | self.dataset._add_positional_encodings(args.pos_encode) 53 | self.train_data = self.dataset.train 54 | self.val_data = self.dataset.val 55 | self.test_data = self.dataset.test 56 | self.load_dataloader() 57 | 58 | self.console.log(f'=> [4] Initial Optimizers') 59 | if args.optimizer == 'SGD': 60 | self.optimizer = torch.optim.SGD( 61 | params = self.model.parameters(), 62 | lr = args.lr, 63 | momentum = args.momentum, 64 | weight_decay = args.weight_decay, 65 | ) 66 | 67 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 68 | optimizer = self.optimizer, 69 | T_max = float(args.epochs), 70 | eta_min = args.lr_min 71 | ) 72 | 73 | elif args.optimizer == 'ADAM': 74 | self.optimizer = torch.optim.Adam( 75 | params = self.model.parameters(), 76 | lr = args.lr, 77 | weight_decay = args.weight_decay, 78 | ) 79 | 80 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 81 | optimizer = self.optimizer, 82 | mode = 'min', 83 | factor = 0.5, 84 | patience = args.patience, 85 | verbose = True 86 | ) 87 | else: 88 | raise Exception('Unknown optimizer!') 89 | 90 | 91 | def load_dataloader(self): 92 | 93 | num_train = int(len(self.train_data) * self.args.data_clip) 94 | indices = list(range(num_train)) 95 | 96 | self.train_queue = torch.utils.data.DataLoader( 97 | dataset = self.train_data, 98 | batch_size = self.args.batch, 99 | pin_memory = True, 100 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices), 101 | num_workers = self.args.nb_workers, 102 | collate_fn = self.dataset.collate, 103 | ) 104 | 105 | num_valid = int(len(self.val_data) * self.args.data_clip) 106 | indices = list(range(num_valid)) 107 | 108 | self.val_queue = torch.utils.data.DataLoader( 109 | dataset = self.val_data, 110 | batch_size = self.args.batch, 111 | pin_memory = True, 112 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices), 113 | num_workers = self.args.nb_workers, 114 | collate_fn = self.dataset.collate, 115 | shuffle = False 116 | ) 117 | 118 | num_test = int(len(self.test_data) * self.args.data_clip) 119 | indices = list(range(num_test)) 120 | 121 | self.test_queue = torch.utils.data.DataLoader( 122 | dataset = self.test_data, 123 | batch_size = self.args.batch, 124 | pin_memory = True, 125 | sampler = torch.utils.data.sampler.SubsetRandomSampler(indices), 126 | num_workers = self.args.nb_workers, 127 | collate_fn = self.dataset.collate, 128 | shuffle = False, 129 | ) 130 | 131 | 132 | def scheduler_step(self, valid_loss): 133 | 134 | if self.args.optimizer == 'SGD': 135 | self.scheduler.step() 136 | lr = scheduler.get_lr()[0] 137 | elif self.args.optimizer == 'ADAM': 138 | self.scheduler.step(valid_loss) 139 | lr = self.optimizer.param_groups[0]['lr'] 140 | if lr < 1e-5: 141 | self.console.log('=> !! learning rate is smaller than threshold !!') 142 | return lr 143 | 144 | 145 | def run(self): 146 | 147 | self.console.log(f'=> [5] Train Genotypes') 148 | self.lr = self.args.lr 149 | for i_epoch in range(self.args.epochs): 150 | #! training 151 | train_result = self.train(i_epoch, 'train') 152 | self.console.log(f"[green]=> train result [{i_epoch}] - loss: {train_result['loss']:.4f} - metric : {train_result['metric']:.4f}") 153 | with torch.no_grad(): 154 | #! validating 155 | val_result = self.infer(i_epoch, self.val_queue, 'val') 156 | self.console.log(f"[yellow]=> valid result [{i_epoch}] - loss: {val_result['loss']:.4f} - metric : {val_result['metric']:.4f}") 157 | #! testing 158 | test_result = self.infer(i_epoch, self.test_queue, 'test') 159 | self.console.log(f"[underline][red]=> test result [{i_epoch}] - loss: {test_result['loss']:.4f} - metric : {test_result['metric']:.4f}") 160 | self.lr = self.scheduler_step(val_result['loss']) 161 | 162 | self.console.log(f'=> Finished! Genotype = {args.load_genotypes}') 163 | 164 | 165 | @record_run('train') 166 | def train(self, i_epoch, stage = 'train'): 167 | 168 | self.model.train() 169 | epoch_loss = 0 170 | epoch_metric = 0 171 | desc = '=> training' 172 | device = torch.device('cuda') 173 | 174 | with tqdm(self.train_queue, desc = desc, leave = False) as t: 175 | for i_step, (batch_graphs, batch_targets) in enumerate(t): 176 | #! 1. preparing datasets 177 | G = batch_graphs.to(device) 178 | V = batch_graphs.ndata['feat'].to(device) 179 | batch_targets = batch_targets.to(device) 180 | 181 | #! 2. optimizing model parameters 182 | self.optimizer.zero_grad() 183 | input = {'G': G, 'V': V} 184 | batch_scores = self.model(input) 185 | loss = self.loss_fn(batch_scores, batch_targets, graph = batch_graphs, stage = stage) 186 | loss.backward() 187 | self.optimizer.step() 188 | 189 | epoch_loss += loss.detach().item() 190 | epoch_metric += self.metric(batch_scores, batch_targets, graph = batch_graphs, stage = stage) 191 | 192 | loss_avg = epoch_loss / (i_step + 1) 193 | metric_avg = epoch_metric / (i_step + 1) 194 | 195 | result = {'loss' : loss_avg, 'metric' : metric_avg} 196 | t.set_postfix(lr = self.lr, **result) 197 | 198 | return result 199 | 200 | 201 | @record_run('infer') 202 | def infer(self, i_epoch, dataloader, stage = 'infer'): 203 | 204 | self.model.eval() 205 | epoch_loss = 0 206 | epoch_metric = 0 207 | desc = '=> inferring' 208 | device = torch.device('cuda') 209 | 210 | with tqdm(dataloader, desc = desc, leave = False) as t: 211 | for i_step, (batch_graphs, batch_targets) in enumerate(t): 212 | G = batch_graphs.to(device) 213 | V = batch_graphs.ndata['feat'].to(device) 214 | 215 | batch_targets = batch_targets.to(device) 216 | input = {'G': G, 'V': V} 217 | batch_scores = self.model(input) 218 | loss = self.loss_fn(batch_scores, batch_targets, graph = batch_graphs, stage = stage) 219 | 220 | epoch_loss += loss.detach().item() 221 | epoch_metric += self.metric(batch_scores, batch_targets, graph = batch_graphs, stage = stage) 222 | 223 | loss_avg = epoch_loss / (i_step + 1) 224 | metric_avg = epoch_metric / (i_step + 1) 225 | 226 | result = {'loss' : epoch_loss / (i_step + 1), 'metric' : metric_avg} 227 | t.set_postfix(**result) 228 | 229 | return result 230 | 231 | 232 | if __name__ == '__main__': 233 | 234 | import warnings 235 | warnings.filterwarnings('ignore') 236 | 237 | parser = argparse.ArgumentParser('Train_from_Genotype') 238 | parser.add_argument('--task', type = str, default = 'graph_level') 239 | parser.add_argument('--data', type = str, default = 'ZINC') 240 | parser.add_argument('--extra', type = str, default = '') 241 | parser.add_argument('--in_dim_V', type = int, default = 28) 242 | parser.add_argument('--node_dim', type = int, default = 70) 243 | parser.add_argument('--nb_layers', type = int, default = 4) 244 | parser.add_argument('--nb_nodes', type = int, default = 4) 245 | parser.add_argument('--nb_classes', type = int, default = 1) 246 | parser.add_argument('--leaky_slope', type = float, default = 1e-2) 247 | parser.add_argument('--batchnorm_op', default = False, action = 'store_true') 248 | parser.add_argument('--nb_mlp_layer', type = int, default = 4) 249 | parser.add_argument('--dropout', type = float, default = 0.0) 250 | parser.add_argument('--pos_encode', type = int, default = 0) 251 | 252 | parser.add_argument('--data_clip', type = float, default = 1.0) 253 | parser.add_argument('--nb_workers', type = int, default = 0) 254 | parser.add_argument('--seed', type = int, default = 41) 255 | parser.add_argument('--epochs', type = int, default = 100) 256 | parser.add_argument('--batch', type = int, default = 64) 257 | parser.add_argument('--lr', type = float, default = 0.025) 258 | parser.add_argument('--lr_min', type = float, default = 0.001) 259 | parser.add_argument('--momentum', type = float, default = 0.9) 260 | parser.add_argument('--weight_decay', type = float, default = 3e-4) 261 | parser.add_argument('--optimizer', type = str, default = 'ADAM') 262 | parser.add_argument('--patience', type = int, default = 10) 263 | parser.add_argument('--load_genotypes', type = str, required = True) 264 | 265 | from rich.console import Console 266 | from rich.table import Table 267 | from rich.panel import Panel 268 | from rich.syntax import Syntax 269 | 270 | console = Console() 271 | args = parser.parse_args() 272 | title = "[bold][red]Training from Genotype" 273 | vis = '\n'.join([f"{key}: {val}" for key, val in vars(args).items()]) 274 | vis = Syntax(vis, "yaml", theme="monokai", line_numbers=True) 275 | richPanel = Panel.fit(vis, title = title) 276 | console.print(richPanel) 277 | Trainer(args).run() 278 | # - end - # 279 | -------------------------------------------------------------------------------- /utils/record_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorboardX import SummaryWriter 3 | 4 | class record_run: 5 | 6 | def __init__(self, comment = ''): 7 | self.comment = comment 8 | 9 | 10 | def __call__(self, func): 11 | 12 | def wrapped_func(*args, **kwargs): 13 | writer = args[0].writer 14 | # geno_path = args[0].args.load_genotypes 15 | i_epoch = args[1] 16 | stage = args[-1] 17 | 18 | result = func(*args, **kwargs) 19 | writer.add_scalar(f'{stage}_loss', result['loss'], global_step = i_epoch) 20 | writer.add_scalar(f'{stage}_metric', result['metric'], global_step = i_epoch) 21 | 22 | return result 23 | 24 | return wrapped_func -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from sklearn.metrics import f1_score 8 | from sklearn.metrics import confusion_matrix 9 | from models.operations import OPS 10 | 11 | class DotDict(dict): 12 | def __init__(self, **kwds): 13 | self.update(kwds) 14 | self.__dict__ = self 15 | 16 | def load_alpha(genotype): 17 | alpha_cell = torch.Tensor(genotype.alpha_cell) 18 | alpha_edge = torch.Tensor(genotype.alpha_edge) 19 | return [alpha_cell, alpha_edge] 20 | 21 | class AvgrageMeter(object): 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.avg = 0 27 | self.sum = 0 28 | self.cnt = 0 29 | 30 | def update(self, val, n=1): 31 | self.sum += val * n 32 | self.cnt += n 33 | self.avg = self.sum / self.cnt 34 | 35 | def count_parameters_in_MB(model): 36 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) 37 | 38 | def save_checkpoint(state, is_best, save): 39 | filename = os.path.join(save, 'checkpoint.pth.tar') 40 | torch.save(state, filename) 41 | if is_best: 42 | best_filename = os.path.join(save, 'model_best.pth.tar') 43 | shutil.copyfile(filename, best_filename) 44 | 45 | def save(model, model_path): 46 | torch.save(model.state_dict(), model_path) 47 | 48 | def load(model, model_path): 49 | model.load_state_dict(torch.load(model_path)) 50 | 51 | def drop_path(x, drop_prob): 52 | if drop_prob > 0.: 53 | keep_prob = 1. - drop_prob 54 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 55 | x.div_(keep_prob) 56 | x.mul_(mask) 57 | return x 58 | 59 | def create_exp_dir(path, scripts_to_save=None): 60 | if not os.path.exists(path): 61 | os.mkdir(path) 62 | print('Experiment dir : {}'.format(path)) 63 | 64 | if scripts_to_save is not None: 65 | os.mkdir(os.path.join(path, 'scripts')) 66 | for script in scripts_to_save: 67 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 68 | shutil.copyfile(script, dst_file) 69 | 70 | # ---------------------------------------------------------------- 71 | # metrics 72 | class mask: 73 | 74 | def __init__(self, data_type, mask): 75 | 76 | self.data_type = data_type 77 | self.mask = mask 78 | 79 | def __call__(self, func): 80 | 81 | def wrapped_func(*args, **kwargs): 82 | 83 | if self.mask: 84 | args = list(args) 85 | graph = kwargs['graph'] 86 | stage = kwargs['stage'] 87 | 88 | graph_data = graph.ndata if self.data_type == 'V' else graph.edata 89 | 90 | if f'{stage}_mask' in graph_data: 91 | mask = graph_data[f'{stage}_mask'] 92 | args[-1] = args[-1][mask] 93 | args[-2] = args[-2][mask] 94 | return func(*args) 95 | 96 | return wrapped_func 97 | 98 | def accuracy(output, target, topk=(1,)): 99 | maxk = max(topk) 100 | batch_size = target.size(0) 101 | 102 | _, pred = output.topk(maxk, 1, True, True) 103 | pred = pred.t() 104 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 105 | 106 | res = [] 107 | for k in topk: 108 | correct_k = correct[:k].view(-1).float().sum(0) 109 | res.append(correct_k.mul_(100.0 / batch_size)) 110 | 111 | return res 112 | 113 | @mask('E', False) 114 | def binary_f1_score(scores, targets): 115 | """Computes the F1 score using scikit-learn for binary class labels. 116 | Returns the F1 score for the positive class, i.e. labelled '1'. 117 | """ 118 | y_true = targets.cpu().numpy() 119 | y_pred = scores.argmax(dim=1).cpu().numpy() 120 | return f1_score(y_true, y_pred, average='binary') 121 | 122 | @mask('V', False) 123 | def accuracy_SBM(scores, targets): 124 | 125 | S = targets.cpu().numpy() 126 | C = np.argmax(torch.nn.Softmax(dim=1)(scores).cpu().detach().numpy(), axis=1) 127 | CM = confusion_matrix(S, C).astype(np.float32) 128 | nb_classes = CM.shape[0] 129 | targets = targets.cpu().detach().numpy() 130 | nb_non_empty_classes = 0 131 | pr_classes = np.zeros(nb_classes) 132 | for r in range(nb_classes): 133 | cluster = np.where(targets == r)[0] 134 | if cluster.shape[0] != 0: 135 | pr_classes[r] = CM[r, r] / float(cluster.shape[0]) 136 | if CM[r, r] > 0: 137 | nb_non_empty_classes += 1 138 | else: 139 | pr_classes[r] = 0.0 140 | acc = 100. * np.sum(pr_classes) / float(nb_classes) 141 | return acc 142 | 143 | @mask('G', False) 144 | def accuracy_MNIST_CIFAR(scores, targets): 145 | scores = scores.detach().argmax(dim=1) 146 | acc = (scores == targets).float().mean().item() 147 | return acc 148 | 149 | @mask('G', False) 150 | def MAE(scores, targets): 151 | MAE = F.l1_loss(scores, targets) 152 | MAE = MAE.detach().item() 153 | return MAE 154 | 155 | @mask('V', True) 156 | def CoraAccuracy(scores, targets): 157 | return (scores.argmax(1) == targets).float().mean().item() 158 | 159 | # ---------------------------------------------------------------- 160 | # loss functions 161 | 162 | class MoleculesCriterion(nn.Module): 163 | 164 | def __init__(self): 165 | super().__init__() 166 | self.loss_fn = nn.L1Loss() 167 | 168 | @mask('G', False) 169 | def forward(self, pred, label): 170 | return self.loss_fn(pred, label) 171 | 172 | class TSPCriterion(nn.Module): 173 | 174 | def __init__(self): 175 | super().__init__() 176 | self.loss_fn = nn.CrossEntropyLoss(weight=None) 177 | 178 | @mask('E', False) 179 | def forward(self, pred, label): 180 | return self.loss_fn(pred, label) 181 | 182 | class SuperPixCriterion(nn.Module): 183 | 184 | def __init__(self): 185 | super().__init__() 186 | self.loss_fn = nn.CrossEntropyLoss(weight=None) 187 | 188 | @mask('G', False) 189 | def forward(self, pred, label): 190 | return self.loss_fn(pred, label) 191 | 192 | class SBMsCriterion(nn.Module): 193 | 194 | def __init__(self, num_classes): 195 | super().__init__() 196 | self.n_classes = num_classes 197 | 198 | @mask('V', False) 199 | def forward(self, pred, label): 200 | V = label.size(0) 201 | label_count = torch.bincount(label) 202 | label_count = label_count[label_count.nonzero()].squeeze() 203 | cluster_sizes = torch.zeros(self.n_classes).long().cuda() 204 | cluster_sizes[torch.unique(label)] = label_count 205 | weight = (V - cluster_sizes).float() / V 206 | weight *= (cluster_sizes > 0).float() 207 | # weighted cross-entropy for unbalanced classes 208 | criterion = nn.CrossEntropyLoss(weight=weight) 209 | loss = criterion(pred, label) 210 | return loss 211 | 212 | class CiteCriterion(nn.Module): 213 | 214 | def __init__(self): 215 | super().__init__() 216 | self.loss_fn = nn.CrossEntropyLoss(weight=None) 217 | 218 | @mask('V', True) 219 | def forward(self, pred, label): 220 | return self.loss_fn(pred, label) 221 | 222 | 223 | # ---------------------------------------------------------------- 224 | def cell_genotype(args, id, arch_para, arch_topo): 225 | result = {'id': id, 'topology': []} 226 | link = [ [] for i in range(args.nb_nodes*3+1) ] 227 | 228 | for src, dst, w, ops in arch_topo: 229 | link[dst].append((src, ops, arch_para[w])) 230 | for dst in range(1, args.nb_nodes*3+1): 231 | nb_link = len(link[dst]) 232 | best_links = sorted( 233 | range(nb_link), 234 | key = lambda lk: -max( 235 | link[dst][lk][-1][j] 236 | for j in range(len(link[dst][lk][-2])) 237 | if 'V_None' not in link[dst][lk][-2] or j != link[dst][lk][-2].index('V_None') 238 | ) 239 | )[:1] #! 截取的操作数量 240 | for blink in best_links: 241 | blink = link[dst][blink] 242 | if 'V_None' in blink[-2]: 243 | best_op = torch.argmax(blink[-1][1:]).item() + 1 244 | else: 245 | best_op = torch.argmax(blink[-1]).item() 246 | src = blink[0] 247 | ops = blink[1][best_op] 248 | result['topology'].append({'src': src, 'dst': dst, 'ops': ops}) 249 | return result 250 | 251 | def genotypes(args, arch_paras, arch_topos): 252 | result = {'Genotype': []} 253 | for id in range(args.nb_layers): 254 | cell_result = cell_genotype(args, id, arch_paras[id], arch_topos[id]) 255 | result['Genotype'].append(cell_result) 256 | return result 257 | 258 | # ---------------------------------------------------------------- 259 | import math 260 | def Singleton(cls): 261 | _instance = {} 262 | 263 | def _singleton(*args, **kargs): 264 | if cls not in _instance: 265 | _instance[cls] = cls(*args, **kargs) 266 | return _instance[cls] 267 | 268 | return _singleton 269 | 270 | @Singleton 271 | class DecayScheduler(object): 272 | def __init__(self, base_lr=1.0, last_iter=-1, T_max=50, T_start=0, T_stop=50, decay_type='cosine'): 273 | self.base_lr = base_lr 274 | self.T_max = T_max 275 | self.T_start = T_start 276 | self.T_stop = T_stop 277 | self.cnt = 0 278 | self.decay_type = decay_type 279 | self.decay_rate = 1.0 280 | 281 | def step(self, epoch): 282 | if epoch >= self.T_start: 283 | if self.decay_type == "cosine": 284 | self.decay_rate = self.base_lr * (1 + math.cos(math.pi * epoch / (self.T_max - self.T_start))) / 2.0 if epoch <= self.T_stop else self.decay_rate 285 | elif self.decay_type == "slow_cosine": 286 | self.decay_rate = self.base_lr * math.cos((math.pi/2) * epoch / (self.T_max - self.T_start)) if epoch <= self.T_stop else self.decay_rate 287 | elif self.decay_type == "linear": 288 | self.decay_rate = self.base_lr * (self.T_max - epoch) / (self.T_max - self.T_start) if epoch <= self.T_stop else self.decay_rate 289 | else: 290 | self.decay_rate = self.base_lr 291 | else: 292 | self.decay_rate = self.base_lr 293 | 294 | --------------------------------------------------------------------------------