├── .gitignore ├── GraphReader ├── README.md ├── __init__.py └── graph_reader.py ├── LICENSE ├── LogMetric.py ├── MessageFunction.py ├── README.md ├── ReadoutFunction.py ├── UpdateFunction.py ├── data ├── README.md └── download.py ├── datasets ├── README.md ├── __init__.py ├── grec.py ├── gwhistograph.py ├── letter.py ├── mutag.py ├── qm9.py └── utils.py ├── demos ├── demo_grec_duvenaud.py ├── demo_grec_intnet.py ├── demo_grec_mpnn.py ├── demo_gwhist_duvenaud.py ├── demo_gwhist_ggnn.py ├── demo_letter_duvenaud.py ├── demo_letter_ggnn.py ├── demo_letter_intnet.py ├── demo_qm9_duvenaud.py ├── demo_qm9_ggnn.py ├── demo_qm9_intnet.py └── demo_qm9_mpnn.py ├── main.py ├── models ├── MPNN.py ├── MPNN_Duvenaud.py ├── MPNN_GGNN.py ├── MPNN_IntNet.py ├── README.md ├── __init__.py └── nnet.py ├── requirements.txt └── visualization ├── Plotter.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # Data folders 92 | data/mutag/* 93 | data/enzymes/* 94 | data/qm9/* 95 | data/data_graphml/* 96 | data/Washington_DB_1.0/* 97 | data/GWHistoGraphs/* 98 | data/Letter/* 99 | data/GREC/* 100 | 101 | # log 102 | log/* 103 | # plot 104 | plot/* 105 | # Trained models 106 | checkpoint/* 107 | # Pycharm 108 | .idea 109 | -------------------------------------------------------------------------------- /GraphReader/README.md: -------------------------------------------------------------------------------- 1 | # GraphReader 2 | 3 | Collection of graph reader functions for different datasets. 4 | 5 | - xyz_graph_reader: read a graph from the [QM9 dataset](http://quantum-machine.org/datasets/). 6 | - create_graph_letter: read a graph from the Letter dataset from [IAM datadet](http://www.fki.inf.unibe.ch/databases/iam-graph-database). 7 | - create_graph_grec: read a graph from the GREC dataset from [IAM dataset](http://www.fki.inf.unibe.ch/databases/iam-graph-database). 8 | - create_graph_gwhist: read a graph from the [HistoGraph dataset](http://www.histograph.ch/). 9 | - create_graph_mutag: read a graph from the [MUTAG dataset](https://figshare.com/articles/MUTAG_and_ENZYMES_DataSet/899875). 10 | - create_graph_enzymes: read a graph from the [ENZYMES dataset](https://figshare.com/articles/MUTAG_and_ENZYMES_DataSet/899875). 11 | 12 | ## TODO 13 | 14 | - [ ] Proteins. 15 | - [ ] PTC. 16 | 17 | ## Authors 18 | 19 | * Pau Riba (@priba) 20 | * Anjan Dutta (@AnjanDutta) 21 | -------------------------------------------------------------------------------- /GraphReader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/priba/nmp_qc/177db7ea738a7a91f1262ce954f9c7a4a2b98849/GraphReader/__init__.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 priba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LogMetric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import os 6 | from tensorboard_logger import configure, log_value 7 | 8 | __author__ = "Pau Riba, Anjan Dutta" 9 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 10 | 11 | 12 | def error_ratio(pred, target): 13 | if type(pred) is not np.ndarray: 14 | pred = np.array(pred) 15 | if type(target) is not np.ndarray: 16 | target = np.array(target) 17 | 18 | return np.mean(np.divide(np.abs(pred - target), np.abs(target))) 19 | 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value""" 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | class Logger(object): 40 | def __init__(self, log_dir): 41 | if not os.path.isdir(log_dir): 42 | # if the directory does not exist we create the directory 43 | os.makedirs(log_dir) 44 | else: 45 | # clean previous logged data under the same directory name 46 | self._remove(log_dir) 47 | 48 | # configure the project 49 | configure(log_dir) 50 | 51 | self.global_step = 0 52 | 53 | def log_value(self, name, value): 54 | log_value(name, value, self.global_step) 55 | return self 56 | 57 | def step(self): 58 | self.global_step += 1 59 | 60 | @staticmethod 61 | def _remove(path): 62 | """ param could either be relative or absolute. """ 63 | if os.path.isfile(path): 64 | os.remove(path) # remove the file 65 | elif os.path.isdir(path): 66 | import shutil 67 | shutil.rmtree(path) # remove dir and all contains 68 | -------------------------------------------------------------------------------- /MessageFunction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | MessageFunction.py: Propagates a message depending on two nodes and their common edge. 6 | 7 | Usage: 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | 13 | # Own modules 14 | import datasets 15 | from models.nnet import NNet 16 | 17 | import numpy as np 18 | import os 19 | import argparse 20 | import time 21 | import torch 22 | 23 | import torch.nn as nn 24 | from torch.autograd.variable import Variable 25 | 26 | 27 | __author__ = "Pau Riba, Anjan Dutta" 28 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 29 | 30 | 31 | class MessageFunction(nn.Module): 32 | 33 | # Constructor 34 | def __init__(self, message_def='duvenaud', args={}): 35 | super(MessageFunction, self).__init__() 36 | self.m_definition = '' 37 | self.m_function = None 38 | self.args = {} 39 | self.__set_message(message_def, args) 40 | 41 | # Message from h_v to h_w through e_vw 42 | def forward(self, h_v, h_w, e_vw, args=None): 43 | return self.m_function(h_v, h_w, e_vw, args) 44 | 45 | # Set a message function 46 | def __set_message(self, message_def, args={}): 47 | self.m_definition = message_def.lower() 48 | 49 | self.m_function = { 50 | 'duvenaud': self.m_duvenaud, 51 | 'ggnn': self.m_ggnn, 52 | 'intnet': self.m_intnet, 53 | 'mpnn': self.m_mpnn, 54 | 'mgc': self.m_mgc, 55 | 'bruna': self.m_bruna, 56 | 'defferrard': self.m_deff, 57 | 'kipf': self.m_kipf 58 | }.get(self.m_definition, None) 59 | 60 | if self.m_function is None: 61 | print('WARNING!: Message Function has not been set correctly\n\tIncorrect definition ' + message_def) 62 | quit() 63 | 64 | init_parameters = { 65 | 'duvenaud': self.init_duvenaud, 66 | 'ggnn': self.init_ggnn, 67 | 'intnet': self.init_intnet, 68 | 'mpnn': self.init_mpnn 69 | }.get(self.m_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {})) 70 | 71 | self.learn_args, self.learn_modules, self.args = init_parameters(args) 72 | 73 | self.m_size = { 74 | 'duvenaud': self.out_duvenaud, 75 | 'ggnn': self.out_ggnn, 76 | 'intnet': self.out_intnet, 77 | 'mpnn': self.out_mpnn 78 | }.get(self.m_definition, None) 79 | 80 | # Get the name of the used message function 81 | def get_definition(self): 82 | return self.m_definition 83 | 84 | # Get the message function arguments 85 | def get_args(self): 86 | return self.args 87 | 88 | # Get Output size 89 | def get_out_size(self, size_h, size_e, args=None): 90 | return self.m_size(size_h, size_e, args) 91 | 92 | # Definition of various state of the art message functions 93 | 94 | # Duvenaud et al. (2015), Convolutional Networks for Learning Molecular Fingerprints 95 | def m_duvenaud(self, h_v, h_w, e_vw, args): 96 | m = torch.cat([h_w, e_vw], 2) 97 | return m 98 | 99 | def out_duvenaud(self, size_h, size_e, args): 100 | return size_h + size_e 101 | 102 | def init_duvenaud(self, params): 103 | learn_args = [] 104 | learn_modules = [] 105 | args = {} 106 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 107 | 108 | # Li et al. (2016), Gated Graph Neural Networks (GG-NN) 109 | def m_ggnn(self, h_v, h_w, e_vw, opt={}): 110 | 111 | m = Variable(torch.zeros(h_w.size(0), h_w.size(1), self.args['out']).type_as(h_w.data)) 112 | 113 | for w in range(h_w.size(1)): 114 | if torch.nonzero(e_vw[:, w, :].data).size(): 115 | for i, el in enumerate(self.args['e_label']): 116 | ind = (el == e_vw[:,w,:]).type_as(self.learn_args[0][i]) 117 | 118 | parameter_mat = self.learn_args[0][i][None, ...].expand(h_w.size(0), self.learn_args[0][i].size(0), 119 | self.learn_args[0][i].size(1)) 120 | 121 | m_w = torch.transpose(torch.bmm(torch.transpose(parameter_mat, 1, 2), 122 | torch.transpose(torch.unsqueeze(h_w[:, w, :], 1), 123 | 1, 2)), 1, 2) 124 | m_w = torch.squeeze(m_w) 125 | m[:,w,:] = ind.expand_as(m_w)*m_w 126 | return m 127 | 128 | def out_ggnn(self, size_h, size_e, args): 129 | return self.args['out'] 130 | 131 | def init_ggnn(self, params): 132 | learn_args = [] 133 | learn_modules = [] 134 | args = {} 135 | 136 | args['e_label'] = params['e_label'] 137 | args['in'] = params['in'] 138 | args['out'] = params['out'] 139 | 140 | # Define a parameter matrix A for each edge label. 141 | learn_args.append(nn.Parameter(torch.randn(len(params['e_label']), params['in'], params['out']))) 142 | 143 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 144 | 145 | # Battaglia et al. (2016), Interaction Networks 146 | def m_intnet(self, h_v, h_w, e_vw, args): 147 | m = torch.cat([h_v[:, None, :].expand_as(h_w), h_w, e_vw], 2) 148 | b_size = m.size() 149 | 150 | m = m.view(-1, b_size[2]) 151 | 152 | m = self.learn_modules[0](m) 153 | m = m.view(b_size[0], b_size[1], -1) 154 | return m 155 | 156 | def out_intnet(self, size_h, size_e, args): 157 | return self.args['out'] 158 | 159 | def init_intnet(self, params): 160 | learn_args = [] 161 | learn_modules = [] 162 | args = {} 163 | args['in'] = params['in'] 164 | args['out'] = params['out'] 165 | learn_modules.append(NNet(n_in=params['in'], n_out=params['out'])) 166 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 167 | 168 | # Gilmer et al. (2017), Neural Message Passing for Quantum Chemistry 169 | def m_mpnn(self, h_v, h_w, e_vw, opt={}): 170 | # Matrices for each edge 171 | edge_output = self.learn_modules[0](e_vw) 172 | edge_output = edge_output.view(-1, self.args['out'], self.args['in']) 173 | 174 | h_w_rows = h_w[..., None].expand(h_w.size(0), h_v.size(1), h_w.size(1)).contiguous() 175 | 176 | h_w_rows = h_w_rows.view(-1, self.args['in']) 177 | 178 | h_multiply = torch.bmm(edge_output, torch.unsqueeze(h_w_rows,2)) 179 | 180 | m_new = torch.squeeze(h_multiply) 181 | 182 | return m_new 183 | 184 | def out_mpnn(self, size_h, size_e, args): 185 | return self.args['out'] 186 | 187 | def init_mpnn(self, params): 188 | learn_args = [] 189 | learn_modules = [] 190 | args = {} 191 | 192 | args['in'] = params['in'] 193 | args['out'] = params['out'] 194 | 195 | # Define a parameter matrix A for each edge label. 196 | learn_modules.append(NNet(n_in=params['edge_feat'], n_out=(params['in']*params['out']))) 197 | 198 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 199 | 200 | # Kearnes et al. (2016), Molecular Graph Convolutions 201 | def m_mgc(self, h_v, h_w, e_vw, args): 202 | m = e_vw 203 | return m 204 | 205 | # Laplacian based methods 206 | # Bruna et al. (2013) 207 | def m_bruna(self, h_v, h_w, e_vw, args): 208 | # TODO 209 | m = [] 210 | return m 211 | 212 | # Defferrard et al. (2016) 213 | def m_deff(self, h_v, h_w, e_vw, args): 214 | # TODO 215 | m = [] 216 | return m 217 | 218 | # Kipf & Welling (2016) 219 | def m_kipf(self, h_v, h_w, e_vw, args): 220 | # TODO 221 | m = [] 222 | return m 223 | 224 | if __name__ == '__main__': 225 | # Parse optios for downloading 226 | parser = argparse.ArgumentParser(description='QM9 Object.') 227 | # Optional argument 228 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', default=['./data/qm9/dsgdb9nsd/']) 229 | 230 | args = parser.parse_args() 231 | root = args.root[0] 232 | 233 | files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 234 | 235 | idx = np.random.permutation(len(files)) 236 | idx = idx.tolist() 237 | 238 | valid_ids = [files[i] for i in idx[0:10000]] 239 | test_ids = [files[i] for i in idx[10000:20000]] 240 | train_ids = [files[i] for i in idx[20000:]] 241 | 242 | data_train = datasets.Qm9(root, train_ids) 243 | data_valid = datasets.Qm9(root, valid_ids) 244 | data_test = datasets.Qm9(root, test_ids) 245 | 246 | # Define message 247 | m = MessageFunction('duvenaud') 248 | 249 | print(m.get_definition()) 250 | 251 | start = time.time() 252 | 253 | # Select one graph 254 | g_tuple, l = data_train[0] 255 | g, h_t, e = g_tuple 256 | 257 | m_t = {} 258 | for v in g.nodes_iter(): 259 | neigh = g.neighbors(v) 260 | m_neigh = type(h_t) 261 | for w in neigh: 262 | if (v,w) in e: 263 | e_vw = e[(v, w)] 264 | else: 265 | e_vw = e[(w, v)] 266 | m_v = m.forward(h_t[v], h_t[w], e_vw) 267 | if len(m_neigh): 268 | m_neigh += m_v 269 | else: 270 | m_neigh = m_v 271 | 272 | m_t[v] = m_neigh 273 | 274 | end = time.time() 275 | 276 | print('Input nodes') 277 | print(h_t) 278 | print('Message') 279 | print(m_t) 280 | print('Time') 281 | print(end - start) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Message Passing for Quantum Chemistry 2 | 3 | Implementation of different models of Neural Networks on graphs as explained in the article proposed by Gilmer *et al.* [1]. 4 | 5 | ## Installation 6 | 7 | $ pip install -r requirements.txt 8 | $ python main.py 9 | 10 | ## Installation of rdkit 11 | 12 | Running any experiment using QM9 dataset needs installing the [rdkit](http://www.rdkit.org/) package, which can be done 13 | following the instructions available [here](http://www.rdkit.org/docs/Install.html) 14 | 15 | ## Data 16 | 17 | The data used in this project can be downloaded [here](https://github.com/priba/nmp_qc/tree/master/data). 18 | 19 | ## Bibliography 20 | 21 | - [1] Gilmer *et al.*, [Neural Message Passing for Quantum Chemistry](https://arxiv.org/pdf/1704.01212.pdf), arXiv, 2017. 22 | - [2] Duvenaud *et al.*, [Convolutional Networks on Graphs for Learning Molecular Fingerprints](https://arxiv.org/abs/1606.09375), NIPS, 2015. 23 | - [3] Li *et al.*, [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493), ICLR, 2016. 24 | - [4] Battaglia *et al.*, [Interaction Networks for Learning about Objects](https://arxiv.org/abs/1612.00222), NIPS, 2016. 25 | - [5] Kipf *et al.*, [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907), ICLR, 2017 26 | - [6] Defferrard *et al.*, [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375), NIPS, 2016. 27 | - [7] Kearnes *et al.*, [Molecular Graph Convolutions: Moving Beyond Fingerprints](https://arxiv.org/abs/1603.00856), JCAMD, 2016. 28 | - [8] Bruna *et al.*, [Spectral Networks and Locally Connected Networks on Graphs](https://arxiv.org/abs/1312.6203), ICLR, 2014. 29 | 30 | ## Cite 31 | 32 | ``` 33 | @Article{Gilmer2017, 34 | author = {Justin Gilmer and Samuel S. Schoenholz and Patrick F. Riley and Oriol Vinyals and George E. Dahl}, 35 | title = {Neural Message Passing for Quantum Chemistry}, 36 | journal = {CoRR}, 37 | year = {2017} 38 | } 39 | ``` 40 | 41 | ## Authors 42 | 43 | * Pau Riba (@priba) [Webpage](http://www.cvc.uab.es/people/priba/) 44 | * Anjan Dutta (@AnjanDutta) [Webpage](https://sites.google.com/site/2adutta/) 45 | -------------------------------------------------------------------------------- /ReadoutFunction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | MessageFunction.py: Propagates a message depending on two nodes and their common edge. 6 | 7 | Usage: 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | 13 | # Own modules 14 | import datasets 15 | from MessageFunction import MessageFunction 16 | from UpdateFunction import UpdateFunction 17 | from models.nnet import NNet 18 | 19 | import time 20 | import torch 21 | import torch.nn as nn 22 | import os 23 | import argparse 24 | import numpy as np 25 | 26 | from torch.autograd.variable import Variable 27 | 28 | #dtype = torch.cuda.FloatTensor 29 | dtype = torch.FloatTensor 30 | 31 | __author__ = "Pau Riba, Anjan Dutta" 32 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 33 | 34 | 35 | class ReadoutFunction(nn.Module): 36 | 37 | # Constructor 38 | def __init__(self, readout_def='nn', args={}): 39 | super(ReadoutFunction, self).__init__() 40 | self.r_definition = '' 41 | self.r_function = None 42 | self.args = {} 43 | self.__set_readout(readout_def, args) 44 | 45 | # Readout graph given node values at las layer 46 | def forward(self, h_v): 47 | return self.r_function(h_v) 48 | 49 | # Set a readout function 50 | def __set_readout(self, readout_def, args): 51 | self.r_definition = readout_def.lower() 52 | 53 | self.r_function = { 54 | 'duvenaud': self.r_duvenaud, 55 | 'ggnn': self.r_ggnn, 56 | 'intnet': self.r_intnet, 57 | 'mpnn': self.r_mpnn 58 | }.get(self.r_definition, None) 59 | 60 | if self.r_function is None: 61 | print('WARNING!: Readout Function has not been set correctly\n\tIncorrect definition ' + readout_def) 62 | quit() 63 | 64 | init_parameters = { 65 | 'duvenaud': self.init_duvenaud, 66 | 'ggnn': self.init_ggnn, 67 | 'intnet': self.init_intnet, 68 | 'mpnn': self.init_mpnn 69 | }.get(self.r_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {})) 70 | 71 | self.learn_args, self.learn_modules, self.args = init_parameters(args) 72 | 73 | # Get the name of the used readout function 74 | def get_definition(self): 75 | return self.r_definition 76 | 77 | ## Definition of various state of the art update functions 78 | 79 | # Duvenaud 80 | def r_duvenaud(self, h): 81 | # layers 82 | aux = [] 83 | for l in range(len(h)): 84 | param_sz = self.learn_args[l].size() 85 | parameter_mat = torch.t(self.learn_args[l])[None, ...].expand(h[l].size(0), param_sz[1], 86 | param_sz[0]) 87 | 88 | aux.append(torch.transpose(torch.bmm(parameter_mat, torch.transpose(h[l], 1, 2)), 1, 2)) 89 | 90 | for j in range(0, aux[l].size(1)): 91 | # Mask whole 0 vectors 92 | aux[l][:, j, :] = nn.Softmax()(aux[l][:, j, :].clone())*(torch.sum(aux[l][:, j, :] != 0, 1) > 0).expand_as(aux[l][:, j, :]).type_as(aux[l]) 93 | 94 | aux = torch.sum(torch.sum(torch.stack(aux, 3), 3), 1) 95 | return self.learn_modules[0](torch.squeeze(aux)) 96 | 97 | def init_duvenaud(self, params): 98 | learn_args = [] 99 | learn_modules = [] 100 | args = {} 101 | 102 | args['out'] = params['out'] 103 | 104 | # Define a parameter matrix W for each layer. 105 | for l in range(params['layers']): 106 | learn_args.append(nn.Parameter(torch.randn(params['in'][l], params['out']))) 107 | 108 | # learn_modules.append(nn.Linear(params['out'], params['target'])) 109 | 110 | learn_modules.append(NNet(n_in=params['out'], n_out=params['target'])) 111 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 112 | 113 | # GG-NN, Li et al. 114 | def r_ggnn(self, h): 115 | 116 | aux = Variable( torch.Tensor(h[0].size(0), self.args['out']).type_as(h[0].data).zero_() ) 117 | # For each graph 118 | for i in range(h[0].size(0)): 119 | nn_res = nn.Sigmoid()(self.learn_modules[0](torch.cat([h[0][i,:,:], h[-1][i,:,:]], 1)))*self.learn_modules[1](h[-1][i,:,:]) 120 | 121 | # Delete virtual nodes 122 | nn_res = (torch.sum(h[0][i,:,:],1).expand_as(nn_res)>0).type_as(nn_res)* nn_res 123 | 124 | aux[i,:] = torch.sum(nn_res,0) 125 | 126 | return aux 127 | 128 | def init_ggnn(self, params): 129 | learn_args = [] 130 | learn_modules = [] 131 | args = {} 132 | 133 | # i 134 | learn_modules.append(NNet(n_in=2*params['in'], n_out=params['target'])) 135 | 136 | # j 137 | learn_modules.append(NNet(n_in=params['in'], n_out=params['target'])) 138 | 139 | args['out'] = params['target'] 140 | 141 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 142 | 143 | 144 | # Battaglia et al. (2016), Interaction Networks 145 | def r_intnet(self, h): 146 | 147 | aux = torch.sum(h[-1],1) 148 | 149 | return self.learn_modules[0](aux) 150 | 151 | def init_intnet(self, params): 152 | learn_args = [] 153 | learn_modules = [] 154 | args = {} 155 | 156 | learn_modules.append(NNet(n_in=params['in'], n_out=params['target'])) 157 | 158 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 159 | 160 | def r_mpnn(self, h): 161 | 162 | aux = Variable( torch.Tensor(h[0].size(0), self.args['out']).type_as(h[0].data).zero_() ) 163 | # For each graph 164 | for i in range(h[0].size(0)): 165 | nn_res = nn.Sigmoid()(self.learn_modules[0](torch.cat([h[0][i,:,:], h[-1][i,:,:]], 1)))*self.learn_modules[1](h[-1][i,:,:]) 166 | 167 | # Delete virtual nodes 168 | nn_res = (torch.sum(h[0][i,:,:],1).expand_as(nn_res)>0).type_as(nn_res)* nn_res 169 | 170 | aux[i,:] = torch.sum(nn_res,0) 171 | 172 | return aux 173 | 174 | def init_mpnn(self, params): 175 | learn_args = [] 176 | learn_modules = [] 177 | args = {} 178 | 179 | # i 180 | learn_modules.append(NNet(n_in=2*params['in'], n_out=params['target'])) 181 | 182 | # j 183 | learn_modules.append(NNet(n_in=params['in'], n_out=params['target'])) 184 | 185 | args['out'] = params['target'] 186 | 187 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 188 | 189 | if __name__ == '__main__': 190 | # Parse optios for downloading 191 | parser = argparse.ArgumentParser(description='QM9 Object.') 192 | # Optional argument 193 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', default=['./data/qm9/dsgdb9nsd/']) 194 | 195 | args = parser.parse_args() 196 | root = args.root[0] 197 | 198 | files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 199 | 200 | idx = np.random.permutation(len(files)) 201 | idx = idx.tolist() 202 | 203 | valid_ids = [files[i] for i in idx[0:10000]] 204 | test_ids = [files[i] for i in idx[10000:20000]] 205 | train_ids = [files[i] for i in idx[20000:]] 206 | 207 | data_train = datasets.Qm9(root, train_ids) 208 | data_valid = datasets.Qm9(root, valid_ids) 209 | data_test = datasets.Qm9(root, test_ids) 210 | 211 | # d = datasets.utils.get_graph_stats(data_train, 'degrees') 212 | d = [1, 2, 3, 4] 213 | 214 | ## Define message 215 | m = MessageFunction('duvenaud') 216 | 217 | ## Parameters for the update function 218 | # Select one graph 219 | g_tuple, l = data_train[0] 220 | g, h_t, e = g_tuple 221 | 222 | m_v = m.forward(h_t[0], h_t[1], e[list(e.keys())[0]]) 223 | 224 | in_n = len(m_v) 225 | out_n = 30 226 | 227 | ## Define Update 228 | u = UpdateFunction('duvenaud', args={'deg': d, 'in': in_n, 'out': out_n}) 229 | 230 | in_n = len(h_t[0]) 231 | 232 | ## Define Readout 233 | r = ReadoutFunction('duvenaud', args={'layers': 2, 'in': [in_n, out_n], 'out': 50, 'target': len(l)}) 234 | 235 | print(m.get_definition()) 236 | print(u.get_definition()) 237 | print(r.get_definition()) 238 | 239 | start = time.time() 240 | 241 | # Layers 242 | h = [] 243 | 244 | # Select one graph 245 | g_tuple, l = data_train[0] 246 | g, h_in, e = g_tuple 247 | 248 | h.append(h_in) 249 | 250 | # Layer 251 | t = 1 252 | h.append({}) 253 | for v in g.nodes_iter(): 254 | neigh = g.neighbors(v) 255 | m_neigh = dtype() 256 | for w in neigh: 257 | if (v, w) in e: 258 | e_vw = e[(v, w)] 259 | else: 260 | e_vw = e[(w, v)] 261 | m_v = m.forward(h[t-1][v], h[t-1][w], e_vw) 262 | if len(m_neigh): 263 | m_neigh += m_v 264 | else: 265 | m_neigh = m_v 266 | 267 | # Duvenaud 268 | opt = {'deg': len(neigh)} 269 | h[t][v] = u.forward(h[t-1][v], m_neigh, opt) 270 | 271 | # Readout 272 | res = r.forward(h) 273 | 274 | end = time.time() 275 | 276 | 277 | print(res) 278 | print('Time') 279 | print(end - start) 280 | -------------------------------------------------------------------------------- /UpdateFunction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | UpdateFunction.py: Updates the nodes using the previous state and the message. 6 | 7 | Usage: 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | 13 | # Own modules 14 | import datasets 15 | from MessageFunction import MessageFunction 16 | from models.nnet import NNet 17 | 18 | import numpy as np 19 | import time 20 | import os 21 | import argparse 22 | import torch 23 | 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from torch.autograd.variable import Variable 27 | 28 | #dtype = torch.cuda.FloatTensor 29 | dtype = torch.FloatTensor 30 | 31 | __author__ = "Pau Riba, Anjan Dutta" 32 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 33 | 34 | 35 | class UpdateFunction(nn.Module): 36 | 37 | # Constructor 38 | def __init__(self, update_def='nn', args={}): 39 | super(UpdateFunction, self).__init__() 40 | self.u_definition = '' 41 | self.u_function = None 42 | self.args = {} 43 | self.__set_update(update_def, args) 44 | 45 | # Update node hv given message mv 46 | def forward(self, h_v, m_v, opt={}): 47 | return self.u_function(h_v, m_v, opt) 48 | 49 | # Set update function 50 | def __set_update(self, update_def, args): 51 | self.u_definition = update_def.lower() 52 | 53 | self.u_function = { 54 | 'duvenaud': self.u_duvenaud, 55 | 'ggnn': self.u_ggnn, 56 | 'intnet': self.u_intnet, 57 | 'mpnn': self.u_mpnn 58 | }.get(self.u_definition, None) 59 | 60 | if self.u_function is None: 61 | print('WARNING!: Update Function has not been set correctly\n\tIncorrect definition ' + update_def) 62 | 63 | init_parameters = { 64 | 'duvenaud': self.init_duvenaud, 65 | 'ggnn': self.init_ggnn, 66 | 'intnet': self.init_intnet, 67 | 'mpnn': self.init_mpnn 68 | }.get(self.u_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {})) 69 | 70 | self.learn_args, self.learn_modules, self.args = init_parameters(args) 71 | 72 | # Get the name of the used update function 73 | def get_definition(self): 74 | return self.u_definition 75 | 76 | # Get the update function arguments 77 | def get_args(self): 78 | return self.args 79 | 80 | ## Definition of various state of the art update functions 81 | 82 | # Duvenaud 83 | def u_duvenaud(self, h_v, m_v, opt): 84 | 85 | param_sz = self.learn_args[0][opt['deg']].size() 86 | parameter_mat = torch.t(self.learn_args[0][opt['deg']])[None, ...].expand(m_v.size(0), param_sz[1], param_sz[0]) 87 | 88 | aux = torch.bmm(parameter_mat, torch.transpose(m_v, 1, 2)) 89 | 90 | return torch.transpose(torch.nn.Sigmoid()(aux), 1, 2) 91 | 92 | def init_duvenaud(self, params): 93 | learn_args = [] 94 | learn_modules = [] 95 | args = {} 96 | 97 | # Filter degree 0 (the message will be 0 and therefore there is no update 98 | args['deg'] = [i for i in params['deg'] if i!=0] 99 | args['in'] = params['in'] 100 | args['out'] = params['out'] 101 | 102 | # Define a parameter matrix H for each degree. 103 | learn_args.append(torch.nn.Parameter(torch.randn(len(args['deg']), args['in'], args['out']))) 104 | 105 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 106 | 107 | # GG-NN, Li et al. 108 | def u_ggnn(self, h_v, m_v, opt={}): 109 | h_v.contiguous() 110 | m_v.contiguous() 111 | h_new = self.learn_modules[0](torch.transpose(m_v, 0, 1), torch.unsqueeze(h_v, 0))[0] # 0 or 1??? 112 | return torch.transpose(h_new, 0, 1) 113 | 114 | def init_ggnn(self, params): 115 | learn_args = [] 116 | learn_modules = [] 117 | args = {} 118 | 119 | args['in_m'] = params['in_m'] 120 | args['out'] = params['out'] 121 | 122 | # GRU 123 | learn_modules.append(nn.GRU(params['in_m'], params['out'])) 124 | 125 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 126 | 127 | # Battaglia et al. (2016), Interaction Networks 128 | def u_intnet(self, h_v, m_v, opt): 129 | if opt['x_v'].ndimension(): 130 | input_tensor = torch.cat([h_v, opt['x_v'], torch.squeeze(m_v)], 1) 131 | else: 132 | input_tensor = torch.cat([h_v, torch.squeeze(m_v)], 1) 133 | 134 | return self.learn_modules[0](input_tensor) 135 | 136 | def init_intnet(self, params): 137 | learn_args = [] 138 | learn_modules = [] 139 | args = {} 140 | 141 | args['in'] = params['in'] 142 | args['out'] = params['out'] 143 | 144 | learn_modules.append(NNet(n_in=params['in'], n_out=params['out'])) 145 | 146 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 147 | 148 | def u_mpnn(self, h_v, m_v, opt={}): 149 | h_in = h_v.view(-1,h_v.size(2)) 150 | m_in = m_v.view(-1,m_v.size(2)) 151 | h_new = self.learn_modules[0](m_in[None,...],h_in[None,...])[0] # 0 or 1??? 152 | return torch.squeeze(h_new).view(h_v.size()) 153 | 154 | def init_mpnn(self, params): 155 | learn_args = [] 156 | learn_modules = [] 157 | args = {} 158 | 159 | args['in_m'] = params['in_m'] 160 | args['out'] = params['out'] 161 | 162 | # GRU 163 | learn_modules.append(nn.GRU(params['in_m'], params['out'])) 164 | 165 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 166 | 167 | 168 | if __name__ == '__main__': 169 | 170 | # Parse optios for downloading 171 | parser = argparse.ArgumentParser(description='QM9 Object.') 172 | # Optional argument 173 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', default=['./data/qm9/dsgdb9nsd/']) 174 | 175 | args = parser.parse_args() 176 | root = args.root[0] 177 | 178 | files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 179 | 180 | idx = np.random.permutation(len(files)) 181 | idx = idx.tolist() 182 | 183 | valid_ids = [files[i] for i in idx[0:10000]] 184 | test_ids = [files[i] for i in idx[10000:20000]] 185 | train_ids = [files[i] for i in idx[20000:]] 186 | 187 | data_train = datasets.Qm9(root, train_ids) 188 | data_valid = datasets.Qm9(root, valid_ids) 189 | data_test = datasets.Qm9(root, test_ids) 190 | 191 | print('STATS') 192 | # d = datasets.utils.get_graph_stats(data_test, 'degrees') 193 | d = [1, 2, 3, 4] 194 | 195 | print('Message') 196 | ## Define message 197 | m = MessageFunction('duvenaud') 198 | 199 | ## Parameters for the update function 200 | # Select one graph 201 | g_tuple, l = data_train[0] 202 | g, h_t, e = g_tuple 203 | 204 | m_v = m.forward(h_t[0], h_t[1], e[list(e.keys())[0]]) 205 | in_n = len(m_v) 206 | out_n = 30 207 | 208 | print('Update') 209 | ## Define Update 210 | u = UpdateFunction('duvenaud', args={'deg': d, 'in': in_n , 'out': out_n}) 211 | 212 | print(m.get_definition()) 213 | print(u.get_definition()) 214 | 215 | start = time.time() 216 | 217 | # Select one graph 218 | g_tuple, l = data_train[0] 219 | g, h_t, e = g_tuple 220 | 221 | h_t1 = {} 222 | for v in g.nodes_iter(): 223 | neigh = g.neighbors(v) 224 | m_neigh = dtype() 225 | for w in neigh: 226 | if (v, w) in e: 227 | e_vw = e[(v, w)] 228 | else: 229 | e_vw = e[(w, v)] 230 | m_v = m.forward(h_t[v], h_t[w], e_vw) 231 | if len(m_neigh): 232 | m_neigh += m_v 233 | else: 234 | m_neigh = m_v 235 | 236 | # Duvenaud 237 | opt = {'deg': len(neigh)} 238 | h_t1[v] = u.forward(h_t[v], m_neigh, opt) 239 | 240 | end = time.time() 241 | 242 | print('Input nodes') 243 | print(h_t) 244 | print('Message') 245 | print(h_t1) 246 | print('Time') 247 | print(end - start) -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | #Data 2 | 3 | Download the datasets. 4 | 5 | ## Usage 6 | 7 | $ python download.py qm9 -p ./ 8 | $ ./download.py qm9 mutag enzymes 9 | 10 | For more information: 11 | 12 | $ ./download.py -h 13 | 14 | ## TODO 15 | 16 | - [x] Figshare downloader. 17 | - [x] QM9. 18 | - [x] MUTAG. 19 | - [x] ENZYMES. 20 | 21 | ## Datasets 22 | 23 | ### MUTAG & ENZYMES 24 | 25 | Downloaded from [here](https://figshare.com/articles/MUTAG_and_ENZYMES_DataSet/899875). 26 | 27 | ``` 28 | @article{shervashidze2011weisfeiler, 29 | title={Weisfeiler-lehman graph kernels}, 30 | author={Shervashidze, Nino and Schweitzer, Pascal and Leeuwen, Erik Jan van and Mehlhorn, Kurt and Borgwardt, Karsten M}, 31 | journal={Journal of Machine Learning Research}, 32 | volume={12}, 33 | number={Sep}, 34 | pages={2539--2561}, 35 | year={2011} 36 | } 37 | 38 | ``` 39 | 40 | ### QM9 41 | 42 | Downloaded from [here](https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904). 43 | 44 | ``` 45 | @article{ramakrishnan2014quantum, 46 | title={Quantum chemistry structures and properties of 134 kilo molecules}, 47 | author={Ramakrishnan, Raghunathan and Dral, Pavlo O and Rupp, Matthias and Von Lilienfeld, O Anatole}, 48 | journal={Scientific data}, 49 | volume={1}, 50 | year={2014}, 51 | publisher={Nature Publishing Group} 52 | } 53 | 54 | ``` 55 | 56 | ## Authors 57 | 58 | * Pau Riba (@priba) 59 | * Anjan Dutta (@AnjanDutta) 60 | -------------------------------------------------------------------------------- /data/download.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | download.py: Download the needed datasets. 6 | 7 | Usage: 8 | download.py [-h] [-p dir] D [D ...] 9 | Example: 10 | $ ./download.py qm9 mutag enzymes -p ./ 11 | $ python download.py qm9 mutag enzymes -p ./ 12 | 13 | """ 14 | 15 | __author__ = "Pau Riba, Anjan Dutta" 16 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 17 | 18 | import argparse 19 | import os 20 | import wget 21 | import zipfile 22 | import tarfile 23 | 24 | # Download file 25 | def download_file(url, file_ext, dir_path='./'): 26 | file_name = wget.download(url, out=dir_path) 27 | file_path = os.path.join(dir_path, file_name) 28 | if file_ext == '.zip': 29 | zip_ref = zipfile.ZipFile(file_path,'r') 30 | zip_ref.extractall(dir_path) 31 | zip_ref.close() 32 | os.remove(file_path) 33 | 34 | # Download data from figshare 35 | def download_figshare(file_name, file_ext, dir_path='./', change_name = None): 36 | prepare_data_dir(dir_path) 37 | url = 'https://ndownloader.figshare.com/files/' + file_name 38 | wget.download(url, out=dir_path) 39 | file_path = os.path.join(dir_path, file_name) 40 | 41 | if file_ext == '.zip': 42 | zip_ref = zipfile.ZipFile(file_path,'r') 43 | if change_name is not None: 44 | dir_path = os.path.join(dir_path, change_name) 45 | zip_ref.extractall(dir_path) 46 | zip_ref.close() 47 | os.remove(file_path) 48 | elif file_ext == '.tar.bz2': 49 | tar_ref = tarfile.open(file_path,'r:bz2') 50 | if change_name is not None: 51 | dir_path = os.path.join(dir_path, change_name) 52 | tar_ref.extractall(dir_path) 53 | tar_ref.close() 54 | os.remove(file_path) 55 | elif change_name is not None: 56 | os.rename(file_path, os.path.join(dir_path, change_name)) 57 | 58 | # Download QM9 dataset 59 | def download_qm9(data_dir): 60 | data_dir = os.path.join(data_dir, 'qm9') 61 | if os.path.exists(data_dir): 62 | print('Found QM9 dataset - SKIP!') 63 | return 64 | 65 | prepare_data_dir(data_dir) 66 | 67 | # README 68 | download_figshare('3195392', '.txt', data_dir, 'readme.txt') 69 | # atomref 70 | download_figshare('3195395', '.txt', data_dir, 'atomref.txt') 71 | # Validation 72 | download_figshare('3195401', '.txt', data_dir, 'validation.txt') 73 | # Uncharacterized 74 | download_figshare('3195404', '.txt', data_dir, 'uncharacterized.txt') 75 | # dsgdb9nsd.xyz.tar.bz2 76 | download_figshare('3195389', '.tar.bz2', data_dir, 'dsgdb9nsd') 77 | # dsC7O2H10nsd.xyz.tar.bz2 78 | download_figshare('3195398', '.tar.bz2', data_dir, 'dsC702H10nsd') 79 | 80 | # If not exists creates the specified folder 81 | def prepare_data_dir(path): 82 | if not os.path.exists(path): 83 | os.mkdir(path) 84 | 85 | if __name__ == '__main__': 86 | 87 | # Parse optios for downloading 88 | parser = argparse.ArgumentParser(description='Download dataset for Message Passing Algorithm.') 89 | # Positional arguments 90 | parser.add_argument('datasets', metavar='D', type=str.lower, nargs='+', choices=['qm9','mutag', 91 | 'enzymes', 'graph_kernels'], help='Name of dataset to download [QM9,MUTAG,ENZYMES,GRAPH_KERNELS]') 92 | # I/O 93 | parser.add_argument('-p', '--path', metavar='dir', type=str, nargs=1, 94 | help='path to store the data (default ./)') 95 | 96 | args = parser.parse_args() 97 | 98 | # Check parameters 99 | if args.path is None: 100 | args.path = './' 101 | else: 102 | args.path = args.path[0] 103 | 104 | # Init folder 105 | prepare_data_dir(args.path) 106 | 107 | # Select datasets 108 | if 'qm9' in args.datasets: 109 | download_qm9(args.path) 110 | if 'mutag' in args.datasets: 111 | download_figshare('3132449', '.zip', args.path) 112 | if 'enzymes' in args.datasets: 113 | download_figshare('3132446', '.zip', args.path) 114 | if 'graph_kernels' in args.datasets: 115 | download_file('https://www.ethz.ch/content/dam/ethz/special-interest/bsse/borgwardt-lab/Projects/GraphKernels/data_graphml.zip', '.zip', args.path) 116 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | Collection of classes of different datasets, implementing a data generator in PyTorch style 4 | 5 | - QM9: data generator for the [QM9](http://quantum-machine.org/datasets/) dataset. 6 | - LETTER: data generator for the [Letter](http://www.fki.inf.unibe.ch/databases/iam-graph-database) dataset. 7 | - GREC: data generator for the [GREC](http://www.fki.inf.unibe.ch/databases/iam-graph-database) dataset. 8 | - GWHISTOGRAPH: data generator for the [HistoGraph](http://www.histograph.ch/) dataset. 9 | - MUTAG: data generator for the [MUTAG](https://figshare.com/articles/MUTAG_and_ENZYMES_DataSet/899875) dataset. 10 | - ENZYMES: data generator for the [ENZYMES](https://figshare.com/articles/MUTAG_and_ENZYMES_DataSet/899875) dataset. 11 | 12 | ## Authors 13 | 14 | * Pau Riba (@priba) 15 | * Anjan Dutta (@AnjanDutta) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .grec import GREC 2 | from .gwhistograph import GWHISTOGRAPH 3 | from .letter import LETTER 4 | from .mutag import MUTAG 5 | from .qm9 import Qm9 6 | 7 | __all__ = ('GREC', 'GWHISTOGRAPH', 'LETTER', 'MUTAG', 'Qm9') 8 | -------------------------------------------------------------------------------- /datasets/grec.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os, sys 3 | import argparse 4 | import networkx as nx 5 | 6 | reader_folder = os.path.realpath(os.path.abspath('../GraphReader')) 7 | if reader_folder not in sys.path: 8 | sys.path.insert(1, reader_folder) 9 | 10 | from GraphReader.graph_reader import read_2cols_set_files, create_numeric_classes, read_cxl, create_graph_grec 11 | 12 | __author__ = "Pau Riba, Anjan Dutta" 13 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 14 | 15 | 16 | class GREC(data.Dataset): 17 | def __init__(self, root_path, ids, classes): 18 | self.root = root_path 19 | self.subdir = 'data' 20 | self.classes = classes 21 | self.ids = ids 22 | 23 | def __getitem__(self, index): 24 | g = create_graph_grec(os.path.join(self.root, self.subdir, self.ids[index])) 25 | target = self.classes[index] 26 | h = self.vertex_transform(g) 27 | g, e = self.edge_transform(g) 28 | target = self.target_transform(target) 29 | return (g, h, e), target 30 | 31 | def __len__(self): 32 | return len(self.ids) 33 | 34 | def target_transform(self, target): 35 | return [int(target)-1] 36 | 37 | def vertex_transform(self, g): 38 | h = [] 39 | for n, d in g.nodes_iter(data=True): 40 | h_t = [] 41 | h_t += [float(x) for x in d['labels']] 42 | h.append(h_t) 43 | return h 44 | 45 | def edge_transform(self, g): 46 | e = {} 47 | for n1, n2, d in g.edges_iter(data=True): 48 | e_t = [] 49 | e_t += [float(x) for x in list(d.values())] 50 | e[(n1, n2)] = e_t 51 | return nx.to_numpy_matrix(g), e 52 | 53 | 54 | if __name__ == '__main__': 55 | # Parse optios for downloading 56 | parser = argparse.ArgumentParser(description='GREC Object.') 57 | # Optional argument 58 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', 59 | default=['/home/adutta/Workspace/Datasets/Graphs/GREC']) 60 | 61 | args = parser.parse_args() 62 | root = args.root[0] 63 | 64 | train_classes, train_ids = read_cxl(os.path.join(root, 'data/train.cxl')) 65 | test_classes, test_ids = read_cxl(os.path.join(root, 'data/test.cxl')) 66 | valid_classes, valid_ids = read_cxl(os.path.join(root, 'data/valid.cxl')) 67 | 68 | num_classes = len(list(set(train_classes+valid_classes+test_classes))) 69 | 70 | data_train = GREC(root, train_ids, train_classes, num_classes) 71 | data_valid = GREC(root, valid_ids, valid_classes, num_classes) 72 | data_test = GREC(root, test_ids, test_classes, num_classes) 73 | 74 | print(len(data_train)) 75 | print(len(data_valid)) 76 | print(len(data_test)) 77 | 78 | for i in range(len(train_ids)): 79 | print(data_train[i]) 80 | 81 | for i in range(len(valid_ids)): 82 | print(data_valid[i]) 83 | 84 | for i in range(len(test_ids)): 85 | print(data_test[i]) 86 | 87 | print(data_train[61]) 88 | print(data_valid[1]) 89 | print(data_test[1]) 90 | -------------------------------------------------------------------------------- /datasets/gwhistograph.py: -------------------------------------------------------------------------------- 1 | """ 2 | mutag.py: 3 | 4 | Usage: 5 | 6 | """ 7 | 8 | import torch.utils.data as data 9 | import os, sys 10 | import argparse 11 | import networkx as nx 12 | 13 | reader_folder = os.path.realpath( os.path.abspath('../GraphReader')) 14 | if reader_folder not in sys.path: 15 | sys.path.insert(1, reader_folder) 16 | 17 | from GraphReader.graph_reader import read_2cols_set_files, create_numeric_classes, create_graph_gwhist 18 | 19 | __author__ = "Pau Riba, Anjan Dutta" 20 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 21 | 22 | 23 | class GWHISTOGRAPH(data.Dataset): 24 | 25 | def __init__(self, root_path, subset, ids, classes, max_class_num): 26 | 27 | self.root = root_path 28 | self.subdir = 'Data/Word_Graphs/01_Skew' 29 | self.subset = subset 30 | self.classes = classes 31 | self.ids = ids 32 | self.max_class_num = max_class_num 33 | 34 | def __getitem__(self, index): 35 | 36 | g = create_graph_gwhist(os.path.join(self.root, self.subdir, self.subset, self.ids[index])) 37 | 38 | target = self.classes[index] 39 | 40 | h = self.vertex_transform(g) 41 | 42 | g, e = self.edge_transform(g) 43 | 44 | target = self.target_transform(target) 45 | 46 | return (g, h, e), target 47 | 48 | def __len__(self): 49 | return len(self.ids) 50 | 51 | def target_transform(self, target): 52 | # [int(i == target-1) for i in range(self.max_class_num)] 53 | # return target_one_hot 54 | return [target] 55 | 56 | def vertex_transform(self, g): 57 | h = [] 58 | for n, d in g.nodes_iter(data=True): 59 | h_t = [] 60 | h_t += [float(x) for x in d['labels']] 61 | h.append(h_t) 62 | return h 63 | 64 | def edge_transform(self, g): 65 | e = {} 66 | for n1, n2, d in g.edges_iter(data=True): 67 | e_t = [] 68 | e_t += [1] 69 | e[(n1, n2)] = e_t 70 | return nx.to_numpy_matrix(g), e 71 | 72 | if __name__ == '__main__': 73 | 74 | # Parse optios for downloading 75 | parser = argparse.ArgumentParser(description='GWHISTOGRAPH Object.') 76 | # Optional argument 77 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', default=['/home/adutta/Workspace/Datasets/GWHistoGraphs']) 78 | parser.add_argument('--subset', nargs=1, help='Specify the sub dataset.', default=['01_Keypoint']) 79 | 80 | args = parser.parse_args() 81 | root = args.root[0] 82 | subset = args.subset[0] 83 | 84 | train_classes, train_ids = read_2cols_set_files(os.path.join(root, 'Set/Train.txt')) 85 | test_classes, test_ids = read_2cols_set_files(os.path.join(root, 'Set/Test.txt')) 86 | valid_classes, valid_ids = read_2cols_set_files(os.path.join(root, 'Set/Valid.txt')) 87 | 88 | train_classes, valid_classes, test_classes = create_numeric_classes(train_classes, valid_classes, test_classes) 89 | 90 | num_classes = max(train_classes + valid_classes + test_classes) 91 | 92 | data_train = GWHISTOGRAPH(root, subset, train_ids, train_classes, num_classes) 93 | data_valid = GWHISTOGRAPH(root, subset, valid_ids, valid_classes, num_classes) 94 | data_test = GWHISTOGRAPH(root, subset, test_ids, test_classes, num_classes) 95 | 96 | print(len(data_train)) 97 | print(len(data_valid)) 98 | print(len(data_test)) 99 | 100 | print(data_train[1]) 101 | print(data_valid[1]) 102 | print(data_test[1]) 103 | -------------------------------------------------------------------------------- /datasets/letter.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os, sys 3 | import argparse 4 | import networkx as nx 5 | 6 | reader_folder = os.path.realpath(os.path.abspath('../GraphReader')) 7 | if reader_folder not in sys.path: 8 | sys.path.insert(1, reader_folder) 9 | 10 | from GraphReader.graph_reader import read_cxl, create_graph_letter 11 | 12 | __author__ = "Pau Riba, Anjan Dutta" 13 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 14 | 15 | 16 | class LETTER(data.Dataset): 17 | def __init__(self, root_path, subset, ids, classes, class_list): 18 | self.root = root_path 19 | self.subset = subset 20 | self.classes = classes 21 | self.ids = ids 22 | self.class_list = class_list 23 | 24 | def __getitem__(self, index): 25 | g = create_graph_letter(os.path.join(self.root, self.subset, self.ids[index])) 26 | target = self.classes[index] 27 | h = self.vertex_transform(g) 28 | g, e = self.edge_transform(g) 29 | target = self.target_transform(target) 30 | return (g, h, e), target 31 | 32 | def __len__(self): 33 | return len(self.ids) 34 | 35 | def target_transform(self, target): 36 | return [self.class_list.index(target)] 37 | 38 | def vertex_transform(self, g): 39 | h = [] 40 | for n, d in g.nodes_iter(data=True): 41 | h_t = [] 42 | h_t += [float(x) for x in d['labels']] 43 | h.append(h_t) 44 | return h 45 | 46 | def edge_transform(self, g): 47 | e = {} 48 | for n1, n2, d in g.edges_iter(data=True): 49 | e_t = [] 50 | e_t += [1] 51 | e[(n1, n2)] = e_t 52 | return nx.to_numpy_matrix(g), e 53 | 54 | 55 | if __name__ == '__main__': 56 | # Parse optios for downloading 57 | parser = argparse.ArgumentParser(description='Letter Object.') 58 | # Optional argument 59 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', 60 | default=['/home/adutta/Workspace/Datasets/STDGraphs/Letter']) 61 | parser.add_argument('--subset', nargs=1, help='Specify the sub dataset.', default=['LOW']) 62 | 63 | args = parser.parse_args() 64 | root = args.root[0] 65 | subset = args.subset[0] 66 | 67 | train_classes, train_ids = read_cxl(os.path.join(root, subset, 'train.cxl')) 68 | test_classes, test_ids = read_cxl(os.path.join(root, subset, 'test.cxl')) 69 | valid_classes, valid_ids = read_cxl(os.path.join(root, subset, 'validation.cxl')) 70 | 71 | num_classes = len(list(set(train_classes + valid_classes + test_classes))) 72 | 73 | data_train = LETTER(root, subset, train_ids, train_classes, num_classes) 74 | data_valid = LETTER(root, subset, valid_ids, valid_classes, num_classes) 75 | data_test = LETTER(root, subset, test_ids, test_classes, num_classes) 76 | 77 | print(len(data_train)) 78 | print(len(data_valid)) 79 | print(len(data_test)) 80 | 81 | for i in range(len(train_ids)): 82 | print(data_train[i]) 83 | 84 | for i in range(len(valid_ids)): 85 | print(data_valid[i]) 86 | 87 | for i in range(len(test_ids)): 88 | print(data_test[i]) 89 | 90 | print(data_train[1]) 91 | print(data_valid[1]) 92 | print(data_test[1]) 93 | -------------------------------------------------------------------------------- /datasets/mutag.py: -------------------------------------------------------------------------------- 1 | """ 2 | mutag.py: 3 | 4 | Usage: 5 | 6 | """ 7 | import networkx as nx 8 | 9 | import torch.utils.data as data 10 | import os, sys 11 | import argparse 12 | 13 | import datasets.utils as utils 14 | 15 | reader_folder = os.path.realpath( os.path.abspath('../GraphReader')) 16 | if reader_folder not in sys.path: 17 | sys.path.insert(1, reader_folder) 18 | 19 | from GraphReader.graph_reader import divide_datasets 20 | 21 | __author__ = "Pau Riba, Anjan Dutta" 22 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 23 | 24 | class MUTAG(data.Dataset): 25 | 26 | def __init__(self, root_path, ids, classes): 27 | 28 | self.root = root_path 29 | self.classes = classes 30 | self.ids = ids 31 | 32 | def __getitem__(self, index): 33 | 34 | #TODO: Manually have to check the convert_node_labels_to_integers function 35 | g = nx.convert_node_labels_to_integers(nx.read_graphml(os.path.join(self.root, self.ids[index]))) 36 | 37 | target = self.classes[index] 38 | 39 | h = self.vertex_transform(g) 40 | 41 | g, e = self.edge_transform(g) 42 | 43 | target = self.target_transform(target) 44 | 45 | return (g, h, e), target 46 | 47 | def __len__(self): 48 | return len(self.ids) 49 | 50 | def vertex_transform(self, g): 51 | h = [] 52 | for n, d in g.nodes_iter(data=True): 53 | h_t = [] 54 | h_t.append(d['label']) 55 | h.append(h_t) 56 | return h 57 | 58 | def edge_transform(self, g): 59 | e = {} 60 | for n1, n2, d in g.edges_iter(data=True): 61 | e_t = [] 62 | e_t.append(d['label']) 63 | e[(n1, n2)] = e_t 64 | return nx.to_numpy_matrix(g), e 65 | 66 | def target_transform(self, target): 67 | return [target] 68 | 69 | if __name__ == '__main__': 70 | 71 | # Parse optios for downloading 72 | parser = argparse.ArgumentParser(description='MUTAG Object.') 73 | # Optional argument 74 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', default=['/home/adutta/Workspace/Datasets/Graphs/MUTAG']) 75 | 76 | args = parser.parse_args() 77 | root = args.root[0] 78 | 79 | label_file = 'MUTAG.label' 80 | list_file = 'MUTAG.list' 81 | with open(os.path.join(root, label_file), 'r') as f: 82 | l = f.read() 83 | classes = [int(s) for s in l.split() if s.isdigit()] 84 | with open(os.path.join(root, list_file), 'r') as f: 85 | files = f.read().splitlines() 86 | 87 | train_ids, train_classes, valid_ids, valid_classes, test_ids, test_classes = divide_datasets(files, classes) 88 | 89 | data_train = MUTAG(root, train_ids, train_classes) 90 | data_valid = MUTAG(root, valid_ids, valid_classes) 91 | data_test = MUTAG(root, test_ids, test_classes) 92 | 93 | print(len(data_train)) 94 | print(len(data_valid)) 95 | print(len(data_test)) 96 | 97 | print(data_train[1]) 98 | print(data_valid[1]) 99 | print(data_test[1]) -------------------------------------------------------------------------------- /datasets/qm9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | qm9.py: 5 | 6 | Usage: 7 | 8 | """ 9 | 10 | # Networkx should be imported before torch 11 | import networkx as nx 12 | 13 | import torch.utils.data as data 14 | import numpy as np 15 | import argparse 16 | 17 | import datasets.utils as utils 18 | import time 19 | import os,sys 20 | 21 | import torch 22 | 23 | reader_folder = os.path.realpath( os.path.abspath('../GraphReader')) 24 | if reader_folder not in sys.path: 25 | sys.path.insert(1, reader_folder) 26 | 27 | from GraphReader.graph_reader import xyz_graph_reader 28 | 29 | __author__ = "Pau Riba, Anjan Dutta" 30 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 31 | 32 | class Qm9(data.Dataset): 33 | 34 | # Constructor 35 | def __init__(self, root_path, ids, vertex_transform=utils.qm9_nodes, edge_transform=utils.qm9_edges, 36 | target_transform=None, e_representation='raw_distance'): 37 | self.root = root_path 38 | self.ids = ids 39 | self.vertex_transform = vertex_transform 40 | self.edge_transform = edge_transform 41 | self.target_transform = target_transform 42 | self.e_representation = e_representation 43 | 44 | def __getitem__(self, index): 45 | g, target = xyz_graph_reader(os.path.join(self.root, self.ids[index])) 46 | if self.vertex_transform is not None: 47 | h = self.vertex_transform(g) 48 | 49 | if self.edge_transform is not None: 50 | g, e = self.edge_transform(g, self.e_representation) 51 | 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | 55 | return (g, h, e), target 56 | 57 | def __len__(self): 58 | return len(self.ids) 59 | 60 | def set_target_transform(self, target_transform): 61 | self.target_transform = target_transform 62 | 63 | if __name__ == '__main__': 64 | 65 | # Parse optios for downloading 66 | parser = argparse.ArgumentParser(description='QM9 Object.') 67 | # Optional argument 68 | parser.add_argument('--root', nargs=1, help='Specify the data directory.', default=['../data/qm9/dsgdb9nsd']) 69 | 70 | args = parser.parse_args() 71 | root = args.root[0] 72 | 73 | files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 74 | 75 | idx = np.random.permutation(len(files)) 76 | idx = idx.tolist() 77 | 78 | valid_ids = [files[i] for i in idx[0:10000]] 79 | test_ids = [files[i] for i in idx[10000:20000]] 80 | train_ids = [files[i] for i in idx[20000:]] 81 | 82 | data_train = Qm9(root, train_ids, vertex_transform=utils.qm9_nodes, edge_transform=lambda g: utils.qm9_edges(g, e_representation='raw_distance')) 83 | data_valid = Qm9(root, valid_ids) 84 | data_test = Qm9(root, test_ids) 85 | 86 | print(len(data_train)) 87 | print(len(data_valid)) 88 | print(len(data_test)) 89 | 90 | print(data_train[1]) 91 | print(data_valid[1]) 92 | print(data_test[1]) 93 | 94 | start = time.time() 95 | print(utils.get_graph_stats(data_valid, 'degrees')) 96 | end = time.time() 97 | print('Time Statistics Par') 98 | print(end - start) 99 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | utils.py: Functions to process dataset graphs. 6 | 7 | Usage: 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | 13 | import rdkit 14 | import torch 15 | from joblib import Parallel, delayed 16 | import multiprocessing 17 | import networkx as nx 18 | import numpy as np 19 | import shutil 20 | import os 21 | 22 | __author__ = "Pau Riba, Anjan Dutta" 23 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 24 | 25 | 26 | def qm9_nodes(g, hydrogen=False): 27 | h = [] 28 | for n, d in g.nodes_iter(data=True): 29 | h_t = [] 30 | # Atom type (One-hot H, C, N, O F) 31 | h_t += [int(d['a_type'] == x) for x in ['H', 'C', 'N', 'O', 'F']] 32 | # Atomic number 33 | h_t.append(d['a_num']) 34 | # Partial Charge 35 | h_t.append(d['pc']) 36 | # Acceptor 37 | h_t.append(d['acceptor']) 38 | # Donor 39 | h_t.append(d['donor']) 40 | # Aromatic 41 | h_t.append(int(d['aromatic'])) 42 | # Hybradization 43 | h_t += [int(d['hybridization'] == x) for x in [rdkit.Chem.rdchem.HybridizationType.SP, rdkit.Chem.rdchem.HybridizationType.SP2, rdkit.Chem.rdchem.HybridizationType.SP3]] 44 | # If number hydrogen is used as a 45 | if hydrogen: 46 | h_t.append(d['num_h']) 47 | h.append(h_t) 48 | return h 49 | 50 | 51 | def qm9_edges(g, e_representation='raw_distance'): 52 | remove_edges = [] 53 | e={} 54 | for n1, n2, d in g.edges_iter(data=True): 55 | e_t = [] 56 | # Raw distance function 57 | if e_representation == 'chem_graph': 58 | if d['b_type'] is None: 59 | remove_edges += [(n1, n2)] 60 | else: 61 | e_t += [i+1 for i, x in enumerate([rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE, 62 | rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC]) 63 | if x == d['b_type']] 64 | elif e_representation == 'distance_bin': 65 | if d['b_type'] is None: 66 | step = (6-2)/8.0 67 | start = 2 68 | b = 9 69 | for i in range(0, 9): 70 | if d['distance'] < (start+i*step): 71 | b = i 72 | break 73 | e_t.append(b+5) 74 | else: 75 | e_t += [i+1 for i, x in enumerate([rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE, 76 | rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC]) 77 | if x == d['b_type']] 78 | elif e_representation == 'raw_distance': 79 | if d['b_type'] is None: 80 | remove_edges += [(n1, n2)] 81 | else: 82 | e_t.append(d['distance']) 83 | e_t += [int(d['b_type'] == x) for x in [rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE, 84 | rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC]] 85 | else: 86 | print('Incorrect Edge representation transform') 87 | quit() 88 | if e_t: 89 | e[(n1, n2)] = e_t 90 | for edg in remove_edges: 91 | g.remove_edge(*edg) 92 | return nx.to_numpy_matrix(g), e 93 | 94 | 95 | def normalize_data(data, mean, std): 96 | data_norm = (data-mean)/std 97 | return data_norm 98 | 99 | 100 | def get_values(obj, start, end, prop): 101 | vals = [] 102 | for i in range(start, end): 103 | v = {} 104 | if 'degrees' in prop: 105 | v['degrees'] = set(sum(obj[i][0][0].sum(axis=0, dtype='int').tolist(), [])) 106 | if 'edge_labels' in prop: 107 | v['edge_labels'] = set(sum(list(obj[i][0][2].values()), [])) 108 | if 'target_mean' in prop or 'target_std' in prop: 109 | v['params'] = obj[i][1] 110 | vals.append(v) 111 | return vals 112 | 113 | 114 | def get_graph_stats(graph_obj_handle, prop='degrees'): 115 | # if prop == 'degrees': 116 | num_cores = multiprocessing.cpu_count() 117 | inputs = [int(i*len(graph_obj_handle)/num_cores) for i in range(num_cores)] + [len(graph_obj_handle)] 118 | res = Parallel(n_jobs=num_cores)(delayed(get_values)(graph_obj_handle, inputs[i], inputs[i+1], prop) for i in range(num_cores)) 119 | 120 | stat_dict = {} 121 | 122 | if 'degrees' in prop: 123 | stat_dict['degrees'] = list(set([d for core_res in res for file_res in core_res for d in file_res['degrees']])) 124 | if 'edge_labels' in prop: 125 | stat_dict['edge_labels'] = list(set([d for core_res in res for file_res in core_res for d in file_res['edge_labels']])) 126 | if 'target_mean' in prop or 'target_std' in prop: 127 | param = np.array([file_res['params'] for core_res in res for file_res in core_res]) 128 | if 'target_mean' in prop: 129 | stat_dict['target_mean'] = np.mean(param, axis=0) 130 | if 'target_std' in prop: 131 | stat_dict['target_std'] = np.std(param, axis=0) 132 | 133 | return stat_dict 134 | 135 | 136 | def accuracy(output, target, topk=(1,)): 137 | """Computes the precision@k for the specified values of k""" 138 | maxk = max(topk) 139 | batch_size = target.size(0) 140 | _, pred = output.topk(maxk, 1, True, True) 141 | pred = pred.t() 142 | pred = pred.type_as(target) 143 | target = target.type_as(pred) 144 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 145 | res = [] 146 | for k in topk: 147 | correct_k = correct[:k].view(-1).float().sum(0) 148 | res.append(correct_k.mul_(100.0 / batch_size)) 149 | return res 150 | 151 | 152 | def collate_g(batch): 153 | 154 | batch_sizes = np.max(np.array([[len(input_b[1]), len(input_b[1][0]), len(input_b[2]), 155 | len(list(input_b[2].values())[0])] 156 | if input_b[2] else 157 | [len(input_b[1]), len(input_b[1][0]), 0,0] 158 | for (input_b, target_b) in batch]), axis=0) 159 | 160 | g = np.zeros((len(batch), batch_sizes[0], batch_sizes[0])) 161 | h = np.zeros((len(batch), batch_sizes[0], batch_sizes[1])) 162 | e = np.zeros((len(batch), batch_sizes[0], batch_sizes[0], batch_sizes[3])) 163 | 164 | target = np.zeros((len(batch), len(batch[0][1]))) 165 | 166 | for i in range(len(batch)): 167 | 168 | num_nodes = len(batch[i][0][1]) 169 | 170 | # Adjacency matrix 171 | g[i, 0:num_nodes, 0:num_nodes] = batch[i][0][0] 172 | 173 | # Node features 174 | h[i, 0:num_nodes, :] = batch[i][0][1] 175 | 176 | # Edges 177 | for edge in batch[i][0][2].keys(): 178 | e[i, edge[0], edge[1], :] = batch[i][0][2][edge] 179 | e[i, edge[1], edge[0], :] = batch[i][0][2][edge] 180 | 181 | # Target 182 | target[i, :] = batch[i][1] 183 | 184 | g = torch.FloatTensor(g) 185 | h = torch.FloatTensor(h) 186 | e = torch.FloatTensor(e) 187 | target = torch.FloatTensor(target) 188 | 189 | return g, h, e, target 190 | 191 | 192 | def save_checkpoint(state, is_best, directory): 193 | 194 | if not os.path.isdir(directory): 195 | os.makedirs(directory) 196 | checkpoint_file = os.path.join(directory, 'checkpoint.pth') 197 | best_model_file = os.path.join(directory, 'model_best.pth') 198 | torch.save(state, checkpoint_file) 199 | if is_best: 200 | shutil.copyfile(checkpoint_file, best_model_file) 201 | 202 | 203 | -------------------------------------------------------------------------------- /demos/demo_grec_duvenaud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodology defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | """ 11 | 12 | # Torch 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | import time 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Our Modules 24 | reader_folder = os.path.realpath(os.path.abspath('..')) 25 | if reader_folder not in sys.path: 26 | sys.path.append(reader_folder) 27 | import datasets 28 | from datasets import utils 29 | from models.MPNN_Duvenaud import MpnnDuvenaud 30 | from LogMetric import AverageMeter, Logger 31 | from GraphReader.graph_reader import read_cxl 32 | 33 | __author__ = "Pau Riba, Anjan Dutta" 34 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | 38 | 39 | # Parser check 40 | def restricted_float(x, inter): 41 | x = float(x) 42 | if x < inter[0] or x > inter[1]: 43 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 44 | return x 45 | 46 | # Argument parser 47 | parser = argparse.ArgumentParser(description='Neural message passing') 48 | 49 | parser.add_argument('--dataset', default='GREC', help='GREC') 50 | parser.add_argument('--datasetPath', default='../data/GREC/', help='dataset path') 51 | parser.add_argument('--logPath', default='../log/grec/duvenaud/checkpoint', help='log path') 52 | parser.add_argument('--resume', default='../checkpoint/grec/duvenaud', help='path to latest checkpoint') 53 | # Optimization Options 54 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 55 | help='Input batch size for training (default: 20)') 56 | parser.add_argument('--no-cuda', action='store_true', default=False, 57 | help='Enables CUDA training') 58 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 59 | help='Number of epochs to train (default: 360)') 60 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 0.5]), default=0.001, metavar='LR', 61 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 62 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 63 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 64 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 65 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 66 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 67 | help='SGD momentum (default: 0.9)') 68 | # i/o 69 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 70 | help='How many batches to wait before logging training status') 71 | # Accelerating 72 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 73 | 74 | best_acc1 = 0 75 | 76 | 77 | def main(): 78 | 79 | global args, best_acc1 80 | args = parser.parse_args() 81 | 82 | # Check if CUDA is enabled 83 | args.cuda = not args.no_cuda and torch.cuda.is_available() 84 | 85 | # Load data 86 | root = args.datasetPath 87 | 88 | print('Prepare files') 89 | 90 | train_classes, train_ids = read_cxl(os.path.join(root, 'data/train.cxl')) 91 | valid_classes, valid_ids = read_cxl(os.path.join(root, 'data/valid.cxl')) 92 | test_classes, test_ids = read_cxl(os.path.join(root, 'data/test.cxl')) 93 | 94 | num_classes = len(list(set(train_classes + test_classes + valid_classes))) 95 | 96 | data_train = datasets.GREC(root, train_ids, train_classes) 97 | data_valid = datasets.GREC(root, valid_ids, valid_classes) 98 | data_test = datasets.GREC(root, test_ids, test_classes) 99 | 100 | # Define model and optimizer 101 | print('Define model') 102 | # Select one graph 103 | g_tuple, l = data_train[0] 104 | g, h_t, e = g_tuple 105 | 106 | print('\tStatistics') 107 | stat_dict = datasets.utils.get_graph_stats(data_train, ['degrees']) 108 | 109 | # Data Loader 110 | train_loader = torch.utils.data.DataLoader(data_train, 111 | batch_size=args.batch_size, shuffle=True, 112 | collate_fn=datasets.utils.collate_g, num_workers=args.prefetch, 113 | pin_memory=True) 114 | valid_loader = torch.utils.data.DataLoader(data_valid, 115 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 116 | num_workers=args.prefetch, pin_memory=True) 117 | test_loader = torch.utils.data.DataLoader(data_test, 118 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 119 | num_workers=args.prefetch, 120 | pin_memory=True) 121 | 122 | print('\tCreate model') 123 | model = MpnnDuvenaud(stat_dict['degrees'], [len(h_t[0]), len(list(e.values())[0])], [5, 15, 15], 30, num_classes, 124 | type='classification') 125 | 126 | print('Optimizer') 127 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 128 | 129 | criterion = nn.NLLLoss() 130 | 131 | evaluation = utils.accuracy 132 | 133 | print('Logger') 134 | logger = Logger(args.logPath) 135 | 136 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 137 | 138 | # get the best checkpoint if available without training 139 | if args.resume: 140 | checkpoint_dir = args.resume 141 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 142 | if not os.path.isdir(checkpoint_dir): 143 | os.makedirs(checkpoint_dir) 144 | if os.path.isfile(best_model_file): 145 | print("=> loading best model '{}'".format(best_model_file)) 146 | checkpoint = torch.load(best_model_file) 147 | args.start_epoch = checkpoint['epoch'] 148 | best_acc1 = checkpoint['best_acc1'] 149 | model.load_state_dict(checkpoint['state_dict']) 150 | optimizer.load_state_dict(checkpoint['optimizer']) 151 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 152 | best_acc1)) 153 | else: 154 | print("=> no best model found at '{}'".format(best_model_file)) 155 | 156 | print('Check cuda') 157 | if args.cuda: 158 | print('\t* Cuda') 159 | model = model.cuda() 160 | criterion = criterion.cuda() 161 | 162 | # Epoch for loop 163 | for epoch in range(0, args.epochs): 164 | 165 | if epoch > args.epochs*args.schedule[0] and epoch < args.epochs*args.schedule[1]: 166 | args.lr -= lr_step 167 | for param_group in optimizer.param_groups: 168 | param_group['lr'] = args.lr 169 | 170 | # train for one epoch 171 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 172 | 173 | # evaluate on test set 174 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 175 | 176 | is_best = acc1 > best_acc1 177 | best_acc1 = max(acc1, best_acc1) 178 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 179 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 180 | 181 | # Logger step 182 | logger.log_value('learning_rate', args.lr).step() 183 | 184 | # get the best checkpoint and test it with test set 185 | if args.resume: 186 | checkpoint_dir = args.resume 187 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 188 | if not os.path.isdir(checkpoint_dir): 189 | os.makedirs(checkpoint_dir) 190 | if os.path.isfile(best_model_file): 191 | print("=> loading best model '{}'".format(best_model_file)) 192 | checkpoint = torch.load(best_model_file) 193 | args.start_epoch = checkpoint['epoch'] 194 | best_acc1 = checkpoint['best_acc1'] 195 | model.load_state_dict(checkpoint['state_dict']) 196 | optimizer.load_state_dict(checkpoint['optimizer']) 197 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 198 | best_acc1)) 199 | else: 200 | print("=> no best model found at '{}'".format(best_model_file)) 201 | 202 | # For testing 203 | validate(test_loader, model, criterion, evaluation) 204 | 205 | 206 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 207 | batch_time = AverageMeter() 208 | data_time = AverageMeter() 209 | losses = AverageMeter() 210 | accuracies = AverageMeter() 211 | 212 | # switch to train mode 213 | model.train() 214 | 215 | end = time.time() 216 | for i, (g, h, e, target) in enumerate(train_loader): 217 | 218 | # Prepare input data 219 | target = torch.squeeze(target).type(torch.LongTensor) 220 | if args.cuda: 221 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 222 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 223 | 224 | # Measure data loading time 225 | data_time.update(time.time() - end) 226 | 227 | def closure(): 228 | optimizer.zero_grad() 229 | 230 | # Compute output 231 | output = model(g, h, e) 232 | train_loss = criterion(output, target) 233 | 234 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 235 | 236 | # Logs 237 | losses.update(train_loss.data[0], g.size(0)) 238 | accuracies.update(acc.data[0], g.size(0)) 239 | # compute gradient and do SGD step 240 | train_loss.backward() 241 | return train_loss 242 | 243 | optimizer.step(closure) 244 | 245 | # Measure elapsed time 246 | batch_time.update(time.time() - end) 247 | end = time.time() 248 | 249 | if i % args.log_interval == 0 and i > 0: 250 | 251 | print('Epoch: [{0}][{1}/{2}]\t' 252 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 253 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 254 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 255 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})' 256 | .format(epoch, i, len(train_loader), batch_time=batch_time, 257 | data_time=data_time, loss=losses, acc=accuracies)) 258 | 259 | logger.log_value('train_epoch_loss', losses.avg) 260 | logger.log_value('train_epoch_accuracy', accuracies.avg) 261 | 262 | print('Epoch: [{0}] Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 263 | .format(epoch, acc=accuracies, loss=losses, b_time=batch_time)) 264 | 265 | 266 | def validate(val_loader, model, criterion, evaluation, logger=None): 267 | losses = AverageMeter() 268 | accuracies = AverageMeter() 269 | 270 | # switch to evaluate mode 271 | model.eval() 272 | 273 | for i, (g, h, e, target) in enumerate(val_loader): 274 | 275 | # Prepare input data 276 | target = torch.squeeze(target).type(torch.LongTensor) 277 | if args.cuda: 278 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 279 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 280 | 281 | # Compute output 282 | output = model(g, h, e) 283 | 284 | # Logs 285 | test_loss = criterion(output, target) 286 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 287 | 288 | losses.update(test_loss.data[0], g.size(0)) 289 | accuracies.update(acc.data[0], g.size(0)) 290 | 291 | print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}' 292 | .format(acc=accuracies, loss=losses)) 293 | 294 | if logger is not None: 295 | logger.log_value('test_epoch_loss', losses.avg) 296 | logger.log_value('test_epoch_accuracy', accuracies.avg) 297 | 298 | return accuracies.avg 299 | 300 | if __name__ == '__main__': 301 | main() 302 | -------------------------------------------------------------------------------- /demos/demo_grec_intnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodology defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | """ 11 | 12 | # Torch 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | import time 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Our Modules 24 | reader_folder = os.path.realpath(os.path.abspath('..')) 25 | if reader_folder not in sys.path: 26 | sys.path.append(reader_folder) 27 | import datasets 28 | from datasets import utils 29 | from models.MPNN_IntNet import MpnnIntNet 30 | from LogMetric import AverageMeter, Logger 31 | from GraphReader.graph_reader import read_cxl 32 | 33 | __author__ = "Pau Riba, Anjan Dutta" 34 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | 38 | 39 | # Parser check 40 | def restricted_float(x, inter): 41 | x = float(x) 42 | if x < inter[0] or x > inter[1]: 43 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 44 | return x 45 | 46 | # Argument parser 47 | parser = argparse.ArgumentParser(description='Neural message passing') 48 | 49 | parser.add_argument('--dataset', default='GREC', help='GREC') 50 | parser.add_argument('--datasetPath', default='../data/GREC/', help='dataset path') 51 | parser.add_argument('--logPath', default='../log/grec/intnet/', help='log path') 52 | parser.add_argument('--plotLr', default=False, help='alolow plotting the data') 53 | parser.add_argument('--plotPath', default='../plot/grec/intnet/', help='plot path') 54 | parser.add_argument('--resume', default='../checkpoint/grec/intnet/', 55 | help='path to latest checkpoint') 56 | # Optimization Options 57 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 58 | help='Input batch size for training (default: 20)') 59 | parser.add_argument('--no-cuda', action='store_true', default=False, 60 | help='Enables CUDA training') 61 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 62 | help='Number of epochs to train (default: 360)') 63 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 0.5]), default=0.001, metavar='LR', 64 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 65 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 66 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 67 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 68 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 69 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 70 | help='SGD momentum (default: 0.9)') 71 | # i/o 72 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 73 | help='How many batches to wait before logging training status') 74 | # Accelerating 75 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 76 | 77 | best_acc1 = 0 78 | 79 | 80 | def main(): 81 | 82 | global args, best_acc1 83 | args = parser.parse_args() 84 | 85 | # Check if CUDA is enabled 86 | args.cuda = not args.no_cuda and torch.cuda.is_available() 87 | 88 | # Load data 89 | root = args.datasetPath 90 | 91 | print('Prepare files') 92 | 93 | train_classes, train_ids = read_cxl(os.path.join(root, 'data/train.cxl')) 94 | test_classes, test_ids = read_cxl(os.path.join(root, 'data/test.cxl')) 95 | valid_classes, valid_ids = read_cxl(os.path.join(root, 'data/valid.cxl')) 96 | 97 | num_classes = len(list(set(train_classes + test_classes + valid_classes))) 98 | 99 | data_train = datasets.GREC(root, train_ids, train_classes) 100 | data_valid = datasets.GREC(root, valid_ids, valid_classes) 101 | data_test = datasets.GREC(root, test_ids, test_classes) 102 | 103 | # Define model and optimizer 104 | print('Define model') 105 | # Select one graph 106 | g_tuple, l = data_train[0] 107 | g, h_t, e = g_tuple 108 | 109 | #TODO: Need attention 110 | print('\tStatistics') 111 | stat_dict = {} 112 | stat_dict = datasets.utils.get_graph_stats(data_train, ['edge_labels']) 113 | 114 | # Data Loader 115 | train_loader = torch.utils.data.DataLoader(data_train, 116 | batch_size=args.batch_size, shuffle=True, collate_fn=datasets.utils.collate_g, 117 | num_workers=args.prefetch, pin_memory=True) 118 | valid_loader = torch.utils.data.DataLoader(data_valid, 119 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 120 | num_workers=args.prefetch, pin_memory=True) 121 | test_loader = torch.utils.data.DataLoader(data_test, 122 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 123 | num_workers=args.prefetch, pin_memory=True) 124 | 125 | print('\tCreate model') 126 | model = MpnnIntNet([len(h_t[0]), len(list(e.values())[0])], [15, 25, 20], [10, 20, 20], num_classes, 127 | type='classification') 128 | 129 | print('Optimizer') 130 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 131 | 132 | criterion = nn.NLLLoss() 133 | 134 | evaluation = utils.accuracy 135 | 136 | print('Logger') 137 | logger = Logger(args.logPath) 138 | 139 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 140 | 141 | # get the best checkpoint if available without training 142 | if args.resume: 143 | checkpoint_dir = args.resume 144 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 145 | if not os.path.isdir(checkpoint_dir): 146 | os.makedirs(checkpoint_dir) 147 | if os.path.isfile(best_model_file): 148 | print("=> loading best model '{}'".format(best_model_file)) 149 | checkpoint = torch.load(best_model_file) 150 | args.start_epoch = checkpoint['epoch'] 151 | best_acc1 = checkpoint['best_acc1'] 152 | model.load_state_dict(checkpoint['state_dict']) 153 | optimizer.load_state_dict(checkpoint['optimizer']) 154 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 155 | best_acc1)) 156 | else: 157 | print("=> no best model found at '{}'".format(best_model_file)) 158 | 159 | print('Check cuda') 160 | if args.cuda: 161 | print('\t* Cuda') 162 | model = model.cuda() 163 | criterion = criterion.cuda() 164 | 165 | # Epoch for loop 166 | for epoch in range(0, args.epochs): 167 | 168 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 169 | args.lr -= lr_step 170 | for param_group in optimizer.param_groups: 171 | param_group['lr'] = args.lr 172 | 173 | # train for one epoch 174 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 175 | 176 | # evaluate on test set 177 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 178 | 179 | is_best = acc1 > best_acc1 180 | best_acc1 = max(acc1, best_acc1) 181 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 182 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 183 | 184 | # Logger step 185 | logger.log_value('learning_rate', args.lr).step() 186 | 187 | # get the best checkpoint and test it with test set 188 | if args.resume: 189 | checkpoint_dir = args.resume 190 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 191 | if not os.path.isdir(checkpoint_dir): 192 | os.makedirs(checkpoint_dir) 193 | if os.path.isfile(best_model_file): 194 | print("=> loading best model '{}'".format(best_model_file)) 195 | checkpoint = torch.load(best_model_file) 196 | args.start_epoch = checkpoint['epoch'] 197 | best_acc1 = checkpoint['best_acc1'] 198 | model.load_state_dict(checkpoint['state_dict']) 199 | optimizer.load_state_dict(checkpoint['optimizer']) 200 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 201 | best_acc1)) 202 | else: 203 | print("=> no best model found at '{}'".format(best_model_file)) 204 | 205 | # For testing 206 | validate(test_loader, model, criterion, evaluation) 207 | 208 | 209 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 210 | batch_time = AverageMeter() 211 | data_time = AverageMeter() 212 | losses = AverageMeter() 213 | accuracies = AverageMeter() 214 | 215 | # switch to train mode 216 | model.train() 217 | 218 | end = time.time() 219 | for i, (g, h, e, target) in enumerate(train_loader): 220 | 221 | # Prepare input data 222 | target = torch.squeeze(target).type(torch.LongTensor) 223 | if args.cuda: 224 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 225 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 226 | 227 | # Measure data loading time 228 | data_time.update(time.time() - end) 229 | 230 | def closure(): 231 | optimizer.zero_grad() 232 | 233 | # Compute output 234 | output = model(g, h, e) 235 | train_loss = criterion(output, target) 236 | 237 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 238 | 239 | # Logs 240 | losses.update(train_loss.data[0], g.size(0)) 241 | accuracies.update(acc.data[0], g.size(0)) 242 | # compute gradient and do SGD step 243 | train_loss.backward() 244 | return train_loss 245 | 246 | optimizer.step(closure) 247 | 248 | # Measure elapsed time 249 | batch_time.update(time.time() - end) 250 | end = time.time() 251 | 252 | if i % args.log_interval == 0 and i > 0: 253 | 254 | print('Epoch: [{0}][{1}/{2}]\t' 255 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 256 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 257 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 258 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})' 259 | .format(epoch, i, len(train_loader), batch_time=batch_time, 260 | data_time=data_time, loss=losses, acc=accuracies)) 261 | 262 | logger.log_value('train_epoch_loss', losses.avg) 263 | logger.log_value('train_epoch_accuracy', accuracies.avg) 264 | 265 | print('Epoch: [{0}] Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 266 | .format(epoch, acc=accuracies, loss=losses, b_time=batch_time)) 267 | 268 | 269 | def validate(val_loader, model, criterion, evaluation, logger=None): 270 | losses = AverageMeter() 271 | accuracies = AverageMeter() 272 | 273 | # switch to evaluate mode 274 | model.eval() 275 | 276 | for i, (g, h, e, target) in enumerate(val_loader): 277 | 278 | # Prepare input data 279 | target = torch.squeeze(target).type(torch.LongTensor) 280 | if args.cuda: 281 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 282 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 283 | 284 | # Compute output 285 | output = model(g, h, e) 286 | 287 | # Logs 288 | test_loss = criterion(output, target) 289 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 290 | 291 | losses.update(test_loss.data[0], g.size(0)) 292 | accuracies.update(acc.data[0], g.size(0)) 293 | 294 | print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}' 295 | .format(acc=accuracies, loss=losses)) 296 | 297 | if logger is not None: 298 | logger.log_value('test_epoch_loss', losses.avg) 299 | logger.log_value('test_epoch_accuracy', accuracies.avg) 300 | 301 | return accuracies.avg 302 | 303 | if __name__ == '__main__': 304 | main() 305 | -------------------------------------------------------------------------------- /demos/demo_grec_mpnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodology defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | """ 11 | 12 | # Torch 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | import time 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Our Modules 24 | reader_folder = os.path.realpath(os.path.abspath('..')) 25 | if reader_folder not in sys.path: 26 | sys.path.append(reader_folder) 27 | import datasets 28 | from datasets import utils 29 | from models.MPNN import MPNN 30 | from LogMetric import AverageMeter, Logger 31 | from GraphReader.graph_reader import read_cxl 32 | 33 | __author__ = "Pau Riba, Anjan Dutta" 34 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | 38 | 39 | # Parser check 40 | def restricted_float(x, inter): 41 | x = float(x) 42 | if x < inter[0] or x > inter[1]: 43 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 44 | return x 45 | 46 | # Argument parser 47 | parser = argparse.ArgumentParser(description='Neural message passing') 48 | 49 | parser.add_argument('--dataset', default='GREC', help='GREC') 50 | parser.add_argument('--datasetPath', default='../data/GREC/', help='dataset path') 51 | parser.add_argument('--logPath', default='../log/grec/mpnn/', help='log path') 52 | parser.add_argument('--plotLr', default=False, help='allow plotting the data') 53 | parser.add_argument('--plotPath', default='../plot/grec/mpnn/', help='plot path') 54 | parser.add_argument('--resume', default='../checkpoint/grec/mpnn', 55 | help='path to latest checkpoint') 56 | 57 | # Optimization Options 58 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 59 | help='Input batch size for training (default: 20)') 60 | parser.add_argument('--no-cuda', action='store_true', default=False, 61 | help='Enables CUDA training') 62 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 63 | help='Number of epochs to train (default: 360)') 64 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 0.5]), default=0.001, metavar='LR', 65 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 66 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 67 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 68 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 69 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 70 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 71 | help='SGD momentum (default: 0.9)') 72 | # i/o 73 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 74 | help='How many batches to wait before logging training status') 75 | # Accelerating 76 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 77 | 78 | best_acc1 = 0 79 | 80 | 81 | def main(): 82 | 83 | global args, best_acc1 84 | args = parser.parse_args() 85 | 86 | # Check if CUDA is enabled 87 | args.cuda = not args.no_cuda and torch.cuda.is_available() 88 | 89 | # Load data 90 | root = args.datasetPath 91 | 92 | print('Prepare files') 93 | 94 | train_classes, train_ids = read_cxl(os.path.join(root, 'data/train.cxl')) 95 | test_classes, test_ids = read_cxl(os.path.join(root, 'data/test.cxl')) 96 | valid_classes, valid_ids = read_cxl(os.path.join(root, 'data/valid.cxl')) 97 | 98 | num_classes = len(list(set(train_classes + test_classes))) 99 | 100 | data_train = datasets.GREC(root, train_ids, train_classes) 101 | data_valid = datasets.GREC(root, valid_ids, valid_classes) 102 | data_test = datasets.GREC(root, test_ids, test_classes) 103 | 104 | # Define model and optimizer 105 | print('Define model') 106 | # Select one graph 107 | g_tuple, l = data_train[0] 108 | g, h_t, e = g_tuple 109 | 110 | # Data Loader 111 | train_loader = torch.utils.data.DataLoader(data_train, 112 | batch_size=args.batch_size, shuffle=True, collate_fn=datasets.utils.collate_g, 113 | num_workers=args.prefetch, pin_memory=True) 114 | valid_loader = torch.utils.data.DataLoader(data_valid, 115 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 116 | num_workers=args.prefetch, pin_memory=True) 117 | test_loader = torch.utils.data.DataLoader(data_test, 118 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 119 | num_workers=args.prefetch, pin_memory=True) 120 | 121 | print('\tCreate model') 122 | model = MPNN([len(h_t[0]), len(list(e.values())[0])], 25, 15, 2, num_classes, type='classification') 123 | 124 | print('Optimizer') 125 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 126 | 127 | criterion = nn.NLLLoss() 128 | 129 | evaluation = utils.accuracy 130 | 131 | print('Logger') 132 | logger = Logger(args.logPath) 133 | 134 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 135 | 136 | # get the best checkpoint if available without training 137 | if args.resume: 138 | checkpoint_dir = args.resume 139 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 140 | if not os.path.isdir(checkpoint_dir): 141 | os.makedirs(checkpoint_dir) 142 | if os.path.isfile(best_model_file): 143 | print("=> loading best model '{}'".format(best_model_file)) 144 | checkpoint = torch.load(best_model_file) 145 | args.start_epoch = checkpoint['epoch'] 146 | best_acc1 = checkpoint['best_acc1'] 147 | model.load_state_dict(checkpoint['state_dict']) 148 | optimizer.load_state_dict(checkpoint['optimizer']) 149 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 150 | best_acc1)) 151 | else: 152 | print("=> no best model found at '{}'".format(best_model_file)) 153 | 154 | print('Check cuda') 155 | if args.cuda: 156 | print('\t* Cuda') 157 | model = model.cuda() 158 | criterion = criterion.cuda() 159 | 160 | # Epoch for loop 161 | for epoch in range(0, args.epochs): 162 | 163 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 164 | args.lr -= lr_step 165 | for param_group in optimizer.param_groups: 166 | param_group['lr'] = args.lr 167 | 168 | # train for one epoch 169 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 170 | 171 | # evaluate on test set 172 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 173 | 174 | is_best = acc1 > best_acc1 175 | best_acc1 = max(acc1, best_acc1) 176 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 177 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 178 | 179 | # Logger step 180 | logger.log_value('learning_rate', args.lr).step() 181 | 182 | # get the best checkpoint and test it with test set 183 | if args.resume: 184 | checkpoint_dir = args.resume 185 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 186 | if not os.path.isdir(checkpoint_dir): 187 | os.makedirs(checkpoint_dir) 188 | if os.path.isfile(best_model_file): 189 | print("=> loading best model '{}'".format(best_model_file)) 190 | checkpoint = torch.load(best_model_file) 191 | args.start_epoch = checkpoint['epoch'] 192 | best_acc1 = checkpoint['best_acc1'] 193 | model.load_state_dict(checkpoint['state_dict']) 194 | optimizer.load_state_dict(checkpoint['optimizer']) 195 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 196 | best_acc1)) 197 | else: 198 | print("=> no best model found at '{}'".format(best_model_file)) 199 | 200 | # For testing 201 | validate(test_loader, model, criterion, evaluation) 202 | 203 | 204 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 205 | batch_time = AverageMeter() 206 | data_time = AverageMeter() 207 | losses = AverageMeter() 208 | accuracies = AverageMeter() 209 | 210 | # switch to train mode 211 | model.train() 212 | 213 | end = time.time() 214 | for i, (g, h, e, target) in enumerate(train_loader): 215 | 216 | # Prepare input data 217 | target = torch.squeeze(target).type(torch.LongTensor) 218 | if args.cuda: 219 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 220 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 221 | 222 | # Measure data loading time 223 | data_time.update(time.time() - end) 224 | 225 | def closure(): 226 | optimizer.zero_grad() 227 | 228 | # Compute output 229 | output = model(g, h, e) 230 | train_loss = criterion(output, target) 231 | 232 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 233 | 234 | # Logs 235 | losses.update(train_loss.data[0], g.size(0)) 236 | accuracies.update(acc.data[0], g.size(0)) 237 | # compute gradient and do SGD step 238 | train_loss.backward() 239 | return train_loss 240 | 241 | optimizer.step(closure) 242 | 243 | # Measure elapsed time 244 | batch_time.update(time.time() - end) 245 | end = time.time() 246 | 247 | if i % args.log_interval == 0 and i > 0: 248 | 249 | print('Epoch: [{0}][{1}/{2}]\t' 250 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 251 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 252 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 253 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})' 254 | .format(epoch, i, len(train_loader), batch_time=batch_time, 255 | data_time=data_time, loss=losses, acc=accuracies)) 256 | 257 | logger.log_value('train_epoch_loss', losses.avg) 258 | logger.log_value('train_epoch_accuracy', accuracies.avg) 259 | 260 | print('Epoch: [{0}] Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 261 | .format(epoch, acc=accuracies, loss=losses, b_time=batch_time)) 262 | 263 | 264 | def validate(val_loader, model, criterion, evaluation, logger=None): 265 | losses = AverageMeter() 266 | accuracies = AverageMeter() 267 | 268 | # switch to evaluate mode 269 | model.eval() 270 | 271 | for i, (g, h, e, target) in enumerate(val_loader): 272 | 273 | # Prepare input data 274 | target = torch.squeeze(target).type(torch.LongTensor) 275 | if args.cuda: 276 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 277 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 278 | 279 | # Compute output 280 | output = model(g, h, e) 281 | 282 | # Logs 283 | test_loss = criterion(output, target) 284 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 285 | 286 | losses.update(test_loss.data[0], g.size(0)) 287 | accuracies.update(acc.data[0], g.size(0)) 288 | 289 | print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}' 290 | .format(acc=accuracies, loss=losses)) 291 | 292 | if logger is not None: 293 | logger.log_value('test_epoch_loss', losses.avg) 294 | logger.log_value('test_epoch_accuracy', accuracies.avg) 295 | 296 | return accuracies.avg 297 | 298 | if __name__ == '__main__': 299 | main() 300 | -------------------------------------------------------------------------------- /demos/demo_gwhist_duvenaud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodology defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | """ 11 | 12 | # Torch 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | import time 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Our Modules 24 | reader_folder = os.path.realpath(os.path.abspath('..')) 25 | if reader_folder not in sys.path: 26 | sys.path.append(reader_folder) 27 | import datasets 28 | from datasets import utils 29 | from models.MPNN_Duvenaud import MpnnDuvenaud 30 | from LogMetric import AverageMeter, Logger 31 | from GraphReader.graph_reader import read_2cols_set_files, create_numeric_classes 32 | 33 | __author__ = "Pau Riba, Anjan Dutta" 34 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | 38 | 39 | # Parser check 40 | def restricted_float(x, inter): 41 | x = float(x) 42 | if x < inter[0] or x > inter[1]: 43 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 44 | return x 45 | 46 | # Argument parser 47 | parser = argparse.ArgumentParser(description='Neural message passing') 48 | 49 | parser.add_argument('--dataset', default='gwhistograph', help='GWHISTOGRAPH') 50 | parser.add_argument('--datasetPath', default='../data/GWHistoGraphs/', help='dataset path') 51 | parser.add_argument('--subSet', default='01_Keypoint', help='sub dataset') 52 | parser.add_argument('--logPath', default='../log/gwhist/duvenaud/', help='log path') 53 | # Optimization Options 54 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 55 | help='Input batch size for training (default: 20)') 56 | parser.add_argument('--no-cuda', action='store_true', default=False, 57 | help='Enables CUDA training') 58 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 59 | help='Number of epochs to train (default: 360)') 60 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 0.5]), default=0.001, metavar='LR', 61 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 62 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 63 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 64 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 65 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 66 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 67 | help='SGD momentum (default: 0.9)') 68 | # i/o 69 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 70 | help='How many batches to wait before logging training status') 71 | # Accelerating 72 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 73 | 74 | 75 | def main(): 76 | global args 77 | args = parser.parse_args() 78 | 79 | # Check if CUDA is enabled 80 | args.cuda = not args.no_cuda and torch.cuda.is_available() 81 | 82 | # Load data 83 | root = args.datasetPath 84 | subset = args.subSet 85 | 86 | print('Prepare files') 87 | 88 | train_classes, train_ids = read_2cols_set_files(os.path.join(root, 'Set/Train.txt')) 89 | valid_classes, valid_ids = read_2cols_set_files(os.path.join(root, 'Set/Valid.txt')) 90 | test_classes, test_ids = read_2cols_set_files(os.path.join(root,'Set/Test.txt')) 91 | 92 | train_classes, valid_classes, test_classes = create_numeric_classes(train_classes, valid_classes, test_classes) 93 | 94 | num_classes = max(train_classes + test_classes) + 1 95 | data_train = datasets.GWHISTOGRAPH(root, subset, train_ids, train_classes, num_classes) 96 | data_valid = datasets.GWHISTOGRAPH(root, subset, valid_ids, valid_classes, num_classes) 97 | data_test = datasets.GWHISTOGRAPH(root, subset, test_ids, test_classes, num_classes) 98 | 99 | # Define model and optimizer 100 | print('Define model') 101 | # Select one graph 102 | g_tuple, l = data_train[0] 103 | g, h_t, e = g_tuple 104 | 105 | print('\tStatistics') 106 | stat_dict = datasets.utils.get_graph_stats(data_train, ['degrees']) 107 | 108 | # Data Loader 109 | train_loader = torch.utils.data.DataLoader(data_train, 110 | batch_size=args.batch_size, shuffle=True, collate_fn=datasets.utils.collate_g, 111 | num_workers=args.prefetch, pin_memory=True) 112 | valid_loader = torch.utils.data.DataLoader(data_valid, 113 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 114 | num_workers=args.prefetch, pin_memory=True) 115 | test_loader = torch.utils.data.DataLoader(data_test, 116 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 117 | num_workers=args.prefetch, pin_memory=True) 118 | 119 | print('\tCreate model') 120 | model = MpnnDuvenaud(stat_dict['degrees'], [len(h_t[0]), len(list(e.values())[0])], [5, 15, 15], 30, num_classes, type='classification') 121 | 122 | print('Check cuda') 123 | 124 | print('Optimizer') 125 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 126 | 127 | criterion = nn.NLLLoss() 128 | 129 | evaluation = utils.accuracy 130 | 131 | print('Logger') 132 | logger = Logger(args.logPath) 133 | 134 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 135 | 136 | # get the best checkpoint if available without training 137 | if args.resume: 138 | checkpoint_dir = args.resume 139 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 140 | if not os.path.isdir(checkpoint_dir): 141 | os.makedirs(checkpoint_dir) 142 | if os.path.isfile(best_model_file): 143 | print("=> loading best model '{}'".format(best_model_file)) 144 | checkpoint = torch.load(best_model_file) 145 | args.start_epoch = checkpoint['epoch'] 146 | best_acc1 = checkpoint['best_acc1'] 147 | model.load_state_dict(checkpoint['state_dict']) 148 | optimizer.load_state_dict(checkpoint['optimizer']) 149 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 150 | best_acc1)) 151 | else: 152 | print("=> no best model found at '{}'".format(best_model_file)) 153 | 154 | print('Check cuda') 155 | if args.cuda: 156 | print('\t* Cuda') 157 | model = model.cuda() 158 | criterion = criterion.cuda() 159 | 160 | # Epoch for loop 161 | for epoch in range(0, args.epochs): 162 | 163 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 164 | args.lr -= lr_step 165 | for param_group in optimizer.param_groups: 166 | param_group['lr'] = args.lr 167 | 168 | # train for one epoch 169 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 170 | 171 | # evaluate on test set 172 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 173 | 174 | is_best = acc1 > best_acc1 175 | best_acc1 = max(acc1, best_acc1) 176 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 177 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 178 | 179 | # Logger step 180 | logger.log_value('learning_rate', args.lr).step() 181 | 182 | # get the best checkpoint and test it with test set 183 | if args.resume: 184 | checkpoint_dir = args.resume 185 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 186 | if not os.path.isdir(checkpoint_dir): 187 | os.makedirs(checkpoint_dir) 188 | if os.path.isfile(best_model_file): 189 | print("=> loading best model '{}'".format(best_model_file)) 190 | checkpoint = torch.load(best_model_file) 191 | args.start_epoch = checkpoint['epoch'] 192 | best_acc1 = checkpoint['best_acc1'] 193 | model.load_state_dict(checkpoint['state_dict']) 194 | optimizer.load_state_dict(checkpoint['optimizer']) 195 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 196 | best_acc1)) 197 | else: 198 | print("=> no best model found at '{}'".format(best_model_file)) 199 | 200 | # For testing 201 | validate(test_loader, model, criterion, evaluation) 202 | 203 | 204 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 205 | batch_time = AverageMeter() 206 | data_time = AverageMeter() 207 | losses = AverageMeter() 208 | accuracies = AverageMeter() 209 | 210 | # switch to train mode 211 | model.train() 212 | 213 | end = time.time() 214 | for i, (g, h, e, target) in enumerate(train_loader): 215 | 216 | # Prepare input data 217 | target = torch.squeeze(target).type(torch.LongTensor) 218 | if args.cuda: 219 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 220 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 221 | 222 | # Measure data loading time 223 | data_time.update(time.time() - end) 224 | 225 | def closure(): 226 | optimizer.zero_grad() 227 | 228 | # Compute output 229 | output = model(g, h, e) 230 | train_loss = criterion(output, target) 231 | 232 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 233 | 234 | # Logs 235 | losses.update(train_loss.data[0], g.size(0)) 236 | accuracies.update(acc.data[0], g.size(0)) 237 | # compute gradient and do SGD step 238 | train_loss.backward() 239 | return train_loss 240 | 241 | optimizer.step(closure) 242 | 243 | # Measure elapsed time 244 | batch_time.update(time.time() - end) 245 | end = time.time() 246 | 247 | if i % args.log_interval == 0 and i > 0: 248 | 249 | print('Epoch: [{0}][{1}/{2}]\t' 250 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 251 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 252 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 253 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})' 254 | .format(epoch, i, len(train_loader), batch_time=batch_time, 255 | data_time=data_time, loss=losses, acc=accuracies)) 256 | 257 | logger.log_value('train_epoch_loss', losses.avg) 258 | logger.log_value('train_epoch_accuracy', accuracies.avg) 259 | 260 | print('Epoch: [{0}] Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 261 | .format(epoch, acc=accuracies, loss=losses, b_time=batch_time)) 262 | 263 | 264 | def validate(val_loader, model, criterion, evaluation, logger=None): 265 | losses = AverageMeter() 266 | accuracies = AverageMeter() 267 | 268 | # switch to evaluate mode 269 | model.eval() 270 | 271 | for i, (g, h, e, target) in enumerate(val_loader): 272 | 273 | # Prepare input data 274 | target = torch.squeeze(target).type(torch.LongTensor) 275 | if args.cuda: 276 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 277 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 278 | 279 | # Compute output 280 | output = model(g, h, e) 281 | 282 | # Logs 283 | test_loss = criterion(output, target) 284 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 285 | 286 | losses.update(test_loss.data[0], g.size(0)) 287 | accuracies.update(acc.data[0], g.size(0)) 288 | 289 | print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}' 290 | .format(acc=accuracies, loss=losses)) 291 | 292 | if logger is not None: 293 | logger.log_value('test_epoch_loss', losses.avg) 294 | logger.log_value('test_epoch_accuracy', accuracies.avg) 295 | 296 | return accuracies.avg 297 | 298 | if __name__ == '__main__': 299 | main() 300 | -------------------------------------------------------------------------------- /demos/demo_gwhist_ggnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodology defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | """ 11 | 12 | # Torch 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | import time 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Our Modules 24 | reader_folder = os.path.realpath(os.path.abspath('..')) 25 | if reader_folder not in sys.path: 26 | sys.path.append(reader_folder) 27 | import datasets 28 | from datasets import utils 29 | from models.MPNN_GGNN import MpnnGGNN 30 | from LogMetric import AverageMeter, Logger 31 | from GraphReader.graph_reader import read_2cols_set_files, create_numeric_classes 32 | 33 | __author__ = "Pau Riba, Anjan Dutta" 34 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | 38 | 39 | # Parser check 40 | def restricted_float(x, inter): 41 | x = float(x) 42 | if x < inter[0] or x > inter[1]: 43 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 44 | return x 45 | 46 | # Argument parser 47 | parser = argparse.ArgumentParser(description='Neural message passing') 48 | 49 | parser.add_argument('--dataset', default='gwhistograph', help='GWHISTOGRAPH') 50 | parser.add_argument('--datasetPath', default='../data/GWHistoGraphs/', help='dataset path') 51 | parser.add_argument('--subSet', default='01_Keypoint', help='sub dataset') 52 | parser.add_argument('--logPath', default='../log/gwhist/ggnn/', help='log path') 53 | parser.add_argument('--plotLr', default=False, help='allow plotting the data') 54 | parser.add_argument('--plotPath', default='../plot/gwhist/ggnn/', help='plot path') 55 | parser.add_argument('--resume', default='../checkpoint/gwhist/ggnn/', 56 | help='path to latest checkpoint') 57 | # Optimization Options 58 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 59 | help='Input batch size for training (default: 20)') 60 | parser.add_argument('--no-cuda', action='store_true', default=False, 61 | help='Enables CUDA training') 62 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 63 | help='Number of epochs to train (default: 360)') 64 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 0.5]), default=0.001, metavar='LR', 65 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 66 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 67 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 68 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 69 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 70 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 71 | help='SGD momentum (default: 0.9)') 72 | # i/o 73 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 74 | help='How many batches to wait before logging training status') 75 | # Accelerating 76 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 77 | 78 | best_acc1 = 0 79 | 80 | 81 | def main(): 82 | 83 | global args, best_acc1 84 | args = parser.parse_args() 85 | 86 | # Check if CUDA is enabled 87 | args.cuda = not args.no_cuda and torch.cuda.is_available() 88 | 89 | # Load data 90 | root = args.datasetPath 91 | subset = args.subSet 92 | 93 | print('Prepare files') 94 | 95 | train_classes, train_ids = read_2cols_set_files(os.path.join(root, 'Set/Train.txt')) 96 | valid_classes, valid_ids = read_2cols_set_files(os.path.join(root, 'Set/Valid.txt')) 97 | test_classes, test_ids = read_2cols_set_files(os.path.join(root,'Set/Test.txt')) 98 | 99 | train_classes, valid_classes, test_classes = create_numeric_classes(train_classes, valid_classes, test_classes) 100 | 101 | train_classes = train_classes + valid_classes 102 | train_ids = train_ids + valid_ids 103 | 104 | del valid_classes, valid_ids 105 | 106 | num_classes = max(train_classes + test_classes) + 1 107 | data_train = datasets.GWHISTOGRAPH(root, subset, train_ids, train_classes, num_classes) 108 | data_valid = datasets.GWHISTOGRAPH(root, subset, valid_ids, valid_classes, num_classes) 109 | data_test = datasets.GWHISTOGRAPH(root, subset, test_ids, test_classes, num_classes) 110 | 111 | # Define model and optimizer 112 | print('Define model') 113 | # Select one graph 114 | g_tuple, l = data_train[0] 115 | g, h_t, e = g_tuple 116 | 117 | print('\tStatistics') 118 | stat_dict = datasets.utils.get_graph_stats(data_train, ['edge_labels']) 119 | 120 | # Data Loader 121 | train_loader = torch.utils.data.DataLoader(data_train, 122 | batch_size=args.batch_size, shuffle=True, collate_fn=datasets.utils.collate_g, 123 | num_workers=args.prefetch, pin_memory=True) 124 | valid_loader = torch.utils.data.DataLoader(data_valid, 125 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 126 | num_workers=args.prefetch, pin_memory=True) 127 | test_loader = torch.utils.data.DataLoader(data_test, 128 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 129 | num_workers=args.prefetch, pin_memory=True) 130 | 131 | print('\tCreate model') 132 | model = MpnnGGNN(stat_dict['edge_labels'], 25, 15, 2, num_classes, type='classification') 133 | 134 | print('Optimizer') 135 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 136 | 137 | criterion = nn.NLLLoss() 138 | 139 | evaluation = utils.accuracy 140 | 141 | print('Logger') 142 | logger = Logger(args.logPath) 143 | 144 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 145 | 146 | # get the best checkpoint if available without training 147 | if args.resume: 148 | checkpoint_dir = args.resume 149 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 150 | if not os.path.isdir(checkpoint_dir): 151 | os.makedirs(checkpoint_dir) 152 | if os.path.isfile(best_model_file): 153 | print("=> loading best model '{}'".format(best_model_file)) 154 | checkpoint = torch.load(best_model_file) 155 | args.start_epoch = checkpoint['epoch'] 156 | best_acc1 = checkpoint['best_acc1'] 157 | model.load_state_dict(checkpoint['state_dict']) 158 | optimizer.load_state_dict(checkpoint['optimizer']) 159 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 160 | best_acc1)) 161 | else: 162 | print("=> no best model found at '{}'".format(best_model_file)) 163 | 164 | print('Check cuda') 165 | if args.cuda: 166 | print('\t* Cuda') 167 | model = model.cuda() 168 | criterion = criterion.cuda() 169 | 170 | # Epoch for loop 171 | for epoch in range(0, args.epochs): 172 | 173 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 174 | args.lr -= lr_step 175 | for param_group in optimizer.param_groups: 176 | param_group['lr'] = args.lr 177 | 178 | # train for one epoch 179 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 180 | 181 | # evaluate on test set 182 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 183 | 184 | is_best = acc1 > best_acc1 185 | best_acc1 = max(acc1, best_acc1) 186 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 187 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 188 | 189 | # Logger step 190 | logger.log_value('learning_rate', args.lr).step() 191 | 192 | # get the best checkpoint and test it with test set 193 | if args.resume: 194 | checkpoint_dir = args.resume 195 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 196 | if not os.path.isdir(checkpoint_dir): 197 | os.makedirs(checkpoint_dir) 198 | if os.path.isfile(best_model_file): 199 | print("=> loading best model '{}'".format(best_model_file)) 200 | checkpoint = torch.load(best_model_file) 201 | args.start_epoch = checkpoint['epoch'] 202 | best_acc1 = checkpoint['best_acc1'] 203 | model.load_state_dict(checkpoint['state_dict']) 204 | optimizer.load_state_dict(checkpoint['optimizer']) 205 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 206 | best_acc1)) 207 | else: 208 | print("=> no best model found at '{}'".format(best_model_file)) 209 | 210 | # For testing 211 | validate(test_loader, model, criterion, evaluation) 212 | 213 | 214 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 215 | batch_time = AverageMeter() 216 | data_time = AverageMeter() 217 | losses = AverageMeter() 218 | accuracies = AverageMeter() 219 | 220 | # switch to train mode 221 | model.train() 222 | 223 | end = time.time() 224 | for i, (g, h, e, target) in enumerate(train_loader): 225 | 226 | # Prepare input data 227 | target = torch.squeeze(target).type(torch.LongTensor) 228 | if args.cuda: 229 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 230 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 231 | 232 | # Measure data loading time 233 | data_time.update(time.time() - end) 234 | 235 | def closure(): 236 | optimizer.zero_grad() 237 | 238 | # Compute output 239 | output = model(g, h, e) 240 | train_loss = criterion(output, target) 241 | 242 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 243 | 244 | # Logs 245 | losses.update(train_loss.data[0], g.size(0)) 246 | accuracies.update(acc.data[0], g.size(0)) 247 | # compute gradient and do SGD step 248 | train_loss.backward() 249 | return train_loss 250 | 251 | optimizer.step(closure) 252 | 253 | # Measure elapsed time 254 | batch_time.update(time.time() - end) 255 | end = time.time() 256 | 257 | if i % args.log_interval == 0 and i > 0: 258 | 259 | print('Epoch: [{0}][{1}/{2}]\t' 260 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 261 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 262 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 263 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})' 264 | .format(epoch, i, len(train_loader), batch_time=batch_time, 265 | data_time=data_time, loss=losses, acc=accuracies)) 266 | 267 | logger.log_value('train_epoch_loss', losses.avg) 268 | logger.log_value('train_epoch_accuracy', accuracies.avg) 269 | 270 | print('Epoch: [{0}] Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 271 | .format(epoch, acc=accuracies, loss=losses, b_time=batch_time)) 272 | 273 | 274 | def validate(val_loader, model, criterion, evaluation, logger=None): 275 | losses = AverageMeter() 276 | accuracies = AverageMeter() 277 | 278 | # switch to evaluate mode 279 | model.eval() 280 | 281 | for i, (g, h, e, target) in enumerate(val_loader): 282 | 283 | # Prepare input data 284 | target = torch.squeeze(target).type(torch.LongTensor) 285 | if args.cuda: 286 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 287 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 288 | 289 | # Compute output 290 | output = model(g, h, e) 291 | 292 | # Logs 293 | test_loss = criterion(output, target) 294 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 295 | 296 | losses.update(test_loss.data[0], g.size(0)) 297 | accuracies.update(acc.data[0], g.size(0)) 298 | 299 | print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}' 300 | .format(acc=accuracies, loss=losses)) 301 | 302 | if logger is not None: 303 | logger.log_value('test_epoch_loss', losses.avg) 304 | logger.log_value('test_epoch_accuracy', accuracies.avg) 305 | 306 | return accuracies.avg 307 | 308 | if __name__ == '__main__': 309 | main() 310 | -------------------------------------------------------------------------------- /demos/demo_letter_ggnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodology defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | """ 11 | 12 | # Torch 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | import time 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Our Modules 24 | reader_folder = os.path.realpath(os.path.abspath('..')) 25 | if reader_folder not in sys.path: 26 | sys.path.append(reader_folder) 27 | import datasets 28 | from datasets import utils 29 | from models.MPNN_GGNN import MpnnGGNN 30 | from LogMetric import AverageMeter, Logger 31 | from GraphReader.graph_reader import read_cxl 32 | from visualization.Plotter import Plotter 33 | 34 | __author__ = "Pau Riba, Anjan Dutta" 35 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 36 | 37 | torch.multiprocessing.set_sharing_strategy('file_system') 38 | 39 | 40 | # Parser check 41 | def restricted_float(x, inter): 42 | x = float(x) 43 | if x < inter[0] or x > inter[1]: 44 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 45 | return x 46 | 47 | # Argument parser 48 | parser = argparse.ArgumentParser(description='Neural message passing') 49 | 50 | parser.add_argument('--dataset', default='Letter', help='letter') 51 | parser.add_argument('--datasetPath', default='../data/Letter/', help='dataset path') 52 | parser.add_argument('--subSet', default='LOW', help='sub dataset') 53 | parser.add_argument('--logPath', default='../log/letter/ggnn/', help='log path') 54 | parser.add_argument('--plotLr', default=False, help='allow plotting the data') 55 | parser.add_argument('--plotPath', default='../plot/letter/ggnn/', help='plot path') 56 | parser.add_argument('--resume', default='../checkpoint/letter/ggnn/', 57 | help='path to latest checkpoint') 58 | 59 | # Optimization Options 60 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 61 | help='Input batch size for training (default: 20)') 62 | parser.add_argument('--no-cuda', action='store_true', default=True, 63 | help='Enables CUDA training') 64 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 65 | help='Number of epochs to train (default: 360)') 66 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 0.5]), default=0.001, metavar='LR', 67 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 68 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 69 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 70 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 71 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 72 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 73 | help='SGD momentum (default: 0.9)') 74 | # i/o 75 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 76 | help='How many batches to wait before logging training status') 77 | # Accelerating 78 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 79 | 80 | best_acc1 = 0 81 | 82 | 83 | def main(): 84 | 85 | global args, best_acc1 86 | args = parser.parse_args() 87 | 88 | # Check if CUDA is enabled 89 | args.cuda = not args.no_cuda and torch.cuda.is_available() 90 | 91 | # Load data 92 | root = args.datasetPath 93 | subset = args.subSet 94 | 95 | print('Prepare files') 96 | 97 | train_classes, train_ids = read_cxl(os.path.join(root, subset, 'train.cxl')) 98 | test_classes, test_ids = read_cxl(os.path.join(root, subset, 'test.cxl')) 99 | valid_classes, valid_ids = read_cxl(os.path.join(root, subset, 'validation.cxl')) 100 | 101 | class_list = list(set(train_classes + test_classes)) 102 | num_classes = len(class_list) 103 | data_train = datasets.LETTER(root, subset, train_ids, train_classes, class_list) 104 | data_valid = datasets.LETTER(root, subset, valid_ids, train_classes, class_list) 105 | data_test = datasets.LETTER(root, subset, test_ids, test_classes, class_list) 106 | 107 | # Define model and optimizer 108 | print('Define model') 109 | # Select one graph 110 | g_tuple, l = data_train[0] 111 | g, h_t, e = g_tuple 112 | 113 | #TODO: Need attention 114 | print('\tStatistics') 115 | stat_dict = {} 116 | # stat_dict = datasets.utils.get_graph_stats(data_train, ['edge_labels']) 117 | stat_dict['edge_labels'] = [1] 118 | 119 | 120 | # Data Loader 121 | train_loader = torch.utils.data.DataLoader(data_train, 122 | batch_size=args.batch_size, shuffle=True, 123 | collate_fn=datasets.utils.collate_g, 124 | num_workers=args.prefetch, pin_memory=True) 125 | valid_loader = torch.utils.data.DataLoader(data_valid, 126 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 127 | num_workers=args.prefetch, pin_memory=True) 128 | test_loader = torch.utils.data.DataLoader(data_test, 129 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 130 | num_workers=args.prefetch, pin_memory=True) 131 | 132 | print('\tCreate model') 133 | model = MpnnGGNN(stat_dict['edge_labels'], 25, 15, 2, num_classes, type='classification') 134 | 135 | print('Optimizer') 136 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 137 | 138 | criterion = nn.NLLLoss() 139 | 140 | evaluation = utils.accuracy 141 | 142 | print('Logger') 143 | logger = Logger(args.logPath) 144 | 145 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 146 | 147 | # get the best checkpoint if available without training 148 | if args.resume: 149 | checkpoint_dir = args.resume 150 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 151 | if not os.path.isdir(checkpoint_dir): 152 | os.makedirs(checkpoint_dir) 153 | if os.path.isfile(best_model_file): 154 | print("=> loading best model '{}'".format(best_model_file)) 155 | checkpoint = torch.load(best_model_file) 156 | args.start_epoch = checkpoint['epoch'] 157 | best_acc1 = checkpoint['best_acc1'] 158 | model.load_state_dict(checkpoint['state_dict']) 159 | optimizer.load_state_dict(checkpoint['optimizer']) 160 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 161 | best_acc1)) 162 | else: 163 | print("=> no best model found at '{}'".format(best_model_file)) 164 | 165 | print('Check cuda') 166 | if args.cuda: 167 | print('\t* Cuda') 168 | model = model.cuda() 169 | criterion = criterion.cuda() 170 | 171 | # Epoch for loop 172 | for epoch in range(0, args.epochs): 173 | 174 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 175 | args.lr -= lr_step 176 | for param_group in optimizer.param_groups: 177 | param_group['lr'] = args.lr 178 | 179 | # train for one epoch 180 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 181 | 182 | # evaluate on test set 183 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 184 | 185 | is_best = acc1 > best_acc1 186 | best_acc1 = max(acc1, best_acc1) 187 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 188 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 189 | 190 | # Logger step 191 | logger.log_value('learning_rate', args.lr).step() 192 | 193 | # get the best checkpoint and test it with test set 194 | if args.resume: 195 | checkpoint_dir = args.resume 196 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 197 | if not os.path.isdir(checkpoint_dir): 198 | os.makedirs(checkpoint_dir) 199 | if os.path.isfile(best_model_file): 200 | print("=> loading best model '{}'".format(best_model_file)) 201 | checkpoint = torch.load(best_model_file) 202 | args.start_epoch = checkpoint['epoch'] 203 | best_acc1 = checkpoint['best_acc1'] 204 | model.load_state_dict(checkpoint['state_dict']) 205 | optimizer.load_state_dict(checkpoint['optimizer']) 206 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 207 | best_acc1)) 208 | else: 209 | print("=> no best model found at '{}'".format(best_model_file)) 210 | 211 | # For testing 212 | validate(test_loader, model, criterion, evaluation) 213 | 214 | 215 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 216 | batch_time = AverageMeter() 217 | data_time = AverageMeter() 218 | losses = AverageMeter() 219 | accuracies = AverageMeter() 220 | 221 | # switch to train mode 222 | model.train() 223 | 224 | end = time.time() 225 | for i, (g, h, e, target) in enumerate(train_loader): 226 | 227 | # Prepare input data 228 | target = torch.squeeze(target).type(torch.LongTensor) 229 | if args.cuda: 230 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 231 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 232 | 233 | # Measure data loading time 234 | data_time.update(time.time() - end) 235 | 236 | def closure(): 237 | optimizer.zero_grad() 238 | 239 | # Compute output 240 | output = model(g, h, e) 241 | train_loss = criterion(output, target) 242 | 243 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 244 | 245 | # Logs 246 | losses.update(train_loss.data[0], g.size(0)) 247 | accuracies.update(acc.data[0], g.size(0)) 248 | # compute gradient and do SGD step 249 | train_loss.backward() 250 | return train_loss 251 | 252 | optimizer.step(closure) 253 | 254 | # Measure elapsed time 255 | batch_time.update(time.time() - end) 256 | end = time.time() 257 | 258 | if i % args.log_interval == 0 and i > 0: 259 | 260 | print('Epoch: [{0}][{1}/{2}]\t' 261 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 262 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 263 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 264 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})' 265 | .format(epoch, i, len(train_loader), batch_time=batch_time, 266 | data_time=data_time, loss=losses, acc=accuracies)) 267 | 268 | logger.log_value('train_epoch_loss', losses.avg) 269 | logger.log_value('train_epoch_accuracy', accuracies.avg) 270 | 271 | print('Epoch: [{0}] Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 272 | .format(epoch, acc=accuracies, loss=losses, b_time=batch_time)) 273 | 274 | 275 | def validate(val_loader, model, criterion, evaluation, logger=None): 276 | losses = AverageMeter() 277 | accuracies = AverageMeter() 278 | 279 | # switch to evaluate mode 280 | model.eval() 281 | 282 | end = time.time() 283 | for i, (g, h, e, target) in enumerate(val_loader): 284 | 285 | # Prepare input data 286 | target = torch.squeeze(target).type(torch.LongTensor) 287 | if args.cuda: 288 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 289 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 290 | 291 | # Compute output 292 | output = model(g, h, e) 293 | 294 | # Logs 295 | test_loss = criterion(output, target) 296 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 297 | 298 | losses.update(test_loss.data[0], g.size(0)) 299 | accuracies.update(acc.data[0], g.size(0)) 300 | 301 | print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}' 302 | .format(acc=accuracies, loss=losses)) 303 | 304 | if logger is not None: 305 | logger.log_value('test_epoch_loss', losses.avg) 306 | logger.log_value('test_epoch_accuracy', accuracies.avg) 307 | 308 | return accuracies.avg 309 | 310 | if __name__ == '__main__': 311 | main() 312 | -------------------------------------------------------------------------------- /demos/demo_letter_intnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodology defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | """ 11 | 12 | # Torch 13 | import torch 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.autograd import Variable 17 | 18 | import time 19 | import argparse 20 | import os 21 | import sys 22 | 23 | # Our Modules 24 | reader_folder = os.path.realpath(os.path.abspath('..')) 25 | if reader_folder not in sys.path: 26 | sys.path.append(reader_folder) 27 | import datasets 28 | from datasets import utils 29 | from models.MPNN_IntNet import MpnnIntNet 30 | from LogMetric import AverageMeter, Logger 31 | from GraphReader.graph_reader import read_cxl 32 | 33 | __author__ = "Pau Riba, Anjan Dutta" 34 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | 38 | 39 | # Parser check 40 | def restricted_float(x, inter): 41 | x = float(x) 42 | if x < inter[0] or x > inter[1]: 43 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 44 | return x 45 | 46 | # Argument parser 47 | parser = argparse.ArgumentParser(description='Neural message passing') 48 | 49 | parser.add_argument('--dataset', default='Letter', help='letter') 50 | parser.add_argument('--datasetPath', default='../data/Letter/', help='dataset path') 51 | parser.add_argument('--subSet', default='LOW', help='sub dataset') 52 | parser.add_argument('--logPath', default='../log/letter/intnet/', help='log path') 53 | parser.add_argument('--plotLr', default=False, help='allow plotting the data') 54 | parser.add_argument('--plotPath', default='../plot/letter/intnet/', help='plot path') 55 | parser.add_argument('--resume', default='../checkpoint/letter/intnet/', 56 | help='path to latest checkpoint') 57 | # Optimization Options 58 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 59 | help='Input batch size for training (default: 20)') 60 | parser.add_argument('--no-cuda', action='store_true', default=False, 61 | help='Enables CUDA training') 62 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 63 | help='Number of epochs to train (default: 360)') 64 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 0.5]), default=0.001, metavar='LR', 65 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 66 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 67 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 68 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 69 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 70 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 71 | help='SGD momentum (default: 0.9)') 72 | # i/o 73 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 74 | help='How many batches to wait before logging training status') 75 | # Accelerating 76 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 77 | 78 | best_acc1 = 0 79 | 80 | 81 | def main(): 82 | 83 | global args, best_acc1 84 | args = parser.parse_args() 85 | 86 | # Check if CUDA is enabled 87 | args.cuda = not args.no_cuda and torch.cuda.is_available() 88 | 89 | # Load data 90 | root = args.datasetPath 91 | subset = args.subSet 92 | 93 | print('Prepare files') 94 | 95 | train_classes, train_ids = read_cxl(os.path.join(root, subset, 'train.cxl')) 96 | test_classes, test_ids = read_cxl(os.path.join(root, subset, 'test.cxl')) 97 | valid_classes, valid_ids = read_cxl(os.path.join(root, subset, 'validation.cxl')) 98 | 99 | class_list = list(set(train_classes + test_classes)) 100 | num_classes = len(class_list) 101 | data_train = datasets.LETTER(root, subset, train_ids, train_classes, class_list) 102 | data_valid = datasets.LETTER(root, subset, valid_ids, valid_classes, class_list) 103 | data_test = datasets.LETTER(root, subset, test_ids, test_classes, class_list) 104 | 105 | # Define model and optimizer 106 | print('Define model') 107 | # Select one graph 108 | g_tuple, l = data_train[0] 109 | g, h_t, e = g_tuple 110 | 111 | #TODO: Need attention 112 | print('\tStatistics') 113 | stat_dict = {} 114 | # stat_dict = datasets.utils.get_graph_stats(data_train, ['edge_labels']) 115 | stat_dict['edge_labels'] = [1] 116 | 117 | 118 | # Data Loader 119 | train_loader = torch.utils.data.DataLoader(data_train, batch_size=args.batch_size, shuffle=True, 120 | collate_fn=datasets.utils.collate_g, num_workers=args.prefetch, 121 | pin_memory=True) 122 | valid_loader = torch.utils.data.DataLoader(data_valid, batch_size=args.batch_size, 123 | collate_fn=datasets.utils.collate_g, num_workers=args.prefetch, 124 | pin_memory=True) 125 | test_loader = torch.utils.data.DataLoader(data_test, batch_size=args.batch_size, 126 | collate_fn=datasets.utils.collate_g, num_workers=args.prefetch, 127 | pin_memory=True) 128 | 129 | print('\tCreate model') 130 | model = MpnnIntNet([len(h_t[0]), len(list(e.values())[0])], [5, 15, 15], [10, 20, 20], num_classes, 131 | type='classification') 132 | 133 | print('Optimizer') 134 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 135 | 136 | criterion = nn.NLLLoss() 137 | 138 | evaluation = utils.accuracy 139 | 140 | print('Logger') 141 | logger = Logger(args.logPath) 142 | 143 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 144 | 145 | # get the best checkpoint if available without training 146 | if args.resume: 147 | checkpoint_dir = args.resume 148 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 149 | if not os.path.isdir(best_model_file): 150 | os.makedirs(checkpoint_dir) 151 | if os.path.isfile(best_model_file): 152 | print("=> loading best model '{}'".format(best_model_file)) 153 | checkpoint = torch.load(best_model_file) 154 | args.start_epoch = checkpoint['epoch'] 155 | best_acc1 = checkpoint['best_acc1'] 156 | model.load_state_dict(checkpoint['state_dict']) 157 | optimizer.load_state_dict(checkpoint['optimizer']) 158 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 159 | best_acc1)) 160 | else: 161 | print("=> no best model found at '{}'".format(best_model_file)) 162 | 163 | print('Check cuda') 164 | if args.cuda: 165 | print('\t* Cuda') 166 | model = model.cuda() 167 | criterion = criterion.cuda() 168 | 169 | # Epoch for loop 170 | for epoch in range(0, args.epochs): 171 | 172 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 173 | args.lr -= lr_step 174 | for param_group in optimizer.param_groups: 175 | param_group['lr'] = args.lr 176 | 177 | # train for one epoch 178 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 179 | 180 | # evaluate on test set 181 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 182 | 183 | is_best = acc1 > best_acc1 184 | best_acc1 = max(acc1, best_acc1) 185 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 186 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 187 | 188 | # Logger step 189 | logger.log_value('learning_rate', args.lr).step() 190 | 191 | # get the best checkpoint and test it with test set 192 | if args.resume: 193 | checkpoint_dir = args.resume 194 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 195 | if not os.path.isdir(best_model_file): 196 | os.makedirs(checkpoint_dir) 197 | if os.path.isfile(best_model_file): 198 | print("=> loading best model '{}'".format(best_model_file)) 199 | checkpoint = torch.load(best_model_file) 200 | args.start_epoch = checkpoint['epoch'] 201 | best_acc1 = checkpoint['best_acc1'] 202 | model.load_state_dict(checkpoint['state_dict']) 203 | optimizer.load_state_dict(checkpoint['optimizer']) 204 | print("=> loaded best model '{}' (epoch {}; accuracy {})".format(best_model_file, checkpoint['epoch'], 205 | best_acc1)) 206 | else: 207 | print("=> no best model found at '{}'".format(best_model_file)) 208 | 209 | # For testing 210 | validate(test_loader, model, criterion, evaluation) 211 | 212 | 213 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 214 | batch_time = AverageMeter() 215 | data_time = AverageMeter() 216 | losses = AverageMeter() 217 | accuracies = AverageMeter() 218 | 219 | # switch to train mode 220 | model.train() 221 | 222 | end = time.time() 223 | for i, (g, h, e, target) in enumerate(train_loader): 224 | 225 | # Prepare input data 226 | target = torch.squeeze(target).type(torch.LongTensor) 227 | if args.cuda: 228 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 229 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 230 | 231 | # Measure data loading time 232 | data_time.update(time.time() - end) 233 | 234 | def closure(): 235 | optimizer.zero_grad() 236 | 237 | # Compute output 238 | output = model(g, h, e) 239 | train_loss = criterion(output, target) 240 | 241 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 242 | 243 | # Logs 244 | losses.update(train_loss.data[0], g.size(0)) 245 | accuracies.update(acc.data[0], g.size(0)) 246 | # compute gradient and do SGD step 247 | train_loss.backward() 248 | return train_loss 249 | 250 | optimizer.step(closure) 251 | 252 | # Measure elapsed time 253 | batch_time.update(time.time() - end) 254 | end = time.time() 255 | 256 | if i % args.log_interval == 0 and i > 0: 257 | 258 | print('Epoch: [{0}][{1}/{2}]\t' 259 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 260 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 261 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 262 | 'Accuracy {acc.val:.4f} ({acc.avg:.4f})' 263 | .format(epoch, i, len(train_loader), batch_time=batch_time, 264 | data_time=data_time, loss=losses, acc=accuracies)) 265 | 266 | logger.log_value('train_epoch_loss', losses.avg) 267 | logger.log_value('train_epoch_accuracy', accuracies.avg) 268 | 269 | print('Epoch: [{0}] Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 270 | .format(epoch, acc=accuracies, loss=losses, b_time=batch_time)) 271 | 272 | 273 | def validate(val_loader, model, criterion, evaluation, logger=None): 274 | losses = AverageMeter() 275 | accuracies = AverageMeter() 276 | 277 | # switch to evaluate mode 278 | model.eval() 279 | 280 | end = time.time() 281 | for i, (g, h, e, target) in enumerate(val_loader): 282 | 283 | # Prepare input data 284 | target = torch.squeeze(target).type(torch.LongTensor) 285 | if args.cuda: 286 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 287 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 288 | 289 | # Compute output 290 | output = model(g, h, e) 291 | 292 | # Logs 293 | test_loss = criterion(output, target) 294 | acc = Variable(evaluation(output.data, target.data, topk=(1,))[0]) 295 | 296 | losses.update(test_loss.data[0], g.size(0)) 297 | accuracies.update(acc.data[0], g.size(0)) 298 | 299 | print(' * Average Accuracy {acc.avg:.3f}; Average Loss {loss.avg:.3f}' 300 | .format(acc=accuracies, loss=losses)) 301 | 302 | if logger is not None: 303 | logger.log_value('test_epoch_loss', losses.avg) 304 | logger.log_value('test_epoch_accuracy', accuracies.avg) 305 | 306 | return accuracies.avg 307 | 308 | if __name__ == '__main__': 309 | main() 310 | -------------------------------------------------------------------------------- /demos/demo_qm9_ggnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodologi defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | 11 | """ 12 | 13 | # Torch 14 | import torch 15 | import torch.optim as optim 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | 19 | import time 20 | import argparse 21 | import os 22 | import sys 23 | import numpy as np 24 | 25 | # Our Modules 26 | reader_folder = os.path.realpath(os.path.abspath('..')) 27 | if reader_folder not in sys.path: 28 | sys.path.append(reader_folder) 29 | import datasets 30 | from datasets import utils 31 | from models.MPNN_GGNN import MpnnGGNN 32 | from LogMetric import AverageMeter, Logger 33 | 34 | __author__ = "Pau Riba, Anjan Dutta" 35 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 36 | 37 | 38 | # Parser check 39 | def restricted_float(x, inter): 40 | x = float(x) 41 | if x < inter[0] or x > inter[1]: 42 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 43 | return x 44 | 45 | # Argument parser 46 | parser = argparse.ArgumentParser(description='Neural message passing') 47 | 48 | parser.add_argument('--dataset', default='qm9', help='QM9') 49 | parser.add_argument('--datasetPath', default='../data/qm9/dsgdb9nsd/', help='dataset path') 50 | parser.add_argument('--logPath', default='../log/qm9/ggnn/', help='log path') 51 | parser.add_argument('--plotLr', default=False, help='allow plotting the data') 52 | parser.add_argument('--plotPath', default='../plot/qm9/ggnn/', help='plot path') 53 | parser.add_argument('--resume', default='../checkpoint/qm9/ggnn/', 54 | help='path to latest checkpoint') 55 | # Optimization Options 56 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 57 | help='Input batch size for training (default: 20)') 58 | parser.add_argument('--no-cuda', action='store_true', default=False, 59 | help='Enables CUDA training') 60 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 61 | help='Number of epochs to train (default: 360)') 62 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 1e-2]), default=1e-4, metavar='LR', 63 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 64 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 65 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 66 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 67 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 68 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 69 | help='SGD momentum (default: 0.9)') 70 | # i/o 71 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 72 | help='How many batches to wait before logging training status') 73 | # Accelerating 74 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 75 | 76 | best_acc1 = 0 77 | 78 | 79 | def main(): 80 | 81 | global args, best_acc1 82 | args = parser.parse_args() 83 | 84 | # Check if CUDA is enabled 85 | args.cuda = not args.no_cuda and torch.cuda.is_available() 86 | 87 | # Load data 88 | root = args.datasetPath 89 | 90 | print('Prepare files') 91 | files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 92 | 93 | idx = np.random.permutation(len(files)) 94 | idx = idx.tolist() 95 | 96 | valid_ids = [files[i] for i in idx[0:10000]] 97 | test_ids = [files[i] for i in idx[10000:20000]] 98 | train_ids = [files[i] for i in idx[20000:]] 99 | 100 | data_train = datasets.Qm9(root, train_ids) 101 | data_valid = datasets.Qm9(root, valid_ids) 102 | data_test = datasets.Qm9(root, test_ids) 103 | 104 | # Define model and optimizer 105 | print('Define model') 106 | # Select one graph 107 | g_tuple, l = data_train[0] 108 | g, h_t, e = g_tuple 109 | 110 | print('\tStatistics') 111 | # stat_dict = datasets.utils.get_graph_stats(data_valid, ['degrees', 'target_mean', 'target_std', 'edge_labels']) 112 | 113 | stat_dict = {} 114 | 115 | stat_dict['degrees'] = [1, 2, 3, 4] 116 | stat_dict['target_mean'] = np.array([2.71802732e+00, 7.51685080e+01, -2.40259300e-01, 1.09503300e-02, 117 | 2.51209430e-01, 1.18997445e+03, 1.48493130e-01, -4.11609491e+02, 118 | -4.11601022e+02, -4.11600078e+02, -4.11642909e+02, 3.15894998e+01]) 119 | stat_dict['target_std'] = np.array([1.58422291e+00, 8.29443552e+00, 2.23854977e-02, 4.71030547e-02, 120 | 4.77156393e-02, 2.80754665e+02, 3.37238236e-02, 3.97717205e+01, 121 | 3.97715029e+01, 3.97715029e+01, 3.97722334e+01, 4.09458852e+00]) 122 | stat_dict['edge_labels'] = [1, 2, 3, 4] 123 | 124 | data_train.set_target_transform(lambda x: datasets.utils.normalize_data(x,stat_dict['target_mean'], 125 | stat_dict['target_std'])) 126 | data_valid.set_target_transform(lambda x: datasets.utils.normalize_data(x, stat_dict['target_mean'], 127 | stat_dict['target_std'])) 128 | data_test.set_target_transform(lambda x: datasets.utils.normalize_data(x, stat_dict['target_mean'], 129 | stat_dict['target_std'])) 130 | 131 | # Data Loader 132 | train_loader = torch.utils.data.DataLoader(data_train, 133 | batch_size=args.batch_size, shuffle=True, 134 | collate_fn=datasets.utils.collate_g, 135 | num_workers=args.prefetch, pin_memory=True) 136 | valid_loader = torch.utils.data.DataLoader(data_valid, 137 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 138 | num_workers=args.prefetch, pin_memory=True) 139 | test_loader = torch.utils.data.DataLoader(data_test, 140 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 141 | num_workers=args.prefetch, pin_memory=True) 142 | 143 | print('\tCreate model') 144 | model = MpnnGGNN(stat_dict['edge_labels'], 25, 15, 2, len(l), 145 | type='regression') 146 | 147 | print('Optimizer') 148 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 149 | 150 | criterion = nn.MSELoss() 151 | # evaluation = nn.L1Loss() 152 | evaluation = lambda output, target: torch.mean(torch.abs(output - target) / torch.abs(target)) 153 | 154 | print('Logger') 155 | logger = Logger(args.logPath) 156 | 157 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 158 | 159 | # get the best checkpoint if available without training 160 | if args.resume: 161 | checkpoint_dir = args.resume 162 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 163 | if not os.path.isdir(checkpoint_dir): 164 | os.makedirs(checkpoint_dir) 165 | if os.path.isfile(best_model_file): 166 | print("=> loading best model '{}'".format(best_model_file)) 167 | checkpoint = torch.load(best_model_file) 168 | args.start_epoch = checkpoint['epoch'] 169 | best_acc1 = checkpoint['best_acc1'] 170 | model.load_state_dict(checkpoint['state_dict']) 171 | optimizer.load_state_dict(checkpoint['optimizer']) 172 | print("=> loaded best model '{}' (epoch {})".format(best_model_file, checkpoint['epoch'])) 173 | else: 174 | print("=> no best model found at '{}'".format(best_model_file)) 175 | 176 | print('Check cuda') 177 | if args.cuda: 178 | print('\t* Cuda') 179 | model = model.cuda() 180 | criterion = criterion.cuda() 181 | 182 | # Epoch for loop 183 | for epoch in range(0, args.epochs): 184 | 185 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 186 | args.lr -= lr_step 187 | for param_group in optimizer.param_groups: 188 | param_group['lr'] = args.lr 189 | 190 | # train for one epoch 191 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 192 | 193 | # evaluate on test set 194 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 195 | 196 | is_best = acc1 > best_acc1 197 | best_acc1 = max(acc1, best_acc1) 198 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 199 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 200 | 201 | # Logger step 202 | logger.log_value('learning_rate', args.lr).step() 203 | 204 | # get the best checkpoint and test it with test set 205 | if args.resume: 206 | checkpoint_dir = args.resume 207 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 208 | if not os.path.isdir(checkpoint_dir): 209 | os.makedirs(checkpoint_dir) 210 | if os.path.isfile(best_model_file): 211 | print("=> loading best model '{}'".format(best_model_file)) 212 | checkpoint = torch.load(best_model_file) 213 | args.start_epoch = checkpoint['epoch'] 214 | best_acc1 = checkpoint['best_acc1'] 215 | model.load_state_dict(checkpoint['state_dict']) 216 | optimizer.load_state_dict(checkpoint['optimizer']) 217 | print("=> loaded best model '{}' (epoch {})".format(best_model_file, checkpoint['epoch'])) 218 | else: 219 | print("=> no best model found at '{}'".format(best_model_file)) 220 | 221 | # For testing 222 | validate(test_loader, model, criterion, evaluation) 223 | 224 | 225 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 226 | batch_time = AverageMeter() 227 | data_time = AverageMeter() 228 | losses = AverageMeter() 229 | error_ratio = AverageMeter() 230 | 231 | # switch to train mode 232 | model.train() 233 | 234 | end = time.time() 235 | for i, (g, h, e, target) in enumerate(train_loader): 236 | 237 | # Prepare input data 238 | if args.cuda: 239 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 240 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 241 | 242 | # Measure data loading time 243 | data_time.update(time.time() - end) 244 | 245 | optimizer.zero_grad() 246 | 247 | # Compute output 248 | output = model(g, h, e) 249 | train_loss = criterion(output, target) 250 | 251 | # Logs 252 | losses.update(train_loss.data[0], g.size(0)) 253 | error_ratio.update(evaluation(output, target).data[0], g.size(0)) 254 | 255 | # compute gradient and do SGD step 256 | train_loss.backward() 257 | optimizer.step() 258 | 259 | # Measure elapsed time 260 | batch_time.update(time.time() - end) 261 | end = time.time() 262 | 263 | if i % args.log_interval == 0 and i > 0: 264 | 265 | print('Epoch: [{0}][{1}/{2}]\t' 266 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 267 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 268 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 269 | 'Error Ratio {err.val:.4f} ({err.avg:.4f})' 270 | .format(epoch, i, len(train_loader), batch_time=batch_time, 271 | data_time=data_time, loss=losses, err=error_ratio)) 272 | 273 | logger.log_value('train_epoch_loss', losses.avg) 274 | logger.log_value('train_epoch_error_ratio', error_ratio.avg) 275 | 276 | print('Epoch: [{0}] Avg Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 277 | .format(epoch, err=error_ratio, loss=losses, b_time=batch_time)) 278 | 279 | 280 | def validate(val_loader, model, criterion, evaluation, logger=None): 281 | batch_time = AverageMeter() 282 | losses = AverageMeter() 283 | error_ratio = AverageMeter() 284 | 285 | # switch to evaluate mode 286 | model.eval() 287 | 288 | end = time.time() 289 | for i, (g, h, e, target) in enumerate(val_loader): 290 | 291 | # Prepare input data 292 | if args.cuda: 293 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 294 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 295 | 296 | # Compute output 297 | output = model(g, h, e) 298 | 299 | # Logs 300 | losses.update(criterion(output, target).data[0], g.size(0)) 301 | error_ratio.update(evaluation(output, target).data[0], g.size(0)) 302 | 303 | # measure elapsed time 304 | batch_time.update(time.time() - end) 305 | end = time.time() 306 | 307 | if i % args.log_interval == 0 and i > 0: 308 | 309 | print('Test: [{0}/{1}]\t' 310 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 311 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 312 | 'Error Ratio {err.val:.4f} ({err.avg:.4f})' 313 | .format(i, len(val_loader), batch_time=batch_time, 314 | loss=losses, err=error_ratio)) 315 | 316 | print(' * Average Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}' 317 | .format(err=error_ratio, loss=losses)) 318 | 319 | if logger is not None: 320 | logger.log_value('test_epoch_loss', losses.avg) 321 | logger.log_value('test_epoch_error_ratio', error_ratio.avg) 322 | 323 | 324 | if __name__ == '__main__': 325 | main() 326 | -------------------------------------------------------------------------------- /demos/demo_qm9_intnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodologi defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | 11 | """ 12 | 13 | # Torch 14 | import torch 15 | import torch.optim as optim 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | 19 | import time 20 | import argparse 21 | import os 22 | import sys 23 | import numpy as np 24 | 25 | # Our Modules 26 | reader_folder = os.path.realpath(os.path.abspath('..')) 27 | if reader_folder not in sys.path: 28 | sys.path.append(reader_folder) 29 | import datasets 30 | from datasets import utils 31 | from models.MPNN_IntNet import MpnnIntNet 32 | from LogMetric import AverageMeter, Logger 33 | 34 | __author__ = "Pau Riba, Anjan Dutta" 35 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 36 | 37 | 38 | # Parser check 39 | def restricted_float(x, inter): 40 | x = float(x) 41 | if x < inter[0] or x > inter[1]: 42 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 43 | return x 44 | 45 | # Argument parser 46 | parser = argparse.ArgumentParser(description='Neural message passing') 47 | 48 | parser.add_argument('--dataset', default='qm9', help='QM9') 49 | parser.add_argument('--datasetPath', default='../data/qm9/dsgdb9nsd/', help='dataset path') 50 | parser.add_argument('--logPath', default='../log/qm9/intnet/', help='log path') 51 | parser.add_argument('--plotLr', default=False, help='allow plotting the data') 52 | parser.add_argument('--plotPath', default='../plot/qm9/intnet/', help='plot path') 53 | parser.add_argument('--resume', default='../checkpoint/qm9/intnet/', 54 | help='path to latest checkpoint') 55 | # Optimization Options 56 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 57 | help='Input batch size for training (default: 20)') 58 | parser.add_argument('--no-cuda', action='store_true', default=False, 59 | help='Enables CUDA training') 60 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 61 | help='Number of epochs to train (default: 360)') 62 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 5e-4]), default=1e-4, metavar='LR', 63 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 64 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 65 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 66 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 67 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 68 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 69 | help='SGD momentum (default: 0.9)') 70 | # i/o 71 | parser.add_argument('--log-interval', type=int, default=7, metavar='N', 72 | help='How many batches to wait before logging training status') 73 | # Accelerating 74 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 75 | 76 | best_acc1 = 0 77 | 78 | 79 | def main(): 80 | 81 | global args, best_acc1 82 | args = parser.parse_args() 83 | 84 | # Check if CUDA is enabled 85 | args.cuda = not args.no_cuda and torch.cuda.is_available() 86 | 87 | # Load data 88 | root = args.datasetPath 89 | 90 | print('Prepare files') 91 | files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 92 | 93 | idx = np.random.permutation(len(files)) 94 | idx = idx.tolist() 95 | 96 | valid_ids = [files[i] for i in idx[0:10000]] 97 | test_ids = [files[i] for i in idx[10000:20000]] 98 | train_ids = [files[i] for i in idx[20000:]] 99 | 100 | data_train = datasets.Qm9(root, train_ids) 101 | data_valid = datasets.Qm9(root, valid_ids) 102 | data_test = datasets.Qm9(root, test_ids) 103 | 104 | # Define model and optimizer 105 | print('Define model') 106 | # Select one graph 107 | g_tuple, l = data_train[0] 108 | g, h_t, e = g_tuple 109 | 110 | print('\tStatistics') 111 | # stat_dict = datasets.utils.get_graph_stats(data_valid, ['degrees', 'target_mean', 'target_std', 'edge_labels']) 112 | 113 | stat_dict = {} 114 | 115 | stat_dict['degrees'] = [1, 2, 3, 4] 116 | stat_dict['target_mean'] = np.array([2.71802732e+00, 7.51685080e+01, -2.40259300e-01, 1.09503300e-02, 117 | 2.51209430e-01, 1.18997445e+03, 1.48493130e-01, -4.11609491e+02, 118 | -4.11601022e+02, -4.11600078e+02, -4.11642909e+02, 3.15894998e+01]) 119 | stat_dict['target_std'] = np.array([1.58422291e+00, 8.29443552e+00, 2.23854977e-02, 4.71030547e-02, 120 | 4.77156393e-02, 2.80754665e+02, 3.37238236e-02, 3.97717205e+01, 121 | 3.97715029e+01, 3.97715029e+01, 3.97722334e+01, 4.09458852e+00]) 122 | stat_dict['edge_labels'] = [1, 2, 3, 4] 123 | 124 | data_train.set_target_transform(lambda x: datasets.utils.normalize_data(x,stat_dict['target_mean'], 125 | stat_dict['target_std'])) 126 | data_valid.set_target_transform(lambda x: datasets.utils.normalize_data(x, stat_dict['target_mean'], 127 | stat_dict['target_std'])) 128 | data_test.set_target_transform(lambda x: datasets.utils.normalize_data(x, stat_dict['target_mean'], 129 | stat_dict['target_std'])) 130 | 131 | # Data Loader 132 | train_loader = torch.utils.data.DataLoader(data_train, 133 | batch_size=args.batch_size, shuffle=True, 134 | collate_fn=datasets.utils.collate_g, 135 | num_workers=args.prefetch, pin_memory=True) 136 | valid_loader = torch.utils.data.DataLoader(data_valid, 137 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 138 | num_workers=args.prefetch, pin_memory=True) 139 | test_loader = torch.utils.data.DataLoader(data_test, 140 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 141 | num_workers=args.prefetch, pin_memory=True) 142 | 143 | print('\tCreate model') 144 | model = MpnnIntNet([len(h_t[0]), len(list(e.values())[0])], [5, 15, 15], [10, 20, 20], len(l), type='regression') 145 | 146 | print('Optimizer') 147 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 148 | criterion = nn.MSELoss() 149 | # evaluation = nn.L1Loss() 150 | evaluation = lambda output, target: torch.mean(torch.abs(output - target) / torch.abs(target)) 151 | 152 | print('Logger') 153 | logger = Logger(args.logPath) 154 | 155 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 156 | 157 | # get the best checkpoint if available without training 158 | if args.resume: 159 | checkpoint_dir = args.resume 160 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 161 | if not os.path.isdir(checkpoint_dir): 162 | os.makedirs(checkpoint_dir) 163 | if os.path.isfile(best_model_file): 164 | print("=> loading best model '{}'".format(best_model_file)) 165 | checkpoint = torch.load(best_model_file) 166 | args.start_epoch = checkpoint['epoch'] 167 | best_acc1 = checkpoint['best_acc1'] 168 | model.load_state_dict(checkpoint['state_dict']) 169 | optimizer.load_state_dict(checkpoint['optimizer']) 170 | print("=> loaded best model '{}' (epoch {})".format(best_model_file, checkpoint['epoch'])) 171 | else: 172 | print("=> no best model found at '{}'".format(best_model_file)) 173 | 174 | print('Check cuda') 175 | if args.cuda: 176 | print('\t* Cuda') 177 | model = model.cuda() 178 | criterion = criterion.cuda() 179 | 180 | # Epoch for loop 181 | for epoch in range(0, args.epochs): 182 | 183 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 184 | args.lr -= lr_step 185 | for param_group in optimizer.param_groups: 186 | param_group['lr'] = args.lr 187 | 188 | # train for one epoch 189 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 190 | 191 | # evaluate on test set 192 | acc1 = validate(valid_loader, model, criterion, evaluation, logger) 193 | 194 | is_best = acc1 > best_acc1 195 | best_acc1 = max(acc1, best_acc1) 196 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 197 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 198 | 199 | # Logger step 200 | logger.log_value('learning_rate', args.lr).step() 201 | 202 | # get the best checkpoint and test it with test set 203 | if args.resume: 204 | checkpoint_dir = args.resume 205 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 206 | if not os.path.isdir(checkpoint_dir): 207 | os.makedirs(checkpoint_dir) 208 | if os.path.isfile(best_model_file): 209 | print("=> loading best model '{}'".format(best_model_file)) 210 | checkpoint = torch.load(best_model_file) 211 | args.start_epoch = checkpoint['epoch'] 212 | best_acc1 = checkpoint['best_acc1'] 213 | model.load_state_dict(checkpoint['state_dict']) 214 | optimizer.load_state_dict(checkpoint['optimizer']) 215 | print("=> loaded best model '{}' (epoch {})".format(best_model_file, checkpoint['epoch'])) 216 | else: 217 | print("=> no best model found at '{}'".format(best_model_file)) 218 | 219 | # For testing 220 | validate(test_loader, model, criterion, evaluation) 221 | 222 | 223 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 224 | batch_time = AverageMeter() 225 | data_time = AverageMeter() 226 | losses = AverageMeter() 227 | error_ratio = AverageMeter() 228 | 229 | # switch to train mode 230 | model.train() 231 | 232 | end = time.time() 233 | for i, (g, h, e, target) in enumerate(train_loader): 234 | 235 | # Prepare input data 236 | if args.cuda: 237 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 238 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 239 | 240 | # Measure data loading time 241 | data_time.update(time.time() - end) 242 | 243 | optimizer.zero_grad() 244 | 245 | # Compute output 246 | output = model(g, h, e) 247 | train_loss = criterion(output, target) 248 | 249 | # Logs 250 | losses.update(train_loss.data[0], g.size(0)) 251 | error_ratio.update(evaluation(output, target).data[0], g.size(0)) 252 | 253 | # compute gradient and do SGD step 254 | train_loss.backward() 255 | optimizer.step() 256 | 257 | # Measure elapsed time 258 | batch_time.update(time.time() - end) 259 | end = time.time() 260 | 261 | if i % args.log_interval == 0 and i > 0: 262 | 263 | print('Epoch: [{0}][{1}/{2}]\t' 264 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 265 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 266 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 267 | 'Error Ratio {err.val:.4f} ({err.avg:.4f})' 268 | .format(epoch, i, len(train_loader), batch_time=batch_time, 269 | data_time=data_time, loss=losses, err=error_ratio)) 270 | 271 | logger.log_value('train_epoch_loss', losses.avg) 272 | logger.log_value('train_epoch_error_ratio', error_ratio.avg) 273 | 274 | print('Epoch: [{0}] Avg Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 275 | .format(epoch, err=error_ratio, loss=losses, b_time=batch_time)) 276 | 277 | 278 | def validate(val_loader, model, criterion, evaluation, logger=None): 279 | batch_time = AverageMeter() 280 | losses = AverageMeter() 281 | error_ratio = AverageMeter() 282 | 283 | # switch to evaluate mode 284 | model.eval() 285 | 286 | end = time.time() 287 | for i, (g, h, e, target) in enumerate(val_loader): 288 | 289 | # Prepare input data 290 | if args.cuda: 291 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 292 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 293 | 294 | # Compute output 295 | output = model(g, h, e) 296 | 297 | # Logs 298 | losses.update(criterion(output, target).data[0], g.size(0)) 299 | error_ratio.update(evaluation(output, target).data[0], g.size(0)) 300 | 301 | # measure elapsed time 302 | batch_time.update(time.time() - end) 303 | end = time.time() 304 | 305 | if i % args.log_interval == 0 and i > 0: 306 | 307 | print('Test: [{0}/{1}]\t' 308 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 309 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 310 | 'Error Ratio {err.val:.4f} ({err.avg:.4f})' 311 | .format(i, len(val_loader), batch_time=batch_time, 312 | loss=losses, err=error_ratio)) 313 | 314 | print(' * Average Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}' 315 | .format(err=error_ratio, loss=losses)) 316 | 317 | if logger is not None: 318 | logger.log_value('test_epoch_loss', losses.avg) 319 | logger.log_value('test_epoch_error_ratio', error_ratio.avg) 320 | 321 | 322 | if __name__ == '__main__': 323 | main() 324 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Trains a Neural Message Passing Model on various datasets. Methodologi defined in: 6 | 7 | Gilmer, J., Schoenholz S.S., Riley, P.F., Vinyals, O., Dahl, G.E. (2017) 8 | Neural Message Passing for Quantum Chemistry. 9 | arXiv preprint arXiv:1704.01212 [cs.LG] 10 | 11 | """ 12 | 13 | # Torch 14 | import torch 15 | import torch.optim as optim 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | 19 | import time 20 | import argparse 21 | import os 22 | import numpy as np 23 | 24 | # Our Modules 25 | import datasets 26 | from datasets import utils 27 | from models.MPNN import MPNN 28 | from LogMetric import AverageMeter, Logger 29 | 30 | __author__ = "Pau Riba, Anjan Dutta" 31 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 32 | 33 | 34 | # Parser check 35 | def restricted_float(x, inter): 36 | x = float(x) 37 | if x < inter[0] or x > inter[1]: 38 | raise argparse.ArgumentTypeError("%r not in range [1e-5, 1e-4]"%(x,)) 39 | return x 40 | 41 | # Argument parser 42 | parser = argparse.ArgumentParser(description='Neural message passing') 43 | 44 | parser.add_argument('--dataset', default='qm9', help='QM9') 45 | parser.add_argument('--datasetPath', default='./data/qm9/dsgdb9nsd/', help='dataset path') 46 | parser.add_argument('--logPath', default='./log/qm9/mpnn/', help='log path') 47 | parser.add_argument('--plotLr', default=False, help='allow plotting the data') 48 | parser.add_argument('--plotPath', default='./plot/qm9/mpnn/', help='plot path') 49 | parser.add_argument('--resume', default='./checkpoint/qm9/mpnn/', 50 | help='path to latest checkpoint') 51 | # Optimization Options 52 | parser.add_argument('--batch-size', type=int, default=100, metavar='N', 53 | help='Input batch size for training (default: 20)') 54 | parser.add_argument('--no-cuda', action='store_true', default=False, 55 | help='Enables CUDA training') 56 | parser.add_argument('--epochs', type=int, default=360, metavar='N', 57 | help='Number of epochs to train (default: 360)') 58 | parser.add_argument('--lr', type=lambda x: restricted_float(x, [1e-5, 1e-2]), default=1e-3, metavar='LR', 59 | help='Initial learning rate [1e-5, 5e-4] (default: 1e-4)') 60 | parser.add_argument('--lr-decay', type=lambda x: restricted_float(x, [.01, 1]), default=0.6, metavar='LR-DECAY', 61 | help='Learning rate decay factor [.01, 1] (default: 0.6)') 62 | parser.add_argument('--schedule', type=list, default=[0.1, 0.9], metavar='S', 63 | help='Percentage of epochs to start the learning rate decay [0, 1] (default: [0.1, 0.9])') 64 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 65 | help='SGD momentum (default: 0.9)') 66 | # i/o 67 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 68 | help='How many batches to wait before logging training status') 69 | # Accelerating 70 | parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.') 71 | 72 | best_er1 = 0 73 | 74 | 75 | def main(): 76 | 77 | global args, best_er1 78 | args = parser.parse_args() 79 | 80 | # Check if CUDA is enabled 81 | args.cuda = not args.no_cuda and torch.cuda.is_available() 82 | 83 | # Load data 84 | root = args.datasetPath 85 | 86 | print('Prepare files') 87 | files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 88 | 89 | idx = np.random.permutation(len(files)) 90 | idx = idx.tolist() 91 | 92 | valid_ids = [files[i] for i in idx[0:10000]] 93 | test_ids = [files[i] for i in idx[10000:20000]] 94 | train_ids = [files[i] for i in idx[20000:]] 95 | 96 | data_train = datasets.Qm9(root, train_ids, edge_transform=utils.qm9_edges, e_representation='raw_distance') 97 | data_valid = datasets.Qm9(root, valid_ids, edge_transform=utils.qm9_edges, e_representation='raw_distance') 98 | data_test = datasets.Qm9(root, test_ids, edge_transform=utils.qm9_edges, e_representation='raw_distance') 99 | 100 | # Define model and optimizer 101 | print('Define model') 102 | # Select one graph 103 | g_tuple, l = data_train[0] 104 | g, h_t, e = g_tuple 105 | 106 | print('\tStatistics') 107 | stat_dict = datasets.utils.get_graph_stats(data_valid, ['target_mean', 'target_std']) 108 | 109 | data_train.set_target_transform(lambda x: datasets.utils.normalize_data(x,stat_dict['target_mean'], 110 | stat_dict['target_std'])) 111 | data_valid.set_target_transform(lambda x: datasets.utils.normalize_data(x, stat_dict['target_mean'], 112 | stat_dict['target_std'])) 113 | data_test.set_target_transform(lambda x: datasets.utils.normalize_data(x, stat_dict['target_mean'], 114 | stat_dict['target_std'])) 115 | 116 | # Data Loader 117 | train_loader = torch.utils.data.DataLoader(data_train, 118 | batch_size=args.batch_size, shuffle=True, 119 | collate_fn=datasets.utils.collate_g, 120 | num_workers=args.prefetch, pin_memory=True) 121 | valid_loader = torch.utils.data.DataLoader(data_valid, 122 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 123 | num_workers=args.prefetch, pin_memory=True) 124 | test_loader = torch.utils.data.DataLoader(data_test, 125 | batch_size=args.batch_size, collate_fn=datasets.utils.collate_g, 126 | num_workers=args.prefetch, pin_memory=True) 127 | 128 | print('\tCreate model') 129 | in_n = [len(h_t[0]), len(list(e.values())[0])] 130 | hidden_state_size = 73 131 | message_size = 73 132 | n_layers = 3 133 | l_target = len(l) 134 | type ='regression' 135 | model = MPNN(in_n, hidden_state_size, message_size, n_layers, l_target, type=type) 136 | del in_n, hidden_state_size, message_size, n_layers, l_target, type 137 | 138 | print('Optimizer') 139 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 140 | 141 | criterion = nn.MSELoss() 142 | 143 | evaluation = lambda output, target: torch.mean(torch.abs(output - target) / torch.abs(target)) 144 | 145 | print('Logger') 146 | logger = Logger(args.logPath) 147 | 148 | lr_step = (args.lr-args.lr*args.lr_decay)/(args.epochs*args.schedule[1] - args.epochs*args.schedule[0]) 149 | 150 | # get the best checkpoint if available without training 151 | if args.resume: 152 | checkpoint_dir = args.resume 153 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 154 | if not os.path.isdir(checkpoint_dir): 155 | os.makedirs(checkpoint_dir) 156 | if os.path.isfile(best_model_file): 157 | print("=> loading best model '{}'".format(best_model_file)) 158 | checkpoint = torch.load(best_model_file) 159 | args.start_epoch = checkpoint['epoch'] 160 | best_acc1 = checkpoint['best_er1'] 161 | model.load_state_dict(checkpoint['state_dict']) 162 | optimizer.load_state_dict(checkpoint['optimizer']) 163 | print("=> loaded best model '{}' (epoch {})".format(best_model_file, checkpoint['epoch'])) 164 | else: 165 | print("=> no best model found at '{}'".format(best_model_file)) 166 | 167 | print('Check cuda') 168 | if args.cuda: 169 | print('\t* Cuda') 170 | model = model.cuda() 171 | criterion = criterion.cuda() 172 | 173 | # Epoch for loop 174 | for epoch in range(0, args.epochs): 175 | 176 | if epoch > args.epochs * args.schedule[0] and epoch < args.epochs * args.schedule[1]: 177 | args.lr -= lr_step 178 | for param_group in optimizer.param_groups: 179 | param_group['lr'] = args.lr 180 | 181 | # train for one epoch 182 | train(train_loader, model, criterion, optimizer, epoch, evaluation, logger) 183 | 184 | # evaluate on test set 185 | er1 = validate(valid_loader, model, criterion, evaluation, logger) 186 | 187 | is_best = er1 > best_er1 188 | best_er1 = min(er1, best_er1) 189 | utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_er1': best_er1, 190 | 'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume) 191 | 192 | # Logger step 193 | logger.log_value('learning_rate', args.lr).step() 194 | 195 | # get the best checkpoint and test it with test set 196 | if args.resume: 197 | checkpoint_dir = args.resume 198 | best_model_file = os.path.join(checkpoint_dir, 'model_best.pth') 199 | if not os.path.isdir(checkpoint_dir): 200 | os.makedirs(checkpoint_dir) 201 | if os.path.isfile(best_model_file): 202 | print("=> loading best model '{}'".format(best_model_file)) 203 | checkpoint = torch.load(best_model_file) 204 | args.start_epoch = checkpoint['epoch'] 205 | best_acc1 = checkpoint['best_er1'] 206 | model.load_state_dict(checkpoint['state_dict']) 207 | if args.cuda: 208 | model.cuda() 209 | optimizer.load_state_dict(checkpoint['optimizer']) 210 | print("=> loaded best model '{}' (epoch {})".format(best_model_file, checkpoint['epoch'])) 211 | else: 212 | print("=> no best model found at '{}'".format(best_model_file)) 213 | 214 | # For testing 215 | validate(test_loader, model, criterion, evaluation) 216 | 217 | 218 | def train(train_loader, model, criterion, optimizer, epoch, evaluation, logger): 219 | batch_time = AverageMeter() 220 | data_time = AverageMeter() 221 | losses = AverageMeter() 222 | error_ratio = AverageMeter() 223 | 224 | # switch to train mode 225 | model.train() 226 | 227 | end = time.time() 228 | for i, (g, h, e, target) in enumerate(train_loader): 229 | 230 | # Prepare input data 231 | if args.cuda: 232 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 233 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 234 | 235 | # Measure data loading time 236 | data_time.update(time.time() - end) 237 | 238 | optimizer.zero_grad() 239 | 240 | # Compute output 241 | output = model(g, h, e) 242 | train_loss = criterion(output, target) 243 | 244 | # Logs 245 | losses.update(train_loss.data[0], g.size(0)) 246 | error_ratio.update(evaluation(output, target).data[0], g.size(0)) 247 | 248 | # compute gradient and do SGD step 249 | train_loss.backward() 250 | optimizer.step() 251 | 252 | # Measure elapsed time 253 | batch_time.update(time.time() - end) 254 | end = time.time() 255 | 256 | if i % args.log_interval == 0 and i > 0: 257 | 258 | print('Epoch: [{0}][{1}/{2}]\t' 259 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 260 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 261 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 262 | 'Error Ratio {err.val:.4f} ({err.avg:.4f})' 263 | .format(epoch, i, len(train_loader), batch_time=batch_time, 264 | data_time=data_time, loss=losses, err=error_ratio)) 265 | 266 | logger.log_value('train_epoch_loss', losses.avg) 267 | logger.log_value('train_epoch_error_ratio', error_ratio.avg) 268 | 269 | print('Epoch: [{0}] Avg Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}; Avg Time x Batch {b_time.avg:.3f}' 270 | .format(epoch, err=error_ratio, loss=losses, b_time=batch_time)) 271 | 272 | 273 | def validate(val_loader, model, criterion, evaluation, logger=None): 274 | batch_time = AverageMeter() 275 | losses = AverageMeter() 276 | error_ratio = AverageMeter() 277 | 278 | # switch to evaluate mode 279 | model.eval() 280 | 281 | end = time.time() 282 | for i, (g, h, e, target) in enumerate(val_loader): 283 | 284 | # Prepare input data 285 | if args.cuda: 286 | g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() 287 | g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) 288 | 289 | # Compute output 290 | output = model(g, h, e) 291 | 292 | # Logs 293 | losses.update(criterion(output, target).data[0], g.size(0)) 294 | error_ratio.update(evaluation(output, target).data[0], g.size(0)) 295 | 296 | # measure elapsed time 297 | batch_time.update(time.time() - end) 298 | end = time.time() 299 | 300 | if i % args.log_interval == 0 and i > 0: 301 | 302 | print('Test: [{0}/{1}]\t' 303 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 304 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 305 | 'Error Ratio {err.val:.4f} ({err.avg:.4f})' 306 | .format(i, len(val_loader), batch_time=batch_time, 307 | loss=losses, err=error_ratio)) 308 | 309 | print(' * Average Error Ratio {err.avg:.3f}; Average Loss {loss.avg:.3f}' 310 | .format(err=error_ratio, loss=losses)) 311 | 312 | if logger is not None: 313 | logger.log_value('test_epoch_loss', losses.avg) 314 | logger.log_value('test_epoch_error_ratio', error_ratio.avg) 315 | 316 | return error_ratio.avg 317 | 318 | 319 | if __name__ == '__main__': 320 | main() 321 | -------------------------------------------------------------------------------- /models/MPNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from MessageFunction import MessageFunction 5 | from UpdateFunction import UpdateFunction 6 | from ReadoutFunction import ReadoutFunction 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | __author__ = "Pau Riba, Anjan Dutta" 13 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 14 | 15 | 16 | class MPNN(nn.Module): 17 | """ 18 | MPNN as proposed by Gilmer et al.. 19 | 20 | This class implements the whole Gilmer et al. model following the functions Message, Update and Readout. 21 | 22 | Parameters 23 | ---------- 24 | in_n : int list 25 | Sizes for the node and edge features. 26 | hidden_state_size : int 27 | Size of the hidden states (the input will be padded with 0's to this size). 28 | message_size : int 29 | Message function output vector size. 30 | n_layers : int 31 | Number of iterations Message+Update (weight tying). 32 | l_target : int 33 | Size of the output. 34 | type : str (Optional) 35 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 36 | """ 37 | 38 | def __init__(self, in_n, hidden_state_size, message_size, n_layers, l_target, type='regression'): 39 | super(MPNN, self).__init__() 40 | 41 | # Define message 42 | self.m = nn.ModuleList( 43 | [MessageFunction('mpnn', args={'edge_feat': in_n[1], 'in': hidden_state_size, 'out': message_size})]) 44 | 45 | # Define Update 46 | self.u = nn.ModuleList([UpdateFunction('mpnn', 47 | args={'in_m': message_size, 48 | 'out': hidden_state_size})]) 49 | 50 | # Define Readout 51 | self.r = ReadoutFunction('mpnn', 52 | args={'in': hidden_state_size, 53 | 'target': l_target}) 54 | 55 | self.type = type 56 | 57 | self.args = {} 58 | self.args['out'] = hidden_state_size 59 | 60 | self.n_layers = n_layers 61 | 62 | def forward(self, g, h_in, e): 63 | 64 | h = [] 65 | 66 | # Padding to some larger dimension d 67 | h_t = torch.cat([h_in, Variable( 68 | torch.zeros(h_in.size(0), h_in.size(1), self.args['out'] - h_in.size(2)).type_as(h_in.data))], 2) 69 | 70 | h.append(h_t.clone()) 71 | 72 | # Layer 73 | for t in range(0, self.n_layers): 74 | e_aux = e.view(-1, e.size(3)) 75 | 76 | h_aux = h[t].view(-1, h[t].size(2)) 77 | 78 | m = self.m[0].forward(h[t], h_aux, e_aux) 79 | m = m.view(h[0].size(0), h[0].size(1), -1, m.size(1)) 80 | 81 | # Nodes without edge set message to 0 82 | m = torch.unsqueeze(g, 3).expand_as(m) * m 83 | 84 | m = torch.squeeze(torch.sum(m, 1)) 85 | 86 | h_t = self.u[0].forward(h[t], m) 87 | 88 | # Delete virtual nodes 89 | h_t = (torch.sum(h_in, 2).expand_as(h_t) > 0).type_as(h_t) * h_t 90 | h.append(h_t) 91 | 92 | # Readout 93 | res = self.r.forward(h) 94 | 95 | if self.type == 'classification': 96 | res = nn.LogSoftmax()(res) 97 | return res -------------------------------------------------------------------------------- /models/MPNN_Duvenaud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from MessageFunction import MessageFunction 5 | from UpdateFunction import UpdateFunction 6 | from ReadoutFunction import ReadoutFunction 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | __author__ = "Pau Riba, Anjan Dutta" 13 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 14 | 15 | 16 | class MpnnDuvenaud(nn.Module): 17 | """ 18 | MPNN as proposed by Duvenaud et al.. 19 | 20 | This class implements the whole Duvenaud et al. model following the functions proposed by Gilmer et al. as 21 | Message, Update and Readout. 22 | 23 | Parameters 24 | ---------- 25 | d : int list. 26 | Possible degrees for the input graph. 27 | in_n : int list 28 | Sizes for the node and edge features. 29 | out_update : int list 30 | Output sizes for the different Update functions. 31 | hidden_state_readout : int 32 | Input size for the neural net used inside the readout function. 33 | l_target : int 34 | Size of the output. 35 | type : str (Optional) 36 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 37 | """ 38 | 39 | def __init__(self, d, in_n, out_update, hidden_state_readout, l_target, type='regression'): 40 | super(MpnnDuvenaud, self).__init__() 41 | 42 | n_layers = len(out_update) 43 | 44 | # Define message 1 & 2 45 | self.m = nn.ModuleList([MessageFunction('duvenaud') for _ in range(n_layers)]) 46 | 47 | # Define Update 1 & 2 48 | self.u = nn.ModuleList([UpdateFunction('duvenaud', args={'deg': d, 'in': self.m[i].get_out_size(in_n[0], in_n[1]), 'out': out_update[0]}) if i == 0 else 49 | UpdateFunction('duvenaud', args={'deg': d, 'in': self.m[i].get_out_size(out_update[i-1], in_n[1]), 'out': out_update[i]}) for i in range(n_layers)]) 50 | 51 | # Define Readout 52 | self.r = ReadoutFunction('duvenaud', 53 | args={'layers': len(self.m) + 1, 54 | 'in': [in_n[0] if i == 0 else out_update[i-1] for i in range(n_layers+1)], 55 | 'out': hidden_state_readout, 56 | 'target': l_target}) 57 | 58 | self.type = type 59 | 60 | def forward(self, g, h_in, e, plotter=None): 61 | 62 | h = [] 63 | h.append(h_in) 64 | 65 | # Layer 66 | for t in range(0, len(self.m)): 67 | 68 | u_args = self.u[t].get_args() 69 | 70 | h_t = Variable(torch.zeros(h_in.size(0), h_in.size(1), u_args['out']).type_as(h[t].data)) 71 | 72 | # Apply one layer pass (Message + Update) 73 | for v in range(0, h_in.size(1)): 74 | 75 | m = self.m[t].forward(h[t][:, v, :], h[t], e[:, v, :]) 76 | 77 | # Nodes without edge set message to 0 78 | m = g[:, v, :, None].expand_as(m) * m 79 | 80 | m = torch.sum(m, 1) 81 | 82 | # Duvenaud 83 | deg = torch.sum(g[:, v, :].data, 1) 84 | 85 | # Separate degrees 86 | for i in range(len(u_args['deg'])): 87 | ind = deg == u_args['deg'][i] 88 | ind = Variable(torch.squeeze(torch.nonzero(torch.squeeze(ind))), volatile=True) 89 | 90 | opt = {'deg': i} 91 | 92 | # Update 93 | if len(ind) != 0: 94 | aux = self.u[t].forward(torch.index_select(h[t], 0, ind)[:, v, :], torch.index_select(m, 0, ind), opt) 95 | 96 | ind = ind.data.cpu().numpy() 97 | for j in range(len(ind)): 98 | h_t[ind[j], v, :] = aux[j, :] 99 | 100 | if plotter is not None: 101 | num_feat = h_t.size(2) 102 | color = h_t[0,:,:].data.cpu().numpy() 103 | for i in range(num_feat): 104 | plotter(color[:, i], 'layer_' + str(t) + '_element_' + str(i) + '.png') 105 | 106 | h.append(h_t.clone()) 107 | # Readout 108 | res = self.r.forward(h) 109 | if self.type == 'classification': 110 | res = nn.LogSoftmax()(res) 111 | return res 112 | -------------------------------------------------------------------------------- /models/MPNN_GGNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from MessageFunction import MessageFunction 5 | from UpdateFunction import UpdateFunction 6 | from ReadoutFunction import ReadoutFunction 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | __author__ = "Pau Riba, Anjan Dutta" 13 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 14 | 15 | 16 | class MpnnGGNN(nn.Module): 17 | """ 18 | MPNN as proposed by Li et al.. 19 | 20 | This class implements the whole Li et al. model following the functions proposed by Gilmer et al. as 21 | Message, Update and Readout. 22 | 23 | Parameters 24 | ---------- 25 | e : int list. 26 | Possible edge labels for the input graph. 27 | hidden_state_size : int 28 | Size of the hidden states (the input will be padded with 0's to this size). 29 | message_size : int 30 | Message function output vector size. 31 | n_layers : int 32 | Number of iterations Message+Update (weight tying). 33 | l_target : int 34 | Size of the output. 35 | type : str (Optional) 36 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 37 | """ 38 | 39 | def __init__(self, e, hidden_state_size, message_size, n_layers, l_target, type='regression'): 40 | super(MpnnGGNN, self).__init__() 41 | 42 | # Define message 43 | self.m = nn.ModuleList([MessageFunction('ggnn', args={'e_label': e, 'in': hidden_state_size, 'out': message_size})]) 44 | 45 | # Define Update 46 | self.u = nn.ModuleList([UpdateFunction('ggnn', 47 | args={'in_m': message_size, 48 | 'out': hidden_state_size})]) 49 | 50 | # Define Readout 51 | self.r = ReadoutFunction('ggnn', 52 | args={'in': hidden_state_size, 53 | 'target': l_target}) 54 | 55 | self.type = type 56 | 57 | self.args = {} 58 | self.args['out'] = hidden_state_size 59 | 60 | self.n_layers = n_layers 61 | 62 | def forward(self, g, h_in, e): 63 | 64 | h = [] 65 | 66 | # Padding to some larger dimension d 67 | h_t = torch.cat([h_in, Variable(torch.Tensor(h_in.size(0), h_in.size(1), self.args['out'] - h_in.size(2)).type_as(h_in.data).zero_())], 2) 68 | 69 | h.append(h_t.clone()) 70 | 71 | # Layer 72 | for t in range(0, self.n_layers): 73 | 74 | h_t = Variable(torch.zeros(h[0].size(0), h[0].size(1), h[0].size(2)).type_as(h_in.data)) 75 | 76 | # Apply one layer pass (Message + Update) 77 | for v in range(0, h_in.size(1)): 78 | 79 | m = self.m[0].forward(h[t][:, v, :], h[t], e[:, v, :]) 80 | 81 | # Nodes without edge set message to 0 82 | m = g[:, v, :, None].expand_as(m) * m 83 | 84 | m = torch.sum(m, 1) 85 | 86 | # Update 87 | h_t[:, v, :] = self.u[0].forward(h[t][:, v, :], m) 88 | 89 | # Delete virtual nodes 90 | h_t = (torch.sum(torch.abs(h_in), 2).expand_as(h_t) > 0).type_as(h_t)*h_t 91 | h.append(h_t.clone()) 92 | 93 | # Readout 94 | res = self.r.forward(h) 95 | if self.type == 'classification': 96 | res = nn.LogSoftmax()(res) 97 | return res -------------------------------------------------------------------------------- /models/MPNN_IntNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from MessageFunction import MessageFunction 5 | from UpdateFunction import UpdateFunction 6 | from ReadoutFunction import ReadoutFunction 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | __author__ = "Pau Riba, Anjan Dutta" 13 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 14 | 15 | 16 | class MpnnIntNet(nn.Module): 17 | """ 18 | MPNN as proposed by Battaglia et al.. 19 | 20 | This class implements the whole Battaglia et al. model following the functions proposed by Gilmer et al. as 21 | Message, Update and Readout. 22 | 23 | Parameters 24 | ---------- 25 | in_n : int list 26 | Sizes for the node and edge features. 27 | out_message : int list 28 | Output sizes for the different Message functions. 29 | out_update : int list 30 | Output sizes for the different Update functions. 31 | l_target : int 32 | Size of the output. 33 | type : str (Optional) 34 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 35 | """ 36 | 37 | def __init__(self, in_n, out_message, out_update, l_target, type='regression'): 38 | super(MpnnIntNet, self).__init__() 39 | 40 | n_layers = len(out_update) 41 | 42 | # Define message 1 & 2 43 | self.m = nn.ModuleList([MessageFunction('intnet', args={'in': 2*in_n[0] + in_n[1], 'out': out_message[i]}) 44 | if i == 0 else 45 | MessageFunction('intnet', args={'in': 2*out_update[i-1] + in_n[1], 'out': out_message[i]}) 46 | for i in range(n_layers)]) 47 | 48 | # Define Update 1 & 2 49 | self.u = nn.ModuleList([UpdateFunction('intnet', args={'in': in_n[0]+out_message[i], 'out': out_update[i]}) 50 | if i == 0 else 51 | UpdateFunction('intnet', args={'in': out_update[i-1]+out_message[i], 'out': out_update[i]}) 52 | for i in range(n_layers)]) 53 | 54 | # Define Readout 55 | self.r = ReadoutFunction('intnet', args={'in': out_update[-1], 'target': l_target}) 56 | 57 | self.type = type 58 | 59 | def forward(self, g, h_in, e): 60 | 61 | h = [] 62 | h.append(h_in) 63 | 64 | # Layer 65 | for t in range(0, len(self.m)): 66 | 67 | u_args = self.u[t].get_args() 68 | h_t = Variable(torch.zeros(h_in.size(0), h_in.size(1), u_args['out']).type_as(h[t].data)) 69 | 70 | # Apply one layer pass (Message + Update) 71 | for v in range(0, h_in.size(1)): 72 | 73 | m = self.m[t].forward(h[t][:, v, :], h[t], e[:, v, :, :]) 74 | 75 | # Nodes without edge set message to 0 76 | m = g[:, v, :,None].expand_as(m) * m 77 | 78 | m = torch.sum(m, 1) 79 | 80 | # Interaction Net 81 | opt = {} 82 | opt['x_v'] = Variable(torch.Tensor([]).type_as(m.data)) 83 | 84 | h_t[:, v, :] = self.u[t].forward(h[t][:, v, :], m, opt) 85 | 86 | h.append(h_t.clone()) 87 | 88 | # Readout 89 | res = self.r.forward(h) 90 | if self.type == 'classification': 91 | res = nn.LogSoftmax()(res) 92 | return res 93 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | #Available MPNN models 2 | 3 | Some of the models available in the literature have been implemented as Message, Update and Readout functions. 4 | 5 | ## MpnnDuvenaud 6 | 7 | This class implements the whole Duvenaud et al. model following the functions proposed by Gilmer et al. as Message, Update and Readout. 8 | 9 | ``` 10 | Parameters 11 | ---------- 12 | d : int list. 13 | Possible degrees for the input graph. 14 | in_n : int list 15 | Sizes for the node and edge features. 16 | out_update : int list 17 | Output sizes for the different Update funtion. 18 | hidden_state_readout : int 19 | Input size for the neural net used inside the readout function. 20 | l_target : int 21 | Size of the output. 22 | type : str (Optional) 23 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 24 | ``` 25 | 26 | Definition: 27 | 28 | ``` 29 | model = MpnnDuvenaud(d, in_n, out_update, hidden_state_readout, l_target') 30 | ``` 31 | 32 | ## MpnnGGNN 33 | 34 | 35 | This class implements the whole Li et al. model following the functions proposed by Gilmer et al. as Message, Update and Readout. 36 | 37 | ``` 38 | Parameters 39 | ---------- 40 | e : int list. 41 | Possible edge labels for the input graph. 42 | hidden_state_size : int 43 | Size of the hidden states (the input will be padded with 0's to this size). 44 | message_size : int 45 | Message function output vector size. 46 | n_layers : int 47 | Number of iterations Message+Update (weight tying). 48 | l_target : int 49 | Size of the output. 50 | type : str (Optional) 51 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 52 | ``` 53 | 54 | Definition: 55 | 56 | ``` 57 | model = MpnnGGNN(e, in_n, hidden_state_size, message_size, n_layers, l_target) 58 | ``` 59 | 60 | ## IntNet 61 | 62 | This class implements the whole Battaglia et al. model following the functions proposed by Gilmer et al. as Message, Update and Readout. 63 | 64 | ``` 65 | Parameters 66 | ---------- 67 | in_n : int list 68 | Sizes for the node and edge features. 69 | out_message : int list 70 | Output sizes for the different Message functions. 71 | out_update : int list 72 | Output sizes for the different Update functions. 73 | l_target : int 74 | Size of the output. 75 | type : str (Optional) 76 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 77 | ``` 78 | 79 | Definition: 80 | 81 | ``` 82 | model = MpnnIntNet(in_n, out_message, out_update, l_target): 83 | ``` 84 | 85 | ## MPNN as proposed by Gilmer et al. 86 | 87 | This class implements the whole Gilmer et al. model following the functions Message, Update and Readout. 88 | 89 | In progress.. 90 | 91 | * [x] Edge Network 92 | * [ ] Virtual Graph Elements 93 | * [ ] set2set Readout function 94 | * [ ] Multiple Towers 95 | 96 | ``` 97 | Parameters 98 | ---------- 99 | in_n : int list 100 | Sizes for the node and edge features. 101 | hidden_state_size : int 102 | Size of the hidden states (the input will be padded with 0's to this size). 103 | message_size : int 104 | Message function output vector size. 105 | n_layers : int 106 | Number of iterations Message+Update (weight tying). 107 | l_target : int 108 | Size of the output. 109 | type : str (Optional) 110 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 111 | ``` 112 | Definition: 113 | 114 | ``` 115 | model = MPNN(in_n, hidden_state_size, message_size, n_layers, l_target, type='regression'): 116 | ``` 117 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/priba/nmp_qc/177db7ea738a7a91f1262ce954f9c7a4a2b98849/models/__init__.py -------------------------------------------------------------------------------- /models/nnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __author__ = "Pau Riba, Anjan Dutta" 8 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 9 | 10 | 11 | # class NNet(nn.Module): 12 | # 13 | # def __init__(self, n_in, n_out): 14 | # super(NNet, self).__init__() 15 | # 16 | # self.fc1 = nn.Linear(n_in, 120) 17 | # self.fc2 = nn.Linear(120, 84) 18 | # self.fc3 = nn.Linear(84, n_out) 19 | # 20 | # def forward(self, x): 21 | # 22 | # x = x.view(-1, self.num_flat_features(x)) 23 | # x = F.relu(self.fc1(x)) 24 | # x = F.relu(self.fc2(x)) 25 | # x = self.fc3(x) 26 | # return x 27 | # 28 | # def num_flat_features(self, x): 29 | # size = x.size()[1:] # all dimensions except the batch dimension 30 | # num_features = 1 31 | # for s in size: 32 | # num_features *= s 33 | # return num_features 34 | 35 | # small neural network with fully connected layers 36 | 37 | class NNet(nn.Module): 38 | 39 | def __init__(self, n_in, n_out, hlayers=(128, 256, 128)): 40 | super(NNet, self).__init__() 41 | self.n_hlayers = len(hlayers) 42 | self.fcs = nn.ModuleList([nn.Linear(n_in, hlayers[i]) if i == 0 else 43 | nn.Linear(hlayers[i-1], n_out) if i == self.n_hlayers else 44 | nn.Linear(hlayers[i-1], hlayers[i]) for i in range(self.n_hlayers+1)]) 45 | 46 | def forward(self, x): 47 | x = x.contiguous().view(-1, self.num_flat_features(x)) 48 | for i in range(self.n_hlayers): 49 | x = F.relu(self.fcs[i](x)) 50 | x = self.fcs[-1](x) 51 | return x 52 | 53 | def num_flat_features(self, x): 54 | size = x.size()[1:] # all dimensions except the batch dimension 55 | num_features = 1 56 | for s in size: 57 | num_features *= s 58 | return num_features 59 | 60 | # class NNetM(nn.Module): 61 | # 62 | # def __init__(self, n_in, n_out): 63 | # super(NNetM, self).__init__() 64 | # 65 | # self.fc1 = nn.Linear(n_in, 120) 66 | # self.fc2 = nn.Linear(120, 84) 67 | # self.fc3 = nn.Linear(84, n_out[0]*n_out[1]) 68 | # 69 | # def forward(self, x): 70 | # 71 | # x = F.relu(self.fc1(x)) 72 | # x = F.relu(self.fc2(x)) 73 | # x = self.fc3(x) 74 | # return x 75 | 76 | 77 | def main(): 78 | net = NNet(n_in=100, n_out=20) 79 | print(net) 80 | 81 | if __name__=='__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | wget 5 | networkx 6 | joblib 7 | tensorboard 8 | -------------------------------------------------------------------------------- /visualization/Plotter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Plotter.py: Propagates a message depending on two nodes and their common edge. 6 | 7 | Usage: 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | 13 | import networkx as nx 14 | import matplotlib 15 | matplotlib.use("Agg") 16 | import matplotlib.pyplot as plt 17 | import matplotlib.colors as mcol 18 | import matplotlib.cm as cm 19 | import os 20 | import warnings 21 | 22 | 23 | __author__ = "Pau Riba, Anjan Dutta" 24 | __email__ = "priba@cvc.uab.cat, adutta@cvc.uab.cat" 25 | 26 | 27 | """ 28 | Plots a Graph with the library networkx 29 | """ 30 | 31 | class Plotter(): 32 | # Constructor 33 | def __init__(self, plot_dir = './'): 34 | self.plotdir = plot_dir 35 | 36 | if os.path.isdir(plot_dir): 37 | # clean previous logged data under the same directory name 38 | self._remove(plot_dir) 39 | 40 | os.makedirs(plot_dir) 41 | 42 | 43 | @staticmethod 44 | def _remove(path): 45 | """ param could either be relative or absolute. """ 46 | if os.path.isfile(path): 47 | os.remove(path) # remove the file 48 | elif os.path.isdir(path): 49 | import shutil 50 | shutil.rmtree(path) # remove dir and all contains 51 | 52 | def plot_graph(self, am, position=None, cls=None, fig_name='graph.png'): 53 | 54 | with warnings.catch_warnings(): 55 | warnings.filterwarnings("ignore") 56 | 57 | g = nx.from_numpy_matrix(am) 58 | 59 | if position is None: 60 | position=nx.drawing.circular_layout(g) 61 | 62 | fig = plt.figure() 63 | 64 | if cls is None: 65 | cls='r' 66 | else: 67 | # Make a user-defined colormap. 68 | cm1 = mcol.LinearSegmentedColormap.from_list("MyCmapName", ["r", "b"]) 69 | 70 | # Make a normalizer that will map the time values from 71 | # [start_time,end_time+1] -> [0,1]. 72 | cnorm = mcol.Normalize(vmin=0, vmax=1) 73 | 74 | # Turn these into an object that can be used to map time values to colors and 75 | # can be passed to plt.colorbar(). 76 | cpick = cm.ScalarMappable(norm=cnorm, cmap=cm1) 77 | cpick.set_array([]) 78 | cls = cpick.to_rgba(cls) 79 | plt.colorbar(cpick, ax=fig.add_subplot(111)) 80 | 81 | 82 | nx.draw(g, pos=position, node_color=cls, ax=fig.add_subplot(111)) 83 | 84 | fig.savefig(os.path.join(self.plotdir, fig_name)) 85 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/priba/nmp_qc/177db7ea738a7a91f1262ce954f9c7a4a2b98849/visualization/__init__.py --------------------------------------------------------------------------------