├── .gitignore ├── ENVIRONMENT_VARIABLES.md ├── LICENSE ├── README.md ├── data ├── __init__.py ├── protein-folding │ ├── 1j52_adjacency-matrix.pickle │ ├── 1j52_features.pickle │ ├── 1j52_labels.pickle │ ├── 1j53_adjacency-matrix.pickle │ ├── 1j53_features.pickle │ ├── 1j53_labels.pickle │ ├── 1j54_adjacency-matrix.pickle │ ├── 1j54_features.pickle │ ├── 1j54_labels.pickle │ ├── 1j55_adjacency-matrix.pickle │ ├── 1j55_features.pickle │ ├── 1j55_labels.pickle │ ├── 1j56_adjacency-matrix.pickle │ ├── 1j56_features.pickle │ ├── 1j56_labels.pickle │ ├── 1j57_adjacency-matrix.pickle │ ├── 1j57_features.pickle │ ├── 1j57_labels.pickle │ ├── 1j58_adjacency-matrix.pickle │ ├── 1j58_features.pickle │ ├── 1j58_labels.pickle │ ├── 1j59_adjacency-matrix.pickle │ ├── 1j59_features.pickle │ ├── 1j59_labels.pickle │ ├── 1j5a_adjacency-matrix.pickle │ ├── 1j5a_features.pickle │ ├── 1j5a_labels.pickle │ ├── 1j5d_adjacency-matrix.pickle │ ├── 1j5d_features.pickle │ ├── 1j5d_labels.pickle │ ├── 1j5j_adjacency-matrix.pickle │ ├── 1j5j_features.pickle │ ├── 1j5j_labels.pickle │ ├── 1j5k_adjacency-matrix.pickle │ ├── 1j5k_features.pickle │ ├── 1j5k_labels.pickle │ ├── 1j5l_adjacency-matrix.pickle │ ├── 1j5l_features.pickle │ ├── 1j5l_labels.pickle │ ├── 1j5m_adjacency-matrix.pickle │ ├── 1j5m_features.pickle │ ├── 1j5m_labels.pickle │ ├── 1j5p_adjacency-matrix.pickle │ ├── 1j5p_features.pickle │ ├── 1j5p_labels.pickle │ ├── 1j5t_adjacency-matrix.pickle │ ├── 1j5t_features.pickle │ ├── 1j5t_labels.pickle │ ├── 1j5u_adjacency-matrix.pickle │ ├── 1j5u_features.pickle │ ├── 1j5u_labels.pickle │ ├── 1j5x_adjacency-matrix.pickle │ ├── 1j5x_features.pickle │ ├── 1j5x_labels.pickle │ ├── 1j5y_adjacency-matrix.pickle │ ├── 1j5y_features.pickle │ ├── 1j5y_labels.pickle │ ├── 2j51_adjacency-matrix.pickle │ ├── 2j51_features.pickle │ ├── 2j51_labels.pickle │ ├── 2j52_adjacency-matrix.pickle │ ├── 2j52_features.pickle │ ├── 2j52_labels.pickle │ ├── 2j53_adjacency-matrix.pickle │ ├── 2j53_features.pickle │ ├── 2j53_labels.pickle │ ├── 2j5a_adjacency-matrix.pickle │ ├── 2j5a_features.pickle │ ├── 2j5a_labels.pickle │ ├── 2j5e_adjacency-matrix.pickle │ ├── 2j5e_features.pickle │ ├── 2j5e_labels.pickle │ ├── 2j5f_adjacency-matrix.pickle │ ├── 2j5f_features.pickle │ ├── 2j5f_labels.pickle │ ├── 2j5h_adjacency-matrix.pickle │ ├── 2j5h_features.pickle │ ├── 2j5h_labels.pickle │ ├── 2j5m_adjacency-matrix.pickle │ ├── 2j5m_features.pickle │ ├── 2j5m_labels.pickle │ ├── 2j5u_adjacency-matrix.pickle │ ├── 2j5u_features.pickle │ ├── 2j5u_labels.pickle │ ├── 2j5x_adjacency-matrix.pickle │ ├── 2j5x_features.pickle │ ├── 2j5x_labels.pickle │ ├── 2j5y_adjacency-matrix.pickle │ ├── 2j5y_features.pickle │ ├── 2j5y_labels.pickle │ ├── 3j5v_adjacency-matrix.pickle │ ├── 3j5v_features.pickle │ └── 3j5v_labels.pickle ├── sample-dataset │ ├── 0_training_adjacency-matrix.pickle │ ├── 0_training_features.pickle │ ├── 0_training_labels.pickle │ ├── 1_training_adjacency-matrix.pickle │ ├── 1_training_features.pickle │ ├── 1_training_labels.pickle │ ├── 2_training_adjacency-matrix.pickle │ ├── 2_training_features.pickle │ ├── 2_training_labels.pickle │ ├── 3_training_adjacency-matrix.pickle │ ├── 3_training_features.pickle │ ├── 3_training_labels.pickle │ ├── 4_training_adjacency-matrix.pickle │ ├── 4_training_features.pickle │ ├── 4_training_labels.pickle │ ├── 5_training_adjacency-matrix.pickle │ ├── 5_training_features.pickle │ ├── 5_training_labels.pickle │ └── __init__.py └── test_training_dataset.py ├── environment.yml ├── example_notebooks ├── Grid search.ipynb └── Train a single configuration.ipynb ├── forward_pass_equations.png ├── grid-search.sh ├── inference.sh ├── linux_build.sh ├── linux_build_cpu.sh ├── macos_build.sh ├── macos_build_cpu.sh ├── message_passing_nn ├── __init__.py ├── cli.py ├── create_message_passing_nn.py ├── data │ ├── __init__.py │ ├── data_preprocessor.py │ └── preprocessor.py ├── fixtures │ ├── __init__.py │ ├── characters.py │ └── filenames.py ├── graph │ ├── __init__.py │ ├── rnn_encoder.cpp │ ├── rnn_encoder.py │ ├── rnn_encoder_cuda.cpp │ └── rnn_encoder_cuda_kernel.cu ├── infrastructure │ ├── __init__.py │ ├── file_system_repository.py │ └── graph_dataset.py ├── model │ ├── __init__.py │ ├── inferencer.py │ ├── loader.py │ └── trainer.py ├── usecase │ ├── __init__.py │ ├── grid_search.py │ ├── inference.py │ └── usecase.py └── utils │ ├── __init__.py │ ├── derivatives.cpp │ ├── derivatives.h │ ├── grid_search_parameters_parser.py │ ├── loss_function_selector.py │ ├── loss_functions.py │ ├── messages.cpp │ ├── messages.h │ ├── model_selector.py │ ├── models.py │ ├── optimizer_selector.py │ ├── optimizers.py │ └── saver.py ├── parameters ├── __init__.py ├── grid-search-parameters.sh └── inference-parameters.sh ├── requirements.txt ├── setup.cfg ├── setup.py ├── setup_cpu.py └── tests ├── __init__.py ├── data ├── __init__.py └── test_data_preprocessor.py ├── fixtures ├── __init__.py ├── loss_functions.py ├── matrices_and_vectors.py └── optimizers.py ├── graph ├── __init__.py └── test_rnn_encoder.py ├── infrastructure ├── __init__.py └── test_file_system_repository.py ├── model ├── __init__.py ├── test_inferencer.py ├── test_loader.py └── test_trainer.py ├── test_data ├── __init__.py ├── model-checkpoints-test │ ├── __init__.py │ └── configuration&id__model&RNN__epochs&10__loss_function&MSE__optimizer&Adagrad__batch_size&100__validation_split&0.2__test_split&0.1__time_steps&1__validation_period&5 │ │ └── Epoch_5_model_state_dictionary.pth ├── repo-test-data │ └── __init__.py └── training-test-data │ └── __init__.py ├── usecase ├── __init__.py ├── test_grid_search.py └── test_inference.py └── utils ├── __init__.py └── test_grid_search_parameters_parser.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | /.tox/ 4 | /graph_to_graph.egg-info/ 5 | */__pycache__/ 6 | */*/__pycache__/ 7 | SUCCESS 8 | /dist/ 9 | *.egg-info/ 10 | .pickle -------------------------------------------------------------------------------- /ENVIRONMENT_VARIABLES.md: -------------------------------------------------------------------------------- 1 | **GENERAL PARAMETERS** 2 | 3 | - Your dataset folder is defined by: 4 | 5 | DATASET_NAME='sample-dataset' 6 | 7 | - Your dataset directory is defined by: 8 | 9 | DATA_DIRECTORY='data/' 10 | 11 | - The directory to save the model checkpoints is defined by: 12 | 13 | MODEL_DIRECTORY='model_checkpoints' 14 | 15 | - The directory to save the grid search results per configuration is defined by: 16 | 17 | RESULTS_DIRECTORY='grid_search_results' 18 | 19 | - The option to run the model on 'cpu' or 'cuda' can be controlled by: 20 | 21 | DEVICE='cpu' 22 | 23 | **USED FOR GRID SEARCH** 24 | 25 | To define a range for the grid search please pass the values in the following format: 26 | 1. For numeric ranges: ENVVAR='min_value&max_value&number_of_values' (e.g. '10&15&2') 27 | 2. For string ranges: ENVVAR='selection_1&selection_2' (e.g. 'SGD&Adam') 28 | 29 | - The model to use (only 'RNN' available at this version') is defined by : 30 | 31 | MODEL='RNN' 32 | 33 | - The total number of epochs can be controlled by: 34 | 35 | EPOCHS='10' 36 | 37 | - The choice of the loss function can be controlled by (see message_passing_nn/utils/loss_functions.py for a full list): 38 | 39 | LOSS_FUNCTION='MSE' 40 | 41 | - The choice of the optimizer can be controlled by (see message_passing_nn/utils/optimizers.py for a full list): 42 | 43 | OPTIMIZER='SGD' 44 | 45 | - The batch size can be controlled by: 46 | 47 | BATCH_SIZE='1' 48 | 49 | - The validation split can be controlled by: 50 | 51 | VALIDATION_SPLIT='0.2' 52 | 53 | - The test split can be controlled by: 54 | 55 | TEST_SPLIT='0.1' 56 | 57 | - The message passing time steps can be controlled by: 58 | 59 | TIME_STEPS='5' 60 | 61 | - The number of epochs to evaluate the model on the validation set can be controlled by: 62 | 63 | VALIDATION_PERIOD='5' 64 | 65 | **USED FOR INFERENCE** 66 | 67 | - The model to load (only 'RNN' available at this version') is defined by : 68 | 69 | MODEL='RNN' -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Michail Kovanis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Table of contents 2 | - [1. Description](#1-description) 3 | - [2. Model architecture](#2-model-architecture) 4 | - [3. Build and use](#3-build-and-use) 5 | - [4. Import and use](#4-import-as-package) 6 | - [5. Examples](#5-examples) 7 | - [6. Requirements](#6-requirements) 8 | - [7. Environment](#7-environment) 9 | - [8. Dataset](#8-dataset) 10 | - [9. Environment variables](#9-environment-variables) 11 | - [10. Execute a grid search](#10-execute-a-grid-search) 12 | - [11. Execute an inference](#11-execute-an-inference) 13 | 14 | 15 | ### 1. Description 16 | 17 | This repository contains: 18 | 1. A pytorch C++ implementation of a message passing neural network with RNN units (inspired from https://arxiv.org/abs/1812.01070). 19 | 2. A python wrapper around the model to perform a grid search, and save model checkpoints for each validation step. 20 | 3. A script to perform an inference on a dataset based on a specific model checkpoint. 21 | 4. A custom CUDA kernel. 22 | 23 | ### 2. Model Architecture 24 | 25 | ![Model Architecture](./forward_pass_equations.png) 26 | 27 | ### 3. Build and use 28 | 29 | To use the current version (master or tag >= 1.5.0) you need to first build the project. Please clone the repository and then run the build scripts depending on your OS and whether you have a CUDA enabled GPU available. To build for GPU you first need to set the CUDA_HOME variable in the respective .sh file (here it defaults to usr/local/cuda). 30 | 31 | Linux (CPU & GPU) 32 | ``` 33 | . linux_build.sh 34 | ``` 35 | 36 | Linux (CPU only) 37 | ``` 38 | . linux_build_cpu.sh 39 | ``` 40 | 41 | macOS (CPU & GPU) 42 | ``` 43 | . macos_build.sh 44 | ``` 45 | 46 | macOS (CPU only) 47 | ``` 48 | . macos_build_cpu.sh 49 | ``` 50 | 51 | Then you can use the code as in the [examples](#4-examples) or perform a [grid search](#9-execute-a-grid-search). 52 | 53 | ### 4. Import as package 54 | If you can't build the project you can install the pure python version of the project (version 1.4.2) using pip: 55 | 56 | ``` 57 | pip install message-passing-nn 58 | ``` 59 | 60 | ### 5. Examples 61 | 62 | The code can be used to either train a single configuration of the message passing neural network or to perform a grid search. For usage examples (v1.4.2) please look in the example_notebooks/ directory or on the [colab notebook](https://colab.research.google.com/drive/1jFJ7l7jIv22BhvvzlmXOWFtgBE15ea2X). 63 | 64 | ### 6. Requirements 65 | 66 | Python 3.7.6 67 | 68 | Run 69 | ``` 70 | click 71 | torch=1.5.0 72 | numpy==1.17.4 73 | pandas=1.0.3 74 | tqdm 75 | ``` 76 | 77 | Tests 78 | ``` 79 | numpy==1.17.4 80 | torch=1.5.0 81 | pandas=1.0.3 82 | ``` 83 | 84 | ### 7. Environment 85 | To create the "message-passing-neural-network" conda environment please run: 86 | 87 | ``` 88 | conda env create -f environment.yml 89 | ``` 90 | 91 | ### 8. Dataset 92 | 93 | The repository expects the data to be in the following format: 94 | 95 | - filenames: something_features.pickle, something_adjacency-matrix.pickle & something_labels.pickle 96 | - features: torch.tensor.Size([M,N]) 97 | - adjacency-matrix: torch.tensor.Size([M,M]) 98 | - labels: torch.tensor.Size([L]) 99 | 100 | **All features and labels should be preprocessed to be of the same size** 101 | 102 | For example, in the protein-folding dataset: 103 | 104 | - M: represents the number of amino acids 105 | - N: represents the number of protein features 106 | - L: represents the number of values to predict 107 | 108 | This repository contains two dataset folders with examples of data to run the code: 109 | 110 | - sample-dataset: Contains just one pair of features/labels with some default values. This data lets you run the code in demo mode. 111 | - protein-folding: Contains pairs of features/labels for various proteins (prepared using https://github.com/simonholmes001/structure_prediction). The features represent protein characteristics, and the labels the distance between all aminoacids. 112 | 113 | ### 9. Environment variables 114 | 115 | The model and grid search can be set up using a set of environment variables contained in the grid-search-parameters.sh. Please refer to the ENVIRONMENT_VARIABLES.md for the full list of available environment variables and how to use them. 116 | 117 | ### 10. Execute a grid search 118 | 119 | Before executing a grid-search please go to the grid-search.sh to add your PYTHONPATH=path/to/message-passing-nn/. 120 | 121 | The grid search can be executed by executing a shell script: 122 | ``` 123 | . grid-search.sh 124 | ``` 125 | 126 | This script will: 127 | 128 | 1. Create the conda environment from the environment.yml (if not created already) 129 | 2. Activate it 130 | 3. Export the PYTHONPATH=path/to/message-passing-nn/ (line needs to be uncommented first) 131 | 4. Export the environment variables to be used for the Grid Search 132 | 5. Run the grid search 133 | 6. Save model checkpoints for each validation and a csv file containing all calculated losses 134 | 135 | ### 11. Execute an inference 136 | 137 | Before executing an inference please go to the inference.sh to add your PYTHONPATH=path/to/message-passing-nn/. Please also make sure that the dataset used for inference is of same dimensions (M, N, L) as the one used to train the model. 138 | 139 | The grid search can be executed by executing a shell script: 140 | ``` 141 | . inference.sh 142 | ``` 143 | 144 | This script will: 145 | 146 | 1. Create the conda environment from the environment.yml (if not created already) 147 | 2. Activate it 148 | 3. Export the PYTHONPATH=path/to/message-passing-nn/ (line needs to be uncommented first) 149 | 4. Export the environment variables to be used for the Inference 150 | 5. Run the inference 151 | 6. Save results as a list of (output, label, tag) for each input 152 | 153 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/__init__.py -------------------------------------------------------------------------------- /data/protein-folding/1j52_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j52_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j52_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j52_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j52_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j52_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j53_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j53_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j53_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j53_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j53_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j53_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j54_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j54_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j54_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j54_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j54_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j54_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j55_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j55_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j55_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j55_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j55_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j55_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j56_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j56_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j56_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j56_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j56_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j56_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j57_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j57_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j57_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j57_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j57_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j57_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j58_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j58_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j58_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j58_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j58_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j58_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j59_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j59_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j59_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j59_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j59_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j59_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5a_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5a_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5a_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5a_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5a_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5a_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5d_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5d_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5d_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5d_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5d_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5d_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5j_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5j_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5j_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5j_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5j_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5j_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5k_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5k_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5k_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5k_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5k_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5k_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5l_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5l_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5l_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5l_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5l_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5l_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5m_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5m_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5m_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5m_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5m_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5m_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5p_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5p_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5p_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5p_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5p_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5p_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5t_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5t_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5t_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5t_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5t_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5t_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5u_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5u_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5u_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5u_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5u_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5u_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5x_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5x_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5x_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5x_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5x_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5x_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5y_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5y_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5y_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5y_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/1j5y_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/1j5y_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j51_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j51_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j51_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j51_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j51_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j51_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j52_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j52_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j52_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j52_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j52_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j52_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j53_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j53_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j53_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j53_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j53_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j53_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5a_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5a_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5a_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5a_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5a_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5a_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5e_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5e_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5e_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5e_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5e_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5e_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5f_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5f_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5f_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5f_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5f_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5f_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5h_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5h_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5h_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5h_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5h_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5h_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5m_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5m_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5m_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5m_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5m_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5m_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5u_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5u_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5u_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5u_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5u_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5u_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5x_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5x_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5x_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5x_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5x_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5x_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5y_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5y_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5y_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5y_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/2j5y_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/2j5y_labels.pickle -------------------------------------------------------------------------------- /data/protein-folding/3j5v_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/3j5v_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/protein-folding/3j5v_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/3j5v_features.pickle -------------------------------------------------------------------------------- /data/protein-folding/3j5v_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/protein-folding/3j5v_labels.pickle -------------------------------------------------------------------------------- /data/sample-dataset/0_training_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/0_training_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/sample-dataset/0_training_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/0_training_features.pickle -------------------------------------------------------------------------------- /data/sample-dataset/0_training_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/0_training_labels.pickle -------------------------------------------------------------------------------- /data/sample-dataset/1_training_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/1_training_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/sample-dataset/1_training_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/1_training_features.pickle -------------------------------------------------------------------------------- /data/sample-dataset/1_training_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/1_training_labels.pickle -------------------------------------------------------------------------------- /data/sample-dataset/2_training_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/2_training_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/sample-dataset/2_training_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/2_training_features.pickle -------------------------------------------------------------------------------- /data/sample-dataset/2_training_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/2_training_labels.pickle -------------------------------------------------------------------------------- /data/sample-dataset/3_training_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/3_training_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/sample-dataset/3_training_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/3_training_features.pickle -------------------------------------------------------------------------------- /data/sample-dataset/3_training_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/3_training_labels.pickle -------------------------------------------------------------------------------- /data/sample-dataset/4_training_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/4_training_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/sample-dataset/4_training_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/4_training_features.pickle -------------------------------------------------------------------------------- /data/sample-dataset/4_training_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/4_training_labels.pickle -------------------------------------------------------------------------------- /data/sample-dataset/5_training_adjacency-matrix.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/5_training_adjacency-matrix.pickle -------------------------------------------------------------------------------- /data/sample-dataset/5_training_features.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/5_training_features.pickle -------------------------------------------------------------------------------- /data/sample-dataset/5_training_labels.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/5_training_labels.pickle -------------------------------------------------------------------------------- /data/sample-dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/data/sample-dataset/__init__.py -------------------------------------------------------------------------------- /data/test_training_dataset.py: -------------------------------------------------------------------------------- 1 | import torch as to 2 | 3 | BASE_GRAPH = to.tensor([[0, 1, 1, 0], 4 | [1, 0, 1, 0], 5 | [1, 1, 0, 1], 6 | [0, 0, 1, 0]]) 7 | BASE_GRAPH_NODE_FEATURES = to.tensor([[1, 2], [1, 1], [2, 0.5], [0.5, 0.5]]) 8 | BASE_GRAPH_EDGE_FEATURES = to.tensor([[[0.0, 0.0], [1.0, 2.0], [2.0, 0.5], [0.0, 0.0]], 9 | [[1.0, 2.0], [0.0, 0.0], [1.0, 1.0], [0.0, 0.0]], 10 | [[2.0, 0.5], [1.0, 1.0], [0.0, 0.0], [0.5, 0.5]], 11 | [[0.0, 0.0], [0.0, 0.0], [0.5, 0.5], [0.0, 0.0]]]) 12 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: message-passing-neural-network 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - python=3.7.6 8 | - tox=3.14.3 9 | - numpy=1.17.4 10 | - pandas=1.0.3 11 | - pytorch=1.5.0 12 | - setuptools=46.0.0 13 | - click 14 | - tqdm -------------------------------------------------------------------------------- /example_notebooks/Grid search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Install requirements\n", 8 | "Uncomment and install the requirements" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "# !pip install -U message-passing-nn" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Clone the repository to get the data folders" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# !git clone https://github.com/kovanostra/message-passing-nn/" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Imports" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "from message_passing_nn.create_message_passing_nn import create" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## Initialize a grid search" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "message_passing_nn = create(dataset_name='the-name-of-the-directory-containing-your-data', #e.g. 'sample-dataset/'\n", 66 | " data_directory='the-path-to-the-directory-containing-all-your-datasets', #e.g. '~/message-passing-nn/data/'\n", 67 | " model_directory='model_checkpoints',\n", 68 | " results_directory='grid_search_results',\n", 69 | " model='RNN',\n", 70 | " device='cpu',\n", 71 | " epochs='10&15&2', # This will create a linspace from 10 to 15 with 2 values\n", 72 | " loss_function_selection='MSE',\n", 73 | " optimizer_selection='SGD',\n", 74 | " batch_size='1',\n", 75 | " maximum_number_of_features='-1',\n", 76 | " maximum_number_of_nodes='-1',\n", 77 | " validation_split='0.2&0.3&2', # This will create a linspace from 0.2 to 0.3 with 2 values\n", 78 | " test_split='0.1',\n", 79 | " time_steps='2&5&2',\n", 80 | " validation_period='5&15&3')" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Start a grid search" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "message_passing_nn.start()" 97 | ] 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "Python 3", 103 | "language": "python", 104 | "name": "python3" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 3 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython3", 116 | "version": "3.7.4" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 2 121 | } 122 | -------------------------------------------------------------------------------- /example_notebooks/Train a single configuration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Install requirements\n", 8 | "Uncomment and install the requirements" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "# !pip install -U message-passing-nn" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Clone the infrastructure to get the data folders" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# !git clone https://github.com/kovanostra/message-passing-nn/" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Imports" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import torch\n", 50 | "import datetime\n", 51 | "from message_passing_nn.model.model_trainer import ModelTrainer\n", 52 | "from message_passing_nn.graph.graph_rnn_encoder import GraphRNNEncoder\n", 53 | "from message_passing_nn.graph.graph_gru_encoder import GraphGRUEncoder\n", 54 | "from message_passing_nn.data.data_preprocessor import DataPreprocessor\n", 55 | "from message_passing_nn.infrastructure.file_system_repository import FileSystemRepository" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Set up the variables " 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "device = \"cpu\" # You can use \"cuda\" for RNNEncoder, but it is currently adviced to use \"cpu\" for the GRUEncoder\n", 72 | "epochs = 10\n", 73 | "model = 'RNN'\n", 74 | "loss_function = 'MSE'\n", 75 | "optimizer = 'SGD'\n", 76 | "batch_size = 5\n", 77 | "maximum_number_of_nodes = 250 # Some of the adjacency matrices in our dataset are too big, this variable controls the maximum size of the matrices to load. To load the whole dataset set this value to -1.\n", 78 | "maximum_number_of_features = 10 # Similarly for the number of features\n", 79 | "validation_split = 0.2\n", 80 | "test_split = 0.1\n", 81 | "time_steps = 1 # The time steps of the message passing algorithm\n", 82 | "validation_period = 20\n", 83 | "\n", 84 | "configuration_dictionary = {'time_steps': time_steps,\n", 85 | " 'model': model,\n", 86 | " 'loss_function': loss_function,\n", 87 | " 'optimizer': optimizer}" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## Prerocess the dataset\n", 95 | "We load the protein-folding datacet in which each graph contains three pickle files:\n", 96 | " 1. The features of each node (as torch.tensor.Size([M,N]))\n", 97 | " 2. The adjacency matrix (as torch.tensor.Size([M,M]))\n", 98 | " 3. The labels to predict (as torch.tensor.Size([L]))\n", 99 | "\n", 100 | "where M is the number of graph nodes, N the number of features per node, and L the number of values to predict.\n", 101 | "\n", 102 | "The dataset contains features and labels from 31 proteins from (https://www.rcsb.org). We apply a limit to the size of the proteins (to not crush the runtime) to we end up with 17 proteins which we equalize in size and split into training, validation and test datasets." 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "dataset_name = 'protein-folding'\n", 112 | "data_directory = 'message-passing-nn/data/'\n", 113 | "file_system_repository = FileSystemRepository(data_directory, dataset_name)\n", 114 | "raw_dataset = file_system_repository.get_all_data()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "Please uncomment the following block to see examples of the data used as input to the model." 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# node_features_example, adjacency_matrix_example, labels_example = raw_dataset[0]\n", 131 | "# print(node_features_example.size(), adjacency_matrix_example.size(), labels_example.size())" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "### Next we equalize the tensor sizes and split to train, validation and test sets" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "data_preprocessor = DataPreprocessor()\n", 148 | "equalized_dataset = data_preprocessor.equalize_dataset_dimensions(raw_dataset,\n", 149 | " maximum_number_of_nodes,\n", 150 | " maximum_number_of_features)\n", 151 | "training_data, validation_data, test_data = data_preprocessor.train_validation_test_split(equalized_dataset, \n", 152 | " batch_size, \n", 153 | " validation_split, \n", 154 | " test_split)\n", 155 | "data_dimensions = data_preprocessor.extract_data_dimensions(equalized_dataset)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "## Instantiate the model and the trainer\n", 163 | "\n", 164 | "The Trainer is responsible for the instantiation, training and evaluation of the model. It also controls whether a mini-batch normalization over the node features and labels should be applied. The ModelTrainer can use either the RnnEncoder or the GRUEncoder." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "configuration_dictionary = {'time_steps': time_steps,\n", 174 | " 'model': model,\n", 175 | " 'loss_function': loss_function,\n", 176 | " 'optimizer': optimizer}\n", 177 | "model_trainer = ModelTrainer(data_preprocessor, device)\n", 178 | "model_trainer.instantiate_attributes(data_dimensions, configuration_dictionary)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "## Train the model\n", 186 | "This block will train the model and output the training, validation and test losses along with the time. Our use case contains fully connected graphs and therefore the time to train is significantly longer than for sparsely connected graphs." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "for epoch in range(epochs):\n", 196 | " training_loss = model_trainer.do_train(training_data, epoch)\n", 197 | " print(\"Epoch\", epoch, \"Training loss:\", training_loss)\n", 198 | " if epoch % validation_period == 0:\n", 199 | " validation_loss = model_trainer.do_evaluate(validation_data, epoch)\n", 200 | " print(\"Epoch\", epoch, \"Validation loss:\", validation_loss)\n", 201 | "test_loss = model_trainer.do_evaluate(test_data)\n", 202 | "print(\"Test loss:\", validation_loss)" 203 | ] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 3", 209 | "language": "python", 210 | "name": "python3" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.7.4" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | -------------------------------------------------------------------------------- /forward_pass_equations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/forward_pass_equations.png -------------------------------------------------------------------------------- /grid-search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda env create -f environment.yml 3 | conda activate message-passing-neural-network 4 | #export PYTHONPATH=path/to/message-passing-neural-network/ 5 | . parameters/grid-search-parameters.sh 6 | python message_passing_nn/cli.py grid-search -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda env create -f environment.yml 3 | conda activate message-passing-neural-network 4 | #export PYTHONPATH=path/to/message-passing-neural-network/ 5 | . parameters/inference-parameters.sh 6 | python message_passing_nn/cli.py inference -------------------------------------------------------------------------------- /linux_build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_HOME=/usr/local/cuda 4 | export CC=gcc 5 | export CXX=g++ 6 | python setup.py install -------------------------------------------------------------------------------- /linux_build_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CC=gcc 4 | export CXX=g++ 5 | python setup_cpu.py install -------------------------------------------------------------------------------- /macos_build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_HOME=/usr/local/cuda 4 | export MACOSX_DEPLOYMENT_TARGET=10.11 5 | export CC=clang 6 | export CXX=clang++ 7 | python setup.py install 8 | -------------------------------------------------------------------------------- /macos_build_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export MACOSX_DEPLOYMENT_TARGET=10.11 4 | export CC=clang 5 | export CXX=clang++ 6 | python setup_cpu.py install 7 | -------------------------------------------------------------------------------- /message_passing_nn/__init__.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.create_message_passing_nn import create_grid_search 2 | from message_passing_nn.create_message_passing_nn import create_inference 3 | from message_passing_nn.cli import main 4 | -------------------------------------------------------------------------------- /message_passing_nn/cli.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import click 4 | import sys 5 | 6 | from message_passing_nn.create_message_passing_nn import create_grid_search, create_inference 7 | 8 | 9 | @click.group("message-passing-nn") 10 | @click.option('--debug', default=False, help='Set the logs to debug level', show_default=True, is_flag=True) 11 | def main(debug): 12 | log_level = logging.DEBUG if debug else logging.INFO 13 | setup_logging(log_level) 14 | 15 | 16 | @click.command('grid-search', help='Starts the grid search') 17 | @click.argument('dataset_name', envvar='DATASET_NAME', type=str) 18 | @click.argument('data_directory', envvar='DATA_DIRECTORY', type=str) 19 | @click.argument('model_directory', envvar='MODEL_DIRECTORY', type=str) 20 | @click.argument('results_directory', envvar='RESULTS_DIRECTORY', type=str) 21 | @click.argument('model', envvar='MODEL', type=str) 22 | @click.argument('device', envvar='DEVICE', type=str) 23 | @click.argument('epochs', envvar='EPOCHS', type=str) 24 | @click.argument('loss_function', envvar='LOSS_FUNCTION', type=str) 25 | @click.argument('optimizer', envvar='OPTIMIZER', type=str) 26 | @click.argument('batch_size', envvar='BATCH_SIZE', type=str) 27 | @click.argument('validation_split', envvar='VALIDATION_SPLIT', type=str) 28 | @click.argument('test_split', envvar='TEST_SPLIT', type=str) 29 | @click.argument('time_steps', envvar='TIME_STEPS', type=str) 30 | @click.argument('validation_period', envvar='VALIDATION_PERIOD', type=str) 31 | def start_training(dataset_name: str, 32 | data_directory: str, 33 | model_directory: str, 34 | results_directory: str, 35 | model: str, 36 | device: str, 37 | epochs: str, 38 | loss_function: str, 39 | optimizer: str, 40 | batch_size: str, 41 | validation_split: str, 42 | test_split: str, 43 | time_steps: str, 44 | validation_period: str) -> None: 45 | message_passing_nn = create_grid_search(dataset_name, 46 | data_directory, 47 | model_directory, 48 | results_directory, 49 | model, 50 | device, 51 | epochs, 52 | loss_function, 53 | optimizer, 54 | batch_size, 55 | validation_split, 56 | test_split, 57 | time_steps, 58 | validation_period) 59 | message_passing_nn.start() 60 | 61 | 62 | @click.command('inference', help='Starts the inference') 63 | @click.argument('dataset_name', envvar='DATASET_NAME', type=str) 64 | @click.argument('data_directory', envvar='DATA_DIRECTORY', type=str) 65 | @click.argument('model_directory', envvar='MODEL_DIRECTORY', type=str) 66 | @click.argument('results_directory', envvar='RESULTS_DIRECTORY', type=str) 67 | @click.argument('model', envvar='MODEL', type=str) 68 | @click.argument('device', envvar='DEVICE', type=str) 69 | def start_inference(dataset_name: str, 70 | data_directory: str, 71 | model_directory: str, 72 | results_directory: str, 73 | model: str, 74 | device: str) -> None: 75 | get_logger().info("Starting inference") 76 | message_passing_nn = create_inference(dataset_name, 77 | data_directory, 78 | model_directory, 79 | results_directory, 80 | model, 81 | device) 82 | message_passing_nn.start() 83 | 84 | 85 | def setup_logging(log_level): 86 | get_logger().setLevel(log_level) 87 | 88 | logOutputFormatter = logging.Formatter( 89 | '%(asctime)s %(levelname)s - %(message)s [%(filename)s:%(lineno)s] [%(relativeCreated)d]') 90 | 91 | stdoutStreamHandler = logging.StreamHandler(sys.stdout) 92 | stdoutStreamHandler.setLevel(log_level) 93 | stdoutStreamHandler.setFormatter(logOutputFormatter) 94 | 95 | get_logger().addHandler(stdoutStreamHandler) 96 | 97 | stderrStreamHandler = logging.StreamHandler(sys.stdout) 98 | stderrStreamHandler.setLevel(logging.WARNING) 99 | stderrStreamHandler.setFormatter(logOutputFormatter) 100 | 101 | get_logger().addHandler(stderrStreamHandler) 102 | 103 | 104 | def get_logger() -> logging.Logger: 105 | return logging.getLogger('message_passing_nn') 106 | 107 | 108 | main.add_command(start_training) 109 | main.add_command(start_inference) 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /message_passing_nn/create_message_passing_nn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 4 | from message_passing_nn.model.inferencer import Inferencer 5 | from message_passing_nn.model.loader import Loader 6 | from message_passing_nn.model.trainer import Trainer 7 | from message_passing_nn.usecase import Usecase 8 | from message_passing_nn.usecase.grid_search import GridSearch 9 | from message_passing_nn.usecase.inference import Inference 10 | from message_passing_nn.utils.grid_search_parameters_parser import GridSearchParametersParser 11 | from message_passing_nn.utils.saver import Saver 12 | 13 | 14 | class MessagePassingNN: 15 | def __init__(self, usecase: Usecase) -> None: 16 | self.usecase = usecase 17 | 18 | def start(self): 19 | try: 20 | self.usecase.start() 21 | except Exception: 22 | get_logger().exception("message") 23 | 24 | 25 | def create_grid_search(dataset_name: str, 26 | data_directory: str, 27 | model_directory: str, 28 | results_directory: str, 29 | model: str, 30 | device: str, 31 | epochs: str, 32 | loss_function_selection: str, 33 | optimizer_selection: str, 34 | batch_size: str, 35 | validation_split: str, 36 | test_split: str, 37 | time_steps: str, 38 | validation_period: str) -> MessagePassingNN: 39 | grid_search_dictionary = GridSearchParametersParser().get_grid_search_dictionary(model, 40 | epochs, 41 | loss_function_selection, 42 | optimizer_selection, 43 | batch_size, 44 | validation_split, 45 | test_split, 46 | time_steps, 47 | validation_period) 48 | data_path = _get_data_path(data_directory, dataset_name) 49 | data_preprocessor = DataPreprocessor() 50 | trainer = Trainer(data_preprocessor, device) 51 | saver = Saver(model_directory, results_directory) 52 | grid_search = GridSearch(data_path, 53 | data_preprocessor, 54 | trainer, 55 | grid_search_dictionary, 56 | saver) 57 | return MessagePassingNN(grid_search) 58 | 59 | 60 | def create_inference(dataset_name: str, 61 | data_directory: str, 62 | model_directory: str, 63 | results_directory: str, 64 | model: str, 65 | device: str) -> MessagePassingNN: 66 | data_path = data_directory + dataset_name + "/" 67 | data_preprocessor = DataPreprocessor() 68 | model_loader = Loader(model) 69 | model_inferencer = Inferencer(data_preprocessor, device) 70 | saver = Saver(model_directory, results_directory) 71 | inference = Inference(data_path, data_preprocessor, model_loader, model_inferencer, saver) 72 | return MessagePassingNN(inference) 73 | 74 | 75 | def _get_data_path(data_directory: str, dataset_name: str) -> str: 76 | return data_directory + dataset_name + "/" 77 | 78 | 79 | def get_logger() -> logging.Logger: 80 | return logging.getLogger('message_passing_nn') 81 | -------------------------------------------------------------------------------- /message_passing_nn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 2 | from message_passing_nn.data.preprocessor import Preprocessor 3 | -------------------------------------------------------------------------------- /message_passing_nn/data/data_preprocessor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, List 3 | 4 | import torch as to 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.sampler import SubsetRandomSampler 8 | from tqdm import tqdm 9 | 10 | from message_passing_nn.infrastructure.graph_dataset import GraphDataset 11 | from message_passing_nn.data.preprocessor import Preprocessor 12 | 13 | 14 | class DataPreprocessor(Preprocessor): 15 | def __init__(self): 16 | super().__init__() 17 | self.test_mode = False 18 | 19 | def train_validation_test_split(self, 20 | dataset: GraphDataset, 21 | batch_size: int, 22 | validation_split: float = 0.2, 23 | test_split: float = 0.1) -> Tuple[DataLoader, DataLoader, DataLoader]: 24 | test_index, validation_index = self._get_validation_and_test_indexes(dataset, 25 | validation_split, 26 | test_split) 27 | train_sampler = SubsetRandomSampler(list(range(validation_index))) 28 | validation_sampler = SubsetRandomSampler(list(range(validation_index, test_index))) 29 | test_sampler = SubsetRandomSampler(list(range(test_index, len(dataset.dataset)))) 30 | 31 | training_data = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) 32 | if validation_split: 33 | validation_data = DataLoader(dataset, batch_size=batch_size, sampler=validation_sampler) 34 | else: 35 | validation_data = DataLoader(GraphDataset([])) 36 | if test_split: 37 | test_data = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler) 38 | else: 39 | test_data = DataLoader(GraphDataset([])) 40 | self.get_logger().info("Train/validation/test split: " + "/".join([str(len(training_data)), 41 | str(len(validation_data)), 42 | str(len(test_data))]) 43 | + " batches of " + str(batch_size)) 44 | return training_data, validation_data, test_data 45 | 46 | @staticmethod 47 | def get_dataloader(dataset: GraphDataset, batch_size: int = 1) -> DataLoader: 48 | return DataLoader(dataset, batch_size) 49 | 50 | def find_all_node_neighbors(self, dataset: List[Tuple[to.Tensor, to.Tensor, to.Tensor]]) -> List[ 51 | Tuple[to.Tensor, to.Tensor, to.Tensor]]: 52 | dataset_with_neighbors = [] 53 | disable_progress_bar = self.test_mode 54 | for index in tqdm(range(len(dataset)), disable=disable_progress_bar): 55 | features, adjacency_matrix, labels = dataset[index] 56 | number_of_nodes = features.shape[0] 57 | all_neighbors = to.zeros(number_of_nodes, number_of_nodes) - to.ones(number_of_nodes, number_of_nodes) 58 | all_neighbors_list = [to.nonzero(adjacency_matrix[node_id], as_tuple=True)[0].tolist() for node_id in 59 | range(adjacency_matrix.shape[0])] 60 | for node_id in range(number_of_nodes): 61 | all_neighbors[node_id, :len(all_neighbors_list[node_id])] = to.tensor(all_neighbors_list[node_id]) 62 | dataset_with_neighbors.append((features, all_neighbors, labels)) 63 | return dataset_with_neighbors 64 | 65 | @staticmethod 66 | def extract_data_dimensions(dataset: GraphDataset) -> Tuple: 67 | node_features_size = dataset[0][0].size() 68 | labels_size = dataset[0][2].size() 69 | return node_features_size, labels_size 70 | 71 | @staticmethod 72 | def flatten(tensors: to.Tensor, desired_size: int = 0) -> to.Tensor: 73 | flattened_tensor = tensors.view(-1) 74 | if 0 < desired_size != len(flattened_tensor): 75 | flattened_tensor = DataPreprocessor._pad_zeros(flattened_tensor, desired_size) 76 | return flattened_tensor 77 | 78 | @staticmethod 79 | def normalize(tensor: to.Tensor, device: str) -> to.Tensor: 80 | if tensor.size()[0] > 1: 81 | normalizer = nn.BatchNorm1d(tensor.size()[1], affine=False).to(device) 82 | return normalizer(tensor) 83 | else: 84 | return tensor 85 | 86 | @staticmethod 87 | def _pad_zeros(flattened_tensor: to.Tensor, desired_size: int) -> to.Tensor: 88 | size_difference = abs(len(flattened_tensor) - desired_size) 89 | flattened_tensor = to.cat((flattened_tensor, to.zeros(size_difference))) 90 | return flattened_tensor 91 | 92 | @staticmethod 93 | def _get_validation_and_test_indexes(dataset: GraphDataset, 94 | validation_split: float, 95 | test_split: float) -> Tuple[int, int]: 96 | validation_index = int((1 - validation_split - test_split) * len(dataset)) 97 | test_index = int((1 - test_split) * len(dataset)) 98 | return test_index, validation_index 99 | 100 | def enable_test_mode(self) -> None: 101 | self.test_mode = True 102 | 103 | @staticmethod 104 | def get_logger() -> logging.Logger: 105 | return logging.getLogger('message_passing_nn') 106 | -------------------------------------------------------------------------------- /message_passing_nn/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Any, Tuple, List 3 | 4 | import torch as to 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class Preprocessor(metaclass=ABCMeta): 9 | def __init__(self): 10 | pass 11 | 12 | @staticmethod 13 | @abstractmethod 14 | def train_validation_test_split(dataset: Any, 15 | batch_size: int, 16 | validation_split: float, 17 | test_split: float) -> Tuple[DataLoader, DataLoader, DataLoader]: 18 | pass 19 | 20 | @staticmethod 21 | @abstractmethod 22 | def extract_data_dimensions(dataset: Any) -> Tuple: 23 | pass 24 | 25 | @staticmethod 26 | @abstractmethod 27 | def flatten(tensors: to.Tensor, desired_size: Any = 0) -> to.Tensor: 28 | pass 29 | 30 | @staticmethod 31 | @abstractmethod 32 | def normalize(tensors: to.Tensor, device: str) -> to.Tensor: 33 | pass 34 | -------------------------------------------------------------------------------- /message_passing_nn/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/message_passing_nn/fixtures/__init__.py -------------------------------------------------------------------------------- /message_passing_nn/fixtures/characters.py: -------------------------------------------------------------------------------- 1 | GRID_SEARCH_SEPARATION_CHARACTER = '&' 2 | -------------------------------------------------------------------------------- /message_passing_nn/fixtures/filenames.py: -------------------------------------------------------------------------------- 1 | RESULTS_CSV = 'results.csv' 2 | DISTANCE_MAPS = 'distance_maps.pickle' 3 | MODEL_STATE_DICTIONARY = 'model_state_dictionary.pth' 4 | EPOCH = "Epoch" 5 | -------------------------------------------------------------------------------- /message_passing_nn/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.graph.rnn_encoder import RNNEncoder 2 | -------------------------------------------------------------------------------- /message_passing_nn/graph/rnn_encoder.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../utils/messages.h" 3 | #include "../utils/derivatives.h" 4 | #include 5 | 6 | std::vector forward_cpp( 7 | const at::Tensor& time_steps, 8 | const at::Tensor& number_of_nodes, 9 | const at::Tensor& number_of_node_features, 10 | const at::Tensor& fully_connected_layer_output_size, 11 | const at::Tensor& batch_size, 12 | const at::Tensor& node_features, 13 | const at::Tensor& all_neighbors, 14 | const at::Tensor& w_graph_node_features, 15 | const at::Tensor& w_graph_neighbor_messages, 16 | const at::Tensor& u_graph_node_features, 17 | const at::Tensor& u_graph_neighbor_messages, 18 | const at::Tensor& linear_weight, 19 | const at::Tensor& linear_bias) { 20 | 21 | auto time_steps_int = time_steps.item(); 22 | auto number_of_nodes_int = number_of_nodes.item(); 23 | auto number_of_node_features_int = number_of_node_features.item(); 24 | auto fully_connected_layer_output_size_int = fully_connected_layer_output_size.item(); 25 | auto batch_size_int = batch_size.item(); 26 | auto outputs = at::zeros({batch_size_int, fully_connected_layer_output_size_int}); 27 | auto linear_outputs = at::zeros({batch_size_int, fully_connected_layer_output_size_int}); 28 | auto messages = at::zeros({batch_size_int, number_of_nodes_int, number_of_nodes_int, number_of_node_features_int}); 29 | auto messages_previous_step = at::zeros({batch_size_int, number_of_nodes_int, number_of_nodes_int, number_of_node_features_int}); 30 | auto node_encoding_messages = at::zeros({batch_size_int, number_of_nodes_int, number_of_node_features_int}); 31 | auto encodings = at::zeros({batch_size_int, number_of_nodes_int*number_of_node_features_int}); 32 | auto base_messages = at::matmul(w_graph_node_features, node_features); 33 | 34 | for (int batch = 0; batch backward_cpp( 60 | const at::Tensor& grad_output, 61 | const at::Tensor& outputs, 62 | const at::Tensor& linear_outputs, 63 | const at::Tensor& encodings, 64 | const at::Tensor& messages_summed, 65 | const at::Tensor& messages_previous_step_summed, 66 | const at::Tensor& messages, 67 | const at::Tensor& node_features, 68 | const at::Tensor& batch_size, 69 | const at::Tensor& number_of_nodes, 70 | const at::Tensor& number_of_node_features, 71 | const at::Tensor& u_graph_neighbor_messages_summed, 72 | const at::Tensor& linear_weight, 73 | const at::Tensor& linear_bias) { 74 | 75 | 76 | auto delta_1 = grad_output*d_sigmoid(linear_outputs); 77 | auto d_linear_bias = delta_1; 78 | auto d_linear_weight = at::matmul(delta_1.transpose(0, 1), encodings); 79 | 80 | auto delta_2 = at::matmul(delta_1, linear_weight).reshape({batch_size.item(), number_of_nodes.item(), number_of_node_features.item()})*(d_relu_2d(encodings).reshape({batch_size.item(), number_of_nodes.item(), number_of_node_features.item()})); 81 | auto d_u_graph_node_features = at::matmul(delta_2, node_features.transpose(1, 2)); 82 | auto d_u_graph_neighbor_messages = at::matmul(delta_2.transpose(1, 2), messages_summed); 83 | 84 | auto delta_3 = at::matmul(delta_2.transpose(1, 2), at::matmul(u_graph_neighbor_messages_summed, d_relu_4d(messages).transpose(2, 3))); 85 | auto d_w_graph_node_features = at::matmul(delta_3.transpose(1, 2), node_features.transpose(1, 2)); 86 | auto d_w_graph_neighbor_messages = at::matmul(delta_3.transpose(1, 2), messages_previous_step_summed.transpose(1, 2)); 87 | 88 | return {d_w_graph_node_features, 89 | d_w_graph_neighbor_messages, 90 | d_u_graph_node_features, 91 | d_u_graph_neighbor_messages, 92 | d_linear_weight, 93 | d_linear_bias}; 94 | } 95 | 96 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 97 | m.def("forward", &forward_cpp, "RNN encoder forward pass (CPU)"); 98 | m.def("backward", &backward_cpp, "RNN encoder backward pass (CPU)"); 99 | m.def("compose_messages", &compose_messages, "RNN compose messages (CPU)"); 100 | m.def("encode_messages", &encode_messages, "RNN encode messages (CPU)"); 101 | } -------------------------------------------------------------------------------- /message_passing_nn/graph/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import rnn_encoder_cpp as rnn_encoder_cpp 4 | try: 5 | import rnn_encoder_cuda_cpp as rnn_encoder_cuda_cpp 6 | except: 7 | pass 8 | import math 9 | import torch as to 10 | import torch.nn as nn 11 | from torch.nn import init 12 | 13 | 14 | class RNNEncoderFunction(to.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, 17 | time_steps: int, 18 | number_of_nodes: int, 19 | number_of_node_features: int, 20 | fully_connected_layer_output_size: int, 21 | batch_size: int, 22 | device: str, 23 | node_features: to.Tensor, 24 | all_neighbors: to.Tensor, 25 | w_graph_node_features: to.Tensor, 26 | w_graph_neighbor_messages: to.Tensor, 27 | u_graph_node_features: to.Tensor, 28 | u_graph_neighbor_messages: to.Tensor, 29 | linear_weight: to.Tensor, 30 | linear_bias: to.Tensor) -> to.Tensor: 31 | if device == "cuda": 32 | cpp_extension = rnn_encoder_cuda_cpp 33 | else: 34 | cpp_extension = rnn_encoder_cpp 35 | outputs, linear_outputs, encodings, messages, messages_previous_step = cpp_extension.forward( 36 | to.tensor(time_steps, device=device), 37 | to.tensor(number_of_nodes, device=device), 38 | to.tensor(number_of_node_features, device=device), 39 | to.tensor(fully_connected_layer_output_size, device=device), 40 | to.tensor(batch_size, device=device), 41 | node_features, 42 | all_neighbors, 43 | w_graph_node_features, 44 | w_graph_neighbor_messages, 45 | u_graph_node_features, 46 | u_graph_neighbor_messages, 47 | linear_weight, 48 | linear_bias) 49 | variables = [outputs, 50 | linear_outputs, 51 | encodings.view(batch_size, number_of_nodes * number_of_node_features), 52 | to.sum(to.relu(messages), dim=2), 53 | to.sum(to.relu(messages_previous_step), dim=2), 54 | messages, 55 | node_features, 56 | to.tensor([batch_size]), 57 | to.tensor([number_of_nodes]), 58 | to.tensor([number_of_node_features]), 59 | to.sum(u_graph_neighbor_messages, dim=0), 60 | linear_weight, 61 | linear_bias] 62 | ctx.save_for_backward(*variables) 63 | return outputs 64 | 65 | @staticmethod 66 | def backward(ctx, grad_outputs: to.Tensor) -> Tuple[None, 67 | None, 68 | None, 69 | None, 70 | None, 71 | None, 72 | None, 73 | None, 74 | to.Tensor, 75 | to.Tensor, 76 | to.Tensor, 77 | to.Tensor, 78 | to.Tensor, 79 | to.Tensor]: 80 | if grad_outputs.device == "cuda": 81 | cpp_extension = rnn_encoder_cuda_cpp 82 | else: 83 | cpp_extension = rnn_encoder_cpp 84 | backward_outputs = cpp_extension.backward(grad_outputs.contiguous(), *ctx.saved_tensors) 85 | d_w_graph_node_features, d_w_graph_neighbor_messages, d_u_graph_neighbor_messages, d_u_graph_node_features, d_linear_weight, d_linear_bias = backward_outputs 86 | return None, \ 87 | None, \ 88 | None, \ 89 | None, \ 90 | None, \ 91 | None, \ 92 | None, \ 93 | None, \ 94 | d_w_graph_node_features, \ 95 | d_w_graph_neighbor_messages, \ 96 | d_u_graph_neighbor_messages, \ 97 | d_u_graph_node_features, \ 98 | d_linear_weight, \ 99 | d_linear_bias 100 | 101 | 102 | class RNNEncoder(nn.Module): 103 | def __init__(self, 104 | time_steps: int, 105 | number_of_nodes: int, 106 | number_of_node_features: int, 107 | fully_connected_layer_input_size: int, 108 | fully_connected_layer_output_size: int, 109 | device: str = "cpu") -> None: 110 | super(RNNEncoder, self).__init__() 111 | 112 | self.time_steps = time_steps 113 | self.number_of_nodes = number_of_nodes 114 | self.number_of_node_features = number_of_node_features 115 | self.fully_connected_layer_input_size = fully_connected_layer_input_size 116 | self.fully_connected_layer_output_size = fully_connected_layer_output_size 117 | self.device = device 118 | 119 | self.w_graph_node_features = nn.Parameter( 120 | to.empty([number_of_nodes, number_of_nodes], 121 | device=self.device), 122 | requires_grad=True) 123 | self.w_graph_neighbor_messages = nn.Parameter( 124 | to.empty([number_of_nodes, number_of_nodes], 125 | device=self.device), 126 | requires_grad=True) 127 | self.u_graph_node_features = nn.Parameter( 128 | to.empty([number_of_nodes, number_of_nodes], 129 | device=self.device), 130 | requires_grad=True) 131 | self.u_graph_neighbor_messages = nn.Parameter( 132 | to.empty([number_of_node_features, number_of_node_features], 133 | device=self.device), 134 | requires_grad=True) 135 | self.linear_weight = nn.Parameter( 136 | to.empty([self.fully_connected_layer_output_size, self.fully_connected_layer_input_size], 137 | device=self.device), 138 | requires_grad=True) 139 | self.linear_bias = nn.Parameter( 140 | to.empty(self.fully_connected_layer_output_size, 141 | device=self.device), 142 | requires_grad=True) 143 | self.reset_parameters() 144 | 145 | def reset_parameters(self): 146 | nn.init.kaiming_normal_(self.w_graph_node_features) 147 | nn.init.kaiming_normal_(self.w_graph_neighbor_messages) 148 | nn.init.kaiming_normal_(self.u_graph_node_features) 149 | nn.init.kaiming_normal_(self.u_graph_neighbor_messages) 150 | nn.init.kaiming_uniform_(self.linear_weight, a=math.sqrt(5)) 151 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.linear_weight) 152 | nn.init.uniform_(self.linear_bias, -1 / math.sqrt(fan_in), 1 / math.sqrt(fan_in)) 153 | 154 | def forward(self, 155 | node_features: to.Tensor, 156 | all_neighbors: to.Tensor, 157 | batch_size: int) -> to.Tensor: 158 | return RNNEncoderFunction.apply(self.time_steps, 159 | self.number_of_nodes, 160 | self.number_of_node_features, 161 | self.fully_connected_layer_output_size, 162 | batch_size, 163 | self.device, 164 | node_features, 165 | all_neighbors, 166 | self.w_graph_node_features, 167 | self.w_graph_neighbor_messages, 168 | self.u_graph_node_features, 169 | self.u_graph_neighbor_messages, 170 | self.linear_weight, 171 | self.linear_bias) 172 | 173 | def get_model_size(self) -> str: 174 | return str(int((self.w_graph_node_features.element_size() * self.w_graph_node_features.nelement() + 175 | self.w_graph_neighbor_messages.element_size() * self.w_graph_neighbor_messages.nelement() + 176 | self.u_graph_node_features.element_size() * self.u_graph_node_features.nelement() + 177 | self.u_graph_neighbor_messages.element_size() * self.u_graph_neighbor_messages.nelement() + 178 | self.linear_weight.element_size() * self.linear_weight.nelement() + 179 | self.linear_bias.element_size() * self.linear_bias.nelement()) * 0.000001)) 180 | -------------------------------------------------------------------------------- /message_passing_nn/graph/rnn_encoder_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | std::vector forward_cuda_cpp( 6 | const at::Tensor& time_steps, 7 | const at::Tensor& number_of_nodes, 8 | const at::Tensor& number_of_node_features, 9 | const at::Tensor& fully_connected_layer_output_size, 10 | const at::Tensor& batch_size, 11 | const at::Tensor& node_features, 12 | const at::Tensor& all_neighbors, 13 | const at::Tensor& w_graph_node_features, 14 | const at::Tensor& w_graph_neighbor_messages, 15 | const at::Tensor& u_graph_node_features, 16 | const at::Tensor& u_graph_neighbor_messages, 17 | const at::Tensor& linear_weight, 18 | const at::Tensor& linear_bias); 19 | 20 | std::vector backward_cuda_cpp( 21 | const at::Tensor& grad_output, 22 | const at::Tensor& outputs, 23 | const at::Tensor& linear_outputs, 24 | const at::Tensor& encodings, 25 | const at::Tensor& messages_summed, 26 | const at::Tensor& messages_previous_step_summed, 27 | const at::Tensor& messages, 28 | const at::Tensor& node_features, 29 | const at::Tensor& batch_size, 30 | const at::Tensor& number_of_nodes, 31 | const at::Tensor& number_of_node_features, 32 | const at::Tensor& u_graph_neighbor_messages_summed, 33 | const at::Tensor& linear_weight, 34 | const at::Tensor& linear_bias); 35 | 36 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 37 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 38 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 39 | 40 | std::vector forward_cpp( 41 | const at::Tensor& time_steps, 42 | const at::Tensor& number_of_nodes, 43 | const at::Tensor& number_of_node_features, 44 | const at::Tensor& fully_connected_layer_output_size, 45 | const at::Tensor& batch_size, 46 | const at::Tensor& node_features, 47 | const at::Tensor& all_neighbors, 48 | const at::Tensor& w_graph_node_features, 49 | const at::Tensor& w_graph_neighbor_messages, 50 | const at::Tensor& u_graph_node_features, 51 | const at::Tensor& u_graph_neighbor_messages, 52 | const at::Tensor& linear_weight, 53 | const at::Tensor& linear_bias) { 54 | CHECK_INPUT(time_steps); 55 | CHECK_INPUT(number_of_nodes); 56 | CHECK_INPUT(number_of_node_features); 57 | CHECK_INPUT(fully_connected_layer_output_size); 58 | CHECK_INPUT(batch_size); 59 | CHECK_INPUT(node_features); 60 | CHECK_INPUT(all_neighbors); 61 | CHECK_INPUT(w_graph_node_features); 62 | CHECK_INPUT(w_graph_neighbor_messages); 63 | CHECK_INPUT(u_graph_node_features); 64 | CHECK_INPUT(u_graph_neighbor_messages); 65 | CHECK_INPUT(linear_weight); 66 | CHECK_INPUT(linear_bias); 67 | 68 | return forward_cuda_cpp(time_steps, 69 | number_of_nodes, 70 | number_of_node_features, 71 | fully_connected_layer_output_size, 72 | batch_size, 73 | node_features, 74 | all_neighbors, 75 | w_graph_node_features, 76 | w_graph_neighbor_messages, 77 | u_graph_node_features, 78 | u_graph_neighbor_messages, 79 | linear_weight, 80 | linear_bias); 81 | } 82 | 83 | std::vector backward_cpp( 84 | const at::Tensor& grad_output, 85 | const at::Tensor& outputs, 86 | const at::Tensor& linear_outputs, 87 | const at::Tensor& encodings, 88 | const at::Tensor& messages_summed, 89 | const at::Tensor& messages_previous_step_summed, 90 | const at::Tensor& messages, 91 | const at::Tensor& node_features, 92 | const at::Tensor& batch_size, 93 | const at::Tensor& number_of_nodes, 94 | const at::Tensor& number_of_node_features, 95 | const at::Tensor& u_graph_neighbor_messages_summed, 96 | const at::Tensor& linear_weight, 97 | const at::Tensor& linear_bias) { 98 | CHECK_INPUT(grad_output); 99 | CHECK_INPUT(outputs); 100 | CHECK_INPUT(linear_outputs); 101 | CHECK_INPUT(encodings); 102 | CHECK_INPUT(messages_summed); 103 | CHECK_INPUT(messages_previous_step_summed); 104 | CHECK_INPUT(messages); 105 | CHECK_INPUT(node_features); 106 | CHECK_INPUT(batch_size); 107 | CHECK_INPUT(number_of_nodes); 108 | CHECK_INPUT(number_of_node_features); 109 | CHECK_INPUT(u_graph_neighbor_messages_summed); 110 | CHECK_INPUT(linear_weight); 111 | CHECK_INPUT(linear_bias); 112 | 113 | return backward_cuda_cpp( 114 | grad_output, 115 | outputs, 116 | linear_outputs, 117 | encodings, 118 | messages_summed, 119 | messages_previous_step_summed, 120 | messages, 121 | node_features, 122 | batch_size, 123 | number_of_nodes, 124 | number_of_node_features, 125 | u_graph_neighbor_messages_summed, 126 | linear_weight, 127 | linear_bias); 128 | } 129 | 130 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 131 | m.def("forward", &forward_cpp, "RNN encoder forward pass (CUDA)"); 132 | m.def("backward", &backward_cpp, "RNN encoder backward pass (CUDA)"); 133 | } -------------------------------------------------------------------------------- /message_passing_nn/graph/rnn_encoder_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../utils/derivatives.h" 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | 10 | template 11 | __global__ void compose_messages_kernel( 12 | torch::PackedTensorAccessor32 base_neighbor_messages, 13 | torch::PackedTensorAccessor32 w_graph_neighbor_messages, 14 | torch::PackedTensorAccessor32 base_messages, 15 | torch::PackedTensorAccessor32 all_neighbors, 16 | torch::PackedTensorAccessor32 new_messages) { 17 | 18 | const int index = threadIdx.x; 19 | const int stride = blockDim.x; 20 | 21 | for (int node_id = index; node_id < all_neighbors.size(0); node_id += stride) { 22 | for (int end_node_index = 0; end_node_index < all_neighbors.size(1); end_node_index++){ 23 | auto end_node_id = std::round(all_neighbors[node_id][end_node_index]); 24 | if (end_node_id >= 0) { 25 | for (int index_feature = 0; index_feature < new_messages.size(2); index_feature++) { 26 | new_messages[node_id][end_node_id][index_feature] += base_messages[node_id][index_feature]; 27 | } 28 | for (int neighbor_index = 0; neighbor_index < all_neighbors.size(1); neighbor_index++) { 29 | auto neighbor = std::round(all_neighbors[node_id][neighbor_index]); 30 | if (neighbor >= 0 && neighbor_index!=end_node_index) { 31 | for (int index_feature = 0; index_feature < new_messages.size(2); index_feature++) { 32 | new_messages[node_id][end_node_id][index_feature] += base_neighbor_messages[neighbor][node_id][index_feature]; 33 | } 34 | } 35 | } 36 | } 37 | } 38 | } 39 | } 40 | 41 | at::Tensor encode_messages( 42 | const int& number_of_nodes, 43 | const at::Tensor& node_encoding_messages, 44 | const at::Tensor& u_graph_node_features, 45 | const at::Tensor& u_graph_neighbor_messages, 46 | const at::Tensor& node_features, 47 | const at::Tensor& all_neighbors, 48 | const at::Tensor& messages) { 49 | 50 | for (int node_id = 0; node_id(); 53 | if (end_node_id >= 0) { 54 | node_encoding_messages[node_id] += at::matmul(u_graph_neighbor_messages, at::relu(messages[end_node_id][node_id])); 55 | } 56 | } 57 | } 58 | return at::relu(at::add(at::matmul(u_graph_node_features, node_features), node_encoding_messages)); 59 | } 60 | 61 | std::vector forward_cuda_cpp( 62 | const at::Tensor& time_steps, 63 | const at::Tensor& number_of_nodes, 64 | const at::Tensor& number_of_node_features, 65 | const at::Tensor& fully_connected_layer_output_size, 66 | const at::Tensor& batch_size, 67 | const at::Tensor& node_features, 68 | const at::Tensor& all_neighbors, 69 | const at::Tensor& w_graph_node_features, 70 | const at::Tensor& w_graph_neighbor_messages, 71 | const at::Tensor& u_graph_node_features, 72 | const at::Tensor& u_graph_neighbor_messages, 73 | const at::Tensor& linear_weight, 74 | const at::Tensor& linear_bias) { 75 | 76 | auto outputs = at::zeros({batch_size.item(), fully_connected_layer_output_size.item()}, at::kCUDA); 77 | auto linear_outputs = at::zeros({batch_size.item(), fully_connected_layer_output_size.item()}, at::kCUDA); 78 | auto messages = at::zeros({batch_size.item(), number_of_nodes.item(), number_of_nodes.item(), number_of_node_features.item()}, at::kCUDA); 79 | auto messages_previous_step = at::zeros({batch_size.item(), number_of_nodes.item(), number_of_nodes.item(), number_of_node_features.item()}, at::kCUDA); 80 | auto node_encoding_messages = at::zeros({batch_size.item(), number_of_nodes.item(), number_of_node_features.item()}, at::kCUDA); 81 | auto encodings = at::zeros({batch_size.item(), number_of_nodes.item()*number_of_node_features.item()}, at::kCUDA); 82 | 83 | const int threads = 1024; 84 | const dim3 blocks(std::floor(number_of_nodes.item()/threads) + 1); 85 | 86 | auto base_messages = at::matmul(w_graph_node_features, node_features); 87 | 88 | for (int batch = 0; batch(); batch++) { 89 | auto new_messages = at::zeros_like({messages[batch]}, at::kCUDA); 90 | auto previous_messages = at::zeros_like({messages[batch]}, at::kCUDA); 91 | const auto number_of_nodes = all_neighbors[batch].size(0); 92 | const auto max_neighbors = all_neighbors[batch].size(1); 93 | 94 | for (int time_step = 0; time_step(); time_step++) { 95 | auto base_neighbor_messages = at::matmul(w_graph_neighbor_messages, at::relu(new_messages)); 96 | std::swap(previous_messages, new_messages); 97 | auto base_messages_of_batch = base_messages[batch]; 98 | auto neighbors_of_batch = all_neighbors[batch]; 99 | AT_DISPATCH_FLOATING_TYPES(new_messages.type(), "forward_cpp_cuda", ([&] { 100 | compose_messages_kernel<<>>(base_neighbor_messages.packed_accessor32(), 101 | w_graph_neighbor_messages.packed_accessor32(), 102 | base_messages_of_batch.packed_accessor32(), 103 | neighbors_of_batch.packed_accessor32(), 104 | new_messages.packed_accessor32()); 105 | })); 106 | } 107 | 108 | messages[batch] = new_messages; 109 | messages_previous_step[batch] = previous_messages; 110 | encodings[batch] = encode_messages(number_of_nodes, 111 | node_encoding_messages[batch], 112 | u_graph_node_features, 113 | u_graph_neighbor_messages, 114 | node_features[batch], 115 | all_neighbors[batch], 116 | messages[batch]).view({-1}); 117 | linear_outputs[batch] = at::add(at::matmul(linear_weight, encodings[batch]), linear_bias); 118 | outputs[batch] = at::sigmoid(linear_outputs[batch]); 119 | } 120 | return {outputs, linear_outputs, encodings, messages, messages_previous_step}; 121 | } 122 | 123 | std::vector backward_cuda_cpp( 124 | const at::Tensor& grad_output, 125 | const at::Tensor& outputs, 126 | const at::Tensor& linear_outputs, 127 | const at::Tensor& encodings, 128 | const at::Tensor& messages_summed, 129 | const at::Tensor& messages_previous_step_summed, 130 | const at::Tensor& messages, 131 | const at::Tensor& node_features, 132 | const at::Tensor& batch_size, 133 | const at::Tensor& number_of_nodes, 134 | const at::Tensor& number_of_node_features, 135 | const at::Tensor& u_graph_neighbor_messages_summed, 136 | const at::Tensor& linear_weight, 137 | const at::Tensor& linear_bias) { 138 | 139 | auto delta_1 = grad_output*d_sigmoid(linear_outputs); 140 | auto d_linear_bias = delta_1; 141 | auto d_linear_weight = at::matmul(delta_1.transpose(0, 1), encodings); 142 | 143 | auto delta_2 = at::matmul(delta_1, linear_weight).reshape({batch_size.item(), number_of_nodes.item(), number_of_node_features.item()})*(d_relu_2d(encodings).reshape({batch_size.item(), number_of_nodes.item(), number_of_node_features.item()})); 144 | auto d_u_graph_node_features = at::matmul(delta_2, node_features.transpose(1, 2)); 145 | auto d_u_graph_neighbor_messages = at::matmul(delta_2.transpose(1, 2), messages_summed); 146 | 147 | auto delta_3 = at::matmul(delta_2.transpose(1, 2), at::matmul(u_graph_neighbor_messages_summed, d_relu_4d(messages).transpose(2, 3))); 148 | auto d_w_graph_node_features = at::matmul(delta_3.transpose(1, 2), node_features.transpose(1, 2)); 149 | auto d_w_graph_neighbor_messages = at::matmul(delta_3.transpose(1, 2), messages_previous_step_summed.transpose(1, 2)); 150 | 151 | 152 | return {d_w_graph_node_features, 153 | d_w_graph_neighbor_messages, 154 | d_u_graph_node_features, 155 | d_u_graph_neighbor_messages, 156 | d_linear_weight, 157 | d_linear_bias}; 158 | } -------------------------------------------------------------------------------- /message_passing_nn/infrastructure/__init__.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.infrastructure.file_system_repository import FileSystemRepository 2 | -------------------------------------------------------------------------------- /message_passing_nn/infrastructure/file_system_repository.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | import torch as to 5 | 6 | 7 | class FileSystemRepository: 8 | def __init__(self, data_directory: str, dataset: str) -> None: 9 | super().__init__() 10 | self.data_directory = data_directory + dataset + '/' 11 | self.test_mode = False 12 | 13 | def save(self, filename: str, data_to_save: to.Tensor) -> None: 14 | with open(self.data_directory + filename, 'wb') as file: 15 | pickle.dump(data_to_save, file) 16 | 17 | @staticmethod 18 | def get_logger() -> logging.Logger: 19 | return logging.getLogger('message_passing_nn') 20 | -------------------------------------------------------------------------------- /message_passing_nn/infrastructure/graph_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | from typing import List, Tuple 5 | 6 | import torch as to 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | 10 | 11 | class GraphDataset(Dataset): 12 | def __init__(self, 13 | data_directory: str, 14 | test_mode: bool = False) -> None: 15 | self.data_directory = data_directory 16 | self.test_mode = test_mode 17 | self.dataset = self._load_data() if self.data_directory else [] 18 | 19 | def __len__(self) -> int: 20 | return len(self.dataset) 21 | 22 | def __getitem__(self, index: int) -> Tuple[to.Tensor, to.Tensor, to.Tensor, str]: 23 | return self.dataset[index][0], self.dataset[index][1], self.dataset[index][2], self.dataset[index][3] 24 | 25 | def _load_data(self) -> List[Tuple[to.Tensor, to.Tensor, to.Tensor, str]]: 26 | self.get_logger().info("Loading dataset") 27 | files_in_path = self._extract_name_prefixes_from_filenames() 28 | dataset = [] 29 | size = 0 30 | disable_progress_bar = self.test_mode 31 | for filename_index in tqdm(range(len(files_in_path)), disable=disable_progress_bar): 32 | filename = files_in_path[filename_index] 33 | try: 34 | dataset.append( 35 | (self._get_features(filename), self._get_all_neighbors(filename), self._get_labels(filename), 36 | filename)) 37 | except: 38 | self.get_logger().info("Skipped " + filename) 39 | size += self._get_size(dataset[-1]) 40 | self.get_logger().info( 41 | "Loaded " + str(len(dataset)) + " files. Size: " + str(int(size * 0.000001)) + " MB") 42 | return dataset 43 | 44 | @staticmethod 45 | def _to_list(dataset: List[Tuple[to.Tensor, to.Tensor, to.Tensor]]) -> List[Tuple[to.Tensor, to.Tensor]]: 46 | return [(dataset[index][0], dataset[index][1]) for index in range(len(dataset))] 47 | 48 | @staticmethod 49 | def _extract_labels(dataset: List[Tuple[to.Tensor, to.Tensor, to.Tensor]]) -> List[to.Tensor]: 50 | return [dataset[index][2] for index in range(len(dataset))] 51 | 52 | def _get_labels(self, filename: str) -> to.Tensor: 53 | with open(self.data_directory + filename + 'labels.pickle', 'rb') as labels_file: 54 | labels = pickle.load(labels_file).float() 55 | return labels 56 | 57 | def _get_features(self, filename: str) -> to.Tensor: 58 | with open(self.data_directory + filename + 'features.pickle', 'rb') as features_file: 59 | features = pickle.load(features_file).float() 60 | return features 61 | 62 | def _get_all_neighbors(self, filename: str) -> to.Tensor: 63 | with open(self.data_directory + filename + 'adjacency-matrix.pickle', 'rb') as adjacency_matrix_file: 64 | adjacency_matrix = pickle.load(adjacency_matrix_file).float() 65 | number_of_nodes = adjacency_matrix.shape[0] 66 | all_neighbors_list = [] 67 | max_number_of_neighbors = -1 68 | for node_id in range(adjacency_matrix.shape[0]): 69 | neighbors = to.nonzero(adjacency_matrix[node_id], as_tuple=True)[0].tolist() 70 | if len(neighbors) > max_number_of_neighbors: 71 | max_number_of_neighbors = len(neighbors) 72 | all_neighbors_list.append(neighbors) 73 | all_neighbors = self._get_minus_ones_tensor(number_of_nodes, max_number_of_neighbors) 74 | for node_id in range(number_of_nodes): 75 | all_neighbors[node_id, :len(all_neighbors_list[node_id])] = to.tensor(all_neighbors_list[node_id]) 76 | return all_neighbors 77 | 78 | @staticmethod 79 | def _get_minus_ones_tensor(number_of_nodes: int, max_number_of_neighbors: int) -> to.Tensor: 80 | return to.zeros(number_of_nodes, max_number_of_neighbors) - to.ones(number_of_nodes, max_number_of_neighbors) 81 | 82 | @staticmethod 83 | def _get_size(data: Tuple[to.Tensor, to.Tensor, to.Tensor, str]) -> int: 84 | return int(data[0].element_size() * data[0].nelement() + 85 | data[1].element_size() * data[1].nelement() + 86 | data[2].element_size() * data[2].nelement()) 87 | 88 | def _extract_name_prefixes_from_filenames(self) -> List[str]: 89 | return list(set([self._reconstruct_filename(file) for file in self._get_data_filenames()])) 90 | 91 | def _get_data_filenames(self) -> List[str]: 92 | return sorted([file for file in os.listdir(self.data_directory) if file.endswith(".pickle")]) 93 | 94 | @staticmethod 95 | def _reconstruct_filename(file: str) -> str: 96 | return "_".join(file.split("_")[:-1]) + "_" 97 | 98 | def enable_test_mode(self) -> None: 99 | self.test_mode = True 100 | 101 | @staticmethod 102 | def get_logger() -> logging.Logger: 103 | return logging.getLogger('message_passing_nn') 104 | -------------------------------------------------------------------------------- /message_passing_nn/model/__init__.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.model.trainer import Trainer 2 | from message_passing_nn.model.loader import Loader 3 | from message_passing_nn.model.inferencer import Inferencer 4 | -------------------------------------------------------------------------------- /message_passing_nn/model/inferencer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch as to 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | 7 | from message_passing_nn.data import DataPreprocessor 8 | 9 | 10 | class Inferencer: 11 | def __init__(self, data_preprocessor: DataPreprocessor, device: str, normalize: bool = False) -> None: 12 | self.preprocessor = data_preprocessor 13 | self.device = device 14 | self.normalize = normalize 15 | 16 | def do_inference(self, model: nn.Module, inference_data: DataLoader) -> List[Tuple[to.Tensor, to.Tensor, str]]: 17 | outputs_labels_pairs = [] 18 | with to.no_grad(): 19 | for node_features, all_neighbors, labels, tag in inference_data: 20 | node_features, all_neighbors, labels = (node_features.to(self.device), 21 | all_neighbors.to(self.device), 22 | labels.to(self.device)) 23 | if self.normalize: 24 | node_features = self.preprocessor.normalize(node_features, self.device) 25 | labels = self.preprocessor.normalize(labels, self.device) 26 | outputs = model.forward(node_features, all_neighbors, batch_size=1) 27 | outputs_labels_pairs.append((outputs, labels, tag)) 28 | return outputs_labels_pairs 29 | -------------------------------------------------------------------------------- /message_passing_nn/model/loader.py: -------------------------------------------------------------------------------- 1 | import torch as to 2 | 3 | from torch import nn 4 | from typing import Dict, Tuple 5 | 6 | from message_passing_nn.utils import ModelSelector 7 | 8 | 9 | class Loader: 10 | def __init__(self, model: str) -> None: 11 | self.model = ModelSelector.load_model(model) 12 | 13 | def load_model(self, data_dimensions: Tuple, path_to_model: str) -> nn.Module: 14 | model_parameters = self._get_model_parameters_from_path(path_to_model) 15 | node_features_size, labels_size = data_dimensions 16 | number_of_nodes = node_features_size[0] 17 | number_of_node_features = node_features_size[1] 18 | fully_connected_layer_output_size = labels_size[0] 19 | self.model = self.model(time_steps=int(model_parameters['time_steps']), 20 | number_of_nodes=number_of_nodes, 21 | number_of_node_features=number_of_node_features, 22 | fully_connected_layer_input_size=number_of_nodes * number_of_node_features, 23 | fully_connected_layer_output_size=fully_connected_layer_output_size) 24 | self.model.load_state_dict(to.load(path_to_model)) 25 | self.model.eval() 26 | return self.model 27 | 28 | @staticmethod 29 | def _get_model_parameters_from_path(path_to_model: str) -> Dict: 30 | model_configuration = path_to_model.split("/")[-2].split("__") 31 | model_parameters = {} 32 | for model_parameter in model_configuration: 33 | key, value = model_parameter.split("&")[0], model_parameter.split("&")[1] 34 | model_parameters.update({key: value}) 35 | return model_parameters 36 | -------------------------------------------------------------------------------- /message_passing_nn/model/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Any, Tuple 3 | 4 | import numpy as np 5 | import torch as to 6 | from torch.nn.modules.module import Module 7 | from torch.optim.optimizer import Optimizer 8 | from torch.utils.data.dataloader import DataLoader 9 | 10 | from message_passing_nn.data.preprocessor import Preprocessor 11 | from message_passing_nn.utils.loss_function_selector import LossFunctionSelector 12 | from message_passing_nn.utils.model_selector import ModelSelector 13 | from message_passing_nn.utils.optimizer_selector import OptimizerSelector 14 | 15 | 16 | class Trainer: 17 | def __init__(self, preprocessor: Preprocessor, device: str, normalize: bool = False) -> None: 18 | self.preprocessor = preprocessor 19 | self.device = device 20 | self.normalize = normalize 21 | self.model = None 22 | self.loss_function = None 23 | self.optimizer = None 24 | 25 | def instantiate_attributes(self, 26 | data_dimensions: Tuple, 27 | configuration_dictionary: Dict) -> None: 28 | node_features_size, labels_size = data_dimensions 29 | number_of_nodes = node_features_size[0] 30 | number_of_node_features = node_features_size[1] 31 | fully_connected_layer_output_size = labels_size[0] 32 | self.model = ModelSelector.load_model(configuration_dictionary['model']) 33 | self.model = self.model(time_steps=configuration_dictionary['time_steps'], 34 | number_of_nodes=number_of_nodes, 35 | number_of_node_features=number_of_node_features, 36 | fully_connected_layer_input_size=number_of_nodes * number_of_node_features, 37 | fully_connected_layer_output_size=fully_connected_layer_output_size, 38 | device=self.device) 39 | self.get_logger().info('Loaded the ' + configuration_dictionary['model'] + 40 | ' model. Model weights size: ' + self.model.get_model_size() + ' MB') 41 | self.model.to(self.device) 42 | self.loss_function = self._instantiate_the_loss_function( 43 | LossFunctionSelector.load_loss_function(configuration_dictionary['loss_function'])) 44 | self.get_logger().info('Loss function: ' + configuration_dictionary['loss_function']) 45 | self.optimizer = self._instantiate_the_optimizer( 46 | OptimizerSelector.load_optimizer(configuration_dictionary['optimizer'])) 47 | self.get_logger().info('Optimizer: ' + configuration_dictionary['optimizer']) 48 | 49 | def do_train(self, training_data: DataLoader, epoch: int) -> float: 50 | training_loss = np.average(list(map(self._do_train_batch, training_data))) 51 | self.get_logger().info('[Iteration %d] training loss: %.6f' % (epoch, training_loss)) 52 | return training_loss 53 | 54 | def _do_train_batch(self, training_data: DataLoader) -> float: 55 | node_features, all_neighbors, labels, _ = training_data 56 | node_features, all_neighbors, labels = (node_features.to(self.device), 57 | all_neighbors.to(self.device), 58 | labels.to(self.device)) 59 | current_batch_size = self._get_current_batch_size(labels) 60 | if self.normalize: 61 | node_features = self.preprocessor.normalize(node_features, self.device) 62 | labels = self.preprocessor.normalize(labels, self.device) 63 | self.optimizer.zero_grad() 64 | outputs = self.model(node_features, all_neighbors, batch_size=current_batch_size) 65 | loss = self.loss_function(outputs, labels) 66 | self._do_backpropagate(loss) 67 | return loss.item() 68 | 69 | def do_evaluate(self, evaluation_data: DataLoader, epoch: int = None) -> float: 70 | with to.no_grad(): 71 | evaluation_loss = [] 72 | if len(evaluation_data): 73 | for node_features, all_neighbors, labels_validation, _ in evaluation_data: 74 | node_features, all_neighbors, labels_validation = (node_features.to(self.device), 75 | all_neighbors.to(self.device), 76 | labels_validation.to(self.device)) 77 | if self.normalize: 78 | node_features = self.preprocessor.normalize(node_features, self.device) 79 | labels_validation = self.preprocessor.normalize(labels_validation, self.device) 80 | current_batch_size = self._get_current_batch_size(labels_validation) 81 | outputs = self.model(node_features, all_neighbors, current_batch_size) 82 | loss = self.loss_function(outputs, labels_validation) 83 | evaluation_loss.append(float(loss)) 84 | evaluation_loss = np.average(evaluation_loss) 85 | if epoch is not None: 86 | self.get_logger().info('[Iteration %d] validation loss: %.6f' % (epoch, evaluation_loss)) 87 | else: 88 | self.get_logger().info('Test loss: %.6f' % evaluation_loss) 89 | else: 90 | self.get_logger().warning('No evaluation data found!') 91 | return evaluation_loss 92 | 93 | def _do_backpropagate(self, loss: to.Tensor) -> None: 94 | loss.backward() 95 | self.optimizer.step() 96 | 97 | @staticmethod 98 | def _instantiate_the_loss_function(loss_function: Module) -> Module: 99 | return loss_function() 100 | 101 | def _instantiate_the_optimizer(self, optimizer: Any) -> Optimizer: 102 | model_parameters = list(self.model.parameters()) 103 | try: 104 | optimizer = optimizer(model_parameters, lr=0.001, momentum=0.9) 105 | except: 106 | optimizer = optimizer(model_parameters, lr=0.001) 107 | return optimizer 108 | 109 | @staticmethod 110 | def _get_current_batch_size(features: to.Tensor) -> int: 111 | return len(features) 112 | 113 | @staticmethod 114 | def get_logger() -> logging.Logger: 115 | return logging.getLogger('message_passing_nn') 116 | -------------------------------------------------------------------------------- /message_passing_nn/usecase/__init__.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.usecase.usecase import Usecase 2 | from message_passing_nn.usecase.grid_search import GridSearch 3 | from message_passing_nn.usecase.inference import Inference 4 | -------------------------------------------------------------------------------- /message_passing_nn/usecase/grid_search.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict, List, Tuple 4 | 5 | import itertools 6 | import numpy as np 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 10 | from message_passing_nn.infrastructure.graph_dataset import GraphDataset 11 | from message_passing_nn.model.trainer import Trainer 12 | from message_passing_nn.usecase import Usecase 13 | from message_passing_nn.utils.saver import Saver 14 | 15 | 16 | class GridSearch(Usecase): 17 | def __init__(self, 18 | data_path: str, 19 | data_preprocessor: DataPreprocessor, 20 | trainer: Trainer, 21 | grid_search_dictionary: Dict, 22 | saver: Saver, 23 | test_mode: bool = False) -> None: 24 | self.data_path = data_path 25 | self.data_preprocessor = data_preprocessor 26 | self.trainer = trainer 27 | self.grid_search_dictionary = grid_search_dictionary 28 | self.saver = saver 29 | self.test_mode = test_mode 30 | 31 | def start(self) -> Dict: 32 | all_grid_search_configurations = self._get_all_grid_search_configurations() 33 | losses = {'training_loss': {}, 34 | 'validation_loss': {}, 35 | 'test_loss': {}} 36 | configuration_id = '' 37 | for configuration in all_grid_search_configurations: 38 | configuration_id, configuration_dictionary = self._get_configuration_dictionary(configuration) 39 | losses = self._search_configuration(configuration_id, configuration_dictionary, losses) 40 | self.saver.save_results(configuration_id, losses) 41 | self.get_logger().info('Finished Training') 42 | return losses 43 | 44 | def _search_configuration(self, configuration_id: str, configuration_dictionary: Dict, losses: Dict) -> Dict: 45 | training_data, validation_data, test_data, data_dimensions = self._prepare_dataset(configuration_dictionary) 46 | self.trainer.instantiate_attributes(data_dimensions, configuration_dictionary) 47 | losses = self._update_losses_with_configuration_id(configuration_dictionary, losses) 48 | validation_loss_max = np.inf 49 | self.get_logger().info('Started Training') 50 | for epoch in range(1, configuration_dictionary['epochs'] + 1): 51 | training_loss = self.trainer.do_train(training_data, epoch) 52 | losses['training_loss'][configuration_dictionary["configuration_id"]].update({epoch: training_loss}) 53 | if epoch % configuration_dictionary["validation_period"] == 0: 54 | validation_loss = self.trainer.do_evaluate(validation_data, epoch) 55 | losses['validation_loss'][configuration_dictionary["configuration_id"]].update( 56 | {epoch: validation_loss}) 57 | if validation_loss < validation_loss_max: 58 | self.saver.save_model(epoch, configuration_id, self.trainer.model) 59 | test_loss = self.trainer.do_evaluate(test_data) 60 | losses['test_loss'][configuration_dictionary["configuration_id"]].update({"final_epoch": test_loss}) 61 | return losses 62 | 63 | @staticmethod 64 | def _update_losses_with_configuration_id(configuration_dictionary: Dict, losses: Dict) -> Dict: 65 | losses['training_loss'].update({configuration_dictionary["configuration_id"]: {}}) 66 | losses['validation_loss'].update({configuration_dictionary["configuration_id"]: {}}) 67 | losses['test_loss'].update({configuration_dictionary["configuration_id"]: {}}) 68 | return losses 69 | 70 | @staticmethod 71 | def _get_configuration_dictionary(configuration: Tuple[Tuple]) -> Tuple[str, Dict]: 72 | configuration_dictionary = dict(((key, value) for key, value in configuration)) 73 | configuration_id = 'configuration&id' 74 | for key, value in configuration_dictionary.items(): 75 | configuration_id += "__" + "&".join([key, str(value)]) 76 | configuration_dictionary.update({"configuration_id": configuration_id}) 77 | return configuration_id, configuration_dictionary 78 | 79 | def _prepare_dataset(self, configuration_dictionary: Dict) -> Tuple[DataLoader, DataLoader, DataLoader, Tuple]: 80 | dataset = GraphDataset(self.data_path, test_mode=self.test_mode) 81 | dataset.enable_test_mode() 82 | self.get_logger().info("Calculating all neighbors for each node") 83 | training_data, validation_data, test_data = self.data_preprocessor \ 84 | .train_validation_test_split(dataset, 85 | configuration_dictionary['batch_size'], 86 | configuration_dictionary['validation_split'], 87 | configuration_dictionary['test_split']) 88 | data_dimensions = self.data_preprocessor.extract_data_dimensions(dataset) 89 | return training_data, validation_data, test_data, data_dimensions 90 | 91 | def _get_all_grid_search_configurations(self) -> List[Tuple[Tuple]]: 92 | all_grid_search_configurations = [] 93 | for key in self.grid_search_dictionary.keys(): 94 | all_grid_search_configurations.append([(key, value) for value in self.grid_search_dictionary[key]]) 95 | return list(itertools.product(*all_grid_search_configurations)) 96 | 97 | @staticmethod 98 | def get_logger() -> logging.Logger: 99 | return logging.getLogger('message_passing_nn') 100 | -------------------------------------------------------------------------------- /message_passing_nn/usecase/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple 3 | 4 | from torch.utils.data.dataloader import DataLoader 5 | 6 | from message_passing_nn.data import DataPreprocessor 7 | from message_passing_nn.infrastructure.graph_dataset import GraphDataset 8 | from message_passing_nn.model.inferencer import Inferencer 9 | from message_passing_nn.model.loader import Loader 10 | from message_passing_nn.usecase import Usecase 11 | from message_passing_nn.utils import Saver 12 | 13 | 14 | class Inference(Usecase): 15 | def __init__(self, 16 | data_path: str, 17 | data_preprocessor: DataPreprocessor, 18 | loader: Loader, 19 | inferencer: Inferencer, 20 | saver: Saver, 21 | test_mode: bool = False) -> None: 22 | self.data_path = data_path 23 | self.data_preprocessor = data_preprocessor 24 | self.loader = loader 25 | self.inferencer = inferencer 26 | self.saver = saver 27 | self.test_mode = test_mode 28 | 29 | def start(self) -> None: 30 | self.get_logger().info('Started Inference') 31 | configuration_id = '' 32 | inference_dataset, data_dimensions = self._prepare_dataset() 33 | model = self.loader.load_model(data_dimensions, self.saver.model_directory) 34 | outputs_labels_pairs = self.inferencer.do_inference(model, inference_dataset) 35 | self.saver.save_distance_maps(configuration_id, outputs_labels_pairs) 36 | self.get_logger().info('Finished Inference') 37 | 38 | def _prepare_dataset(self) -> Tuple[DataLoader, Tuple]: 39 | dataset = GraphDataset(self.data_path, test_mode=self.test_mode) 40 | inference_dataset = self.data_preprocessor.get_dataloader(dataset) 41 | data_dimensions = self.data_preprocessor.extract_data_dimensions(dataset) 42 | return inference_dataset, data_dimensions 43 | 44 | @staticmethod 45 | def get_logger() -> logging.Logger: 46 | return logging.getLogger('message_passing_nn') 47 | -------------------------------------------------------------------------------- /message_passing_nn/usecase/usecase.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Usecase(metaclass=ABCMeta): 5 | @abstractmethod 6 | def start(self) -> None: 7 | pass 8 | -------------------------------------------------------------------------------- /message_passing_nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.utils.grid_search_parameters_parser import GridSearchParametersParser 2 | from message_passing_nn.utils.loss_function_selector import LossFunctionSelector 3 | from message_passing_nn.utils.optimizer_selector import OptimizerSelector 4 | from message_passing_nn.utils.model_selector import ModelSelector 5 | from message_passing_nn.utils.saver import Saver 6 | -------------------------------------------------------------------------------- /message_passing_nn/utils/derivatives.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../utils/derivatives.h" 3 | 4 | at::Tensor d_sigmoid(at::Tensor z) { 5 | auto s = at::sigmoid(z); 6 | return (1 - s) * s; 7 | } 8 | 9 | at::Tensor d_relu_2d(at::Tensor z) { 10 | auto output = at::zeros_like(z); 11 | for (int i = 0; i() > 0.0) { 14 | output[i][j] = 1; 15 | } 16 | } 17 | } 18 | return output; 19 | } 20 | 21 | at::Tensor d_relu_4d(at::Tensor z) { 22 | auto output = at::zeros_like(z); 23 | for (int i = 0; i() > 0.0) { 28 | output[i][j][k][l] = 1; 29 | } 30 | } 31 | } 32 | } 33 | } 34 | return output; 35 | } 36 | -------------------------------------------------------------------------------- /message_passing_nn/utils/derivatives.h: -------------------------------------------------------------------------------- 1 | #ifndef DERIVATIVES_H 2 | #define DERIVATIVES_H 3 | 4 | #include 5 | 6 | at::Tensor d_sigmoid(at::Tensor z); 7 | at::Tensor d_relu_2d(at::Tensor z); 8 | at::Tensor d_relu_4d(at::Tensor z); 9 | 10 | #endif -------------------------------------------------------------------------------- /message_passing_nn/utils/grid_search_parameters_parser.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict 3 | 4 | import numpy as np 5 | 6 | from message_passing_nn.fixtures.characters import GRID_SEARCH_SEPARATION_CHARACTER 7 | 8 | 9 | class GridSearchParametersParser: 10 | def __init__(self) -> None: 11 | pass 12 | 13 | def get_grid_search_dictionary(self, 14 | model_selection: str, 15 | epochs: str, 16 | loss_function_selection: str, 17 | optimizer_selection: str, 18 | batch_size: str, 19 | validation_split: str, 20 | test_split: str, 21 | time_steps: str, 22 | validation_period: str) -> Dict: 23 | return { 24 | 'model': self._parse_string_selections(model_selection), 25 | 'epochs': self._parse_integer_range(epochs), 26 | 'loss_function': self._parse_string_selections(loss_function_selection), 27 | 'optimizer': self._parse_string_selections(optimizer_selection), 28 | 'batch_size': self._parse_integer_range(batch_size), 29 | 'validation_split': self._parse_float_range(validation_split), 30 | 'test_split': self._parse_float_range(test_split), 31 | 'time_steps': self._parse_integer_range(time_steps), 32 | 'validation_period': self._parse_integer_range(validation_period) 33 | } 34 | 35 | def _parse_integer_range(self, field: str) -> List[int]: 36 | integer_range = field.split(GRID_SEARCH_SEPARATION_CHARACTER) 37 | if len(integer_range) == 1: 38 | return [int(integer_range[0])] 39 | elif len(integer_range) == 3: 40 | min_range, max_range, number_of_values = integer_range 41 | integer_range = np.linspace(int(min_range), int(max_range), int(number_of_values)) 42 | return [int(number) for number in integer_range] 43 | else: 44 | self.get_logger().info("Incorrect values for integer range. Please either provide " 45 | "a single integer or three integers separated by & (min&max&values)") 46 | raise Exception 47 | 48 | def _parse_float_range(self, field: str) -> List[float]: 49 | float_range = field.split(GRID_SEARCH_SEPARATION_CHARACTER) 50 | if len(float_range) == 1: 51 | return [float(float_range[0])] 52 | elif len(float_range) == 3: 53 | min_range, max_range, number_of_values = float_range 54 | float_range = np.linspace(float(min_range), float(max_range), int(number_of_values)) 55 | return [float(number) for number in float_range] 56 | else: 57 | self.get_logger().info("Incorrect values for float range. Please either provide " 58 | "a single float or two floats and an integer separated by & (min&max&values)") 59 | raise Exception 60 | 61 | @staticmethod 62 | def _parse_string_selections(string_selection: str) -> List[str]: 63 | return string_selection.split(GRID_SEARCH_SEPARATION_CHARACTER) 64 | 65 | @staticmethod 66 | def get_logger() -> logging.Logger: 67 | return logging.getLogger('message_passing_nn') 68 | -------------------------------------------------------------------------------- /message_passing_nn/utils/loss_function_selector.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.nn.modules.module import Module 4 | 5 | from message_passing_nn.utils.loss_functions import loss_functions 6 | 7 | 8 | class LossFunctionSelector: 9 | def __init__(self) -> None: 10 | pass 11 | 12 | @staticmethod 13 | def load_loss_function(loss_function_selection: str) -> Module: 14 | if loss_function_selection in loss_functions: 15 | return loss_functions[loss_function_selection] 16 | else: 17 | get_logger().info("The " + loss_function_selection + " is not available") 18 | raise Exception 19 | 20 | 21 | def get_logger() -> logging.Logger: 22 | return logging.getLogger('message_passing_nn') 23 | -------------------------------------------------------------------------------- /message_passing_nn/utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | loss_functions = { 4 | "MSE": nn.MSELoss, 5 | "L1": nn.L1Loss, 6 | "CrossEntropy": nn.CrossEntropyLoss, 7 | "CTC": nn.CTCLoss, 8 | "NLL": nn.NLLLoss, 9 | "PoissonNLL": nn.PoissonNLLLoss, 10 | "KLDiv": nn.KLDivLoss, 11 | "BCE": nn.BCELoss, 12 | "BCEWithLogits": nn.BCEWithLogitsLoss, 13 | "MarginRanking": nn.MarginRankingLoss, 14 | "HingeEmbedding": nn.HingeEmbeddingLoss, 15 | "MultiLabelMargin": nn.MultiLabelMarginLoss, 16 | "SmoothL1": nn.SmoothL1Loss, 17 | "SoftMargin": nn.SoftMarginLoss, 18 | "MultiLabelSoftMargin": nn.MultiLabelSoftMarginLoss, 19 | "CosineEmbedding": nn.CosineEmbeddingLoss, 20 | "MultiMargin": nn.MultiMarginLoss, 21 | "TripletMargin": nn.TripletMarginLoss 22 | } 23 | -------------------------------------------------------------------------------- /message_passing_nn/utils/messages.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../utils/messages.h" 3 | 4 | std::vector compose_messages( 5 | const int& time_steps, 6 | const int& number_of_nodes, 7 | const int& number_of_node_features, 8 | const at::Tensor& w_graph_node_features, 9 | const at::Tensor& w_graph_neighbor_messages, 10 | const at::Tensor& base_messages, 11 | const at::Tensor& all_neighbors, 12 | const at::Tensor& messages_init) { 13 | 14 | auto new_messages = at::zeros_like({messages_init}); 15 | auto previous_messages = at::zeros_like({messages_init}); 16 | auto new_messages_of_node = at::zeros({previous_messages.sizes()[1], previous_messages.sizes()[2]}); 17 | for (int time_step = 0; time_step(); 23 | if (end_node_id >= 0) { 24 | new_messages[node_id][end_node_id] += base_messages[node_id]; 25 | for (int neighbor_index = 0; neighbor_index < all_neighbors.size(1); neighbor_index++) { 26 | auto neighbor = all_neighbors[node_id][neighbor_index].item(); 27 | if (neighbor >= 0 && neighbor_index!=end_node_index) { 28 | new_messages[node_id][end_node_id] += base_neighbor_messages[neighbor][node_id]; 29 | } 30 | } 31 | } 32 | } 33 | } 34 | } 35 | return {new_messages, previous_messages}; 36 | } 37 | 38 | at::Tensor encode_messages( 39 | const int& number_of_nodes, 40 | const at::Tensor& node_encoding_messages, 41 | const at::Tensor& u_graph_node_features, 42 | const at::Tensor& u_graph_neighbor_messages, 43 | const at::Tensor& node_features, 44 | const at::Tensor& all_neighbors, 45 | const at::Tensor& messages) { 46 | 47 | for (int node_id = 0; node_id(); 50 | if (end_node_id >= 0) { 51 | node_encoding_messages[node_id] += at::matmul(u_graph_neighbor_messages, at::relu(messages[end_node_id][node_id])); 52 | } 53 | } 54 | } 55 | return at::relu(at::add(at::matmul(u_graph_node_features, node_features), node_encoding_messages)); 56 | } 57 | -------------------------------------------------------------------------------- /message_passing_nn/utils/messages.h: -------------------------------------------------------------------------------- 1 | #ifndef MESSAGES_H 2 | #define MESSAGES_H 3 | 4 | #include 5 | 6 | at::Tensor get_messages_to_all_end_nodes(const int& node_id, 7 | const at::Tensor& w_graph_neighbor_messages, 8 | const at::Tensor& w_graph_node_features, 9 | const at::Tensor& all_neighbors, 10 | const at::Tensor& features_of_specific_node, 11 | at::Tensor& messages_previous_step); 12 | 13 | std::vector compose_messages( 14 | const int& time_steps, 15 | const int& number_of_true_nodes, 16 | const int& number_of_node_features, 17 | const at::Tensor& w_graph_node_features, 18 | const at::Tensor& w_graph_neighbor_messages, 19 | const at::Tensor& base_messages, 20 | const at::Tensor& all_neighbors, 21 | const at::Tensor& messages_init); 22 | 23 | at::Tensor encode_messages( 24 | const int& number_of_true_nodes, 25 | const at::Tensor& node_encoding_messages, 26 | const at::Tensor& u_graph_node_features, 27 | const at::Tensor& u_graph_neighbor_messages, 28 | const at::Tensor& node_features, 29 | const at::Tensor& all_neighbors, 30 | const at::Tensor& messages); 31 | 32 | #endif -------------------------------------------------------------------------------- /message_passing_nn/utils/model_selector.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch import nn 4 | 5 | from message_passing_nn.utils.models import models 6 | 7 | 8 | class ModelSelector: 9 | def __init__(self) -> None: 10 | pass 11 | 12 | @staticmethod 13 | def load_model(model_selection: str) -> nn.Module: 14 | if model_selection in models: 15 | return models[model_selection] 16 | else: 17 | get_logger().info("The " + model_selection + " model is not available") 18 | raise Exception 19 | 20 | 21 | def get_logger() -> logging.Logger: 22 | return logging.getLogger('message_passing_nn') 23 | -------------------------------------------------------------------------------- /message_passing_nn/utils/models.py: -------------------------------------------------------------------------------- 1 | from message_passing_nn.graph import rnn_encoder 2 | 3 | models = { 4 | "RNN": rnn_encoder.RNNEncoder, 5 | } 6 | -------------------------------------------------------------------------------- /message_passing_nn/utils/optimizer_selector.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim.optimizer import Optimizer 4 | 5 | from message_passing_nn.utils.optimizers import optimizers 6 | 7 | 8 | class OptimizerSelector: 9 | def __init__(self) -> None: 10 | pass 11 | 12 | @staticmethod 13 | def load_optimizer(optimizer_selection: str) -> Optimizer: 14 | if optimizer_selection in optimizers: 15 | return optimizers[optimizer_selection] 16 | else: 17 | get_logger().info("The " + optimizer_selection + " is not available") 18 | raise Exception 19 | 20 | 21 | def get_logger() -> logging.Logger: 22 | return logging.getLogger('message_passing_nn') 23 | -------------------------------------------------------------------------------- /message_passing_nn/utils/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch as to 2 | 3 | optimizers = { 4 | "Adadelta": to.optim.Adadelta, 5 | "Adagrad": to.optim.Adagrad, 6 | "Adam": to.optim.Adam, 7 | "AdamW": to.optim.AdamW, 8 | "SparseAdam": to.optim.SparseAdam, 9 | "Adamax": to.optim.Adamax, 10 | "ASGD": to.optim.ASGD, 11 | "LBFGS": to.optim.LBFGS, 12 | "RMSprop": to.optim.RMSprop, 13 | "Rprop": to.optim.Rprop, 14 | "SGD": to.optim.SGD, 15 | } 16 | -------------------------------------------------------------------------------- /message_passing_nn/utils/saver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | from datetime import datetime 5 | from typing import Dict, List, Tuple 6 | 7 | import torch as to 8 | from pandas import pandas as pd 9 | 10 | from message_passing_nn.fixtures.filenames import * 11 | 12 | 13 | class Saver: 14 | def __init__(self, model_directory: str, results_directory: str) -> None: 15 | self.model_directory = model_directory 16 | self.results_directory = results_directory 17 | 18 | def save_model(self, epoch: int, configuration_id: str, model: to.nn.Module) -> None: 19 | current_folder = self._join_path([self.model_directory, configuration_id]) 20 | if not os.path.exists(current_folder): 21 | os.makedirs(current_folder) 22 | path_and_filename = self._join_path([current_folder, self._join_strings([EPOCH, 23 | str(epoch), 24 | MODEL_STATE_DICTIONARY])]) 25 | to.save(model.state_dict(), path_and_filename) 26 | self.get_logger().info("Saved model checkpoint in " + path_and_filename) 27 | 28 | def save_results(self, configuration_id: str, results: Dict) -> None: 29 | current_folder = self._join_path([self.results_directory, configuration_id]) 30 | if not os.path.exists(current_folder): 31 | os.makedirs(current_folder) 32 | results_dataframe = self._construct_dataframe_from_nested_dictionary(results) 33 | path_and_filename = self._join_path([current_folder, 34 | self._join_strings([datetime.now().strftime("%d-%b-%YT%H_%M"), 35 | RESULTS_CSV])]) 36 | results_dataframe.to_csv(path_and_filename) 37 | self.get_logger().info("Saved results in " + path_and_filename) 38 | 39 | def save_distance_maps(self, configuration_id: str, distance_maps: List[Tuple]): 40 | current_folder = self._join_path([self.results_directory, configuration_id]) 41 | if not os.path.exists(current_folder): 42 | os.makedirs(current_folder) 43 | path_and_filename = self._join_path([current_folder, 44 | self._join_strings([datetime.now().strftime("%d-%b-%YT%H_%M"), 45 | DISTANCE_MAPS])]) 46 | with open(path_and_filename, 'wb') as file: 47 | pickle.dump(distance_maps, file) 48 | self.get_logger().info("Saved inference outputs in " + path_and_filename) 49 | 50 | @staticmethod 51 | def _join_strings(fields: List) -> str: 52 | return "_".join(fields) 53 | 54 | @staticmethod 55 | def _construct_dataframe_from_nested_dictionary(results: Dict) -> pd.DataFrame: 56 | results_dataframe = pd.DataFrame.from_dict({(i, j): results[i][j] 57 | for i in results.keys() 58 | for j in results[i].keys()}, 59 | orient='index') 60 | return results_dataframe 61 | 62 | @staticmethod 63 | def _join_path(fields: List) -> str: 64 | return "/".join(fields) 65 | 66 | @staticmethod 67 | def get_logger() -> logging.Logger: 68 | return logging.getLogger('message_passing_nn') 69 | -------------------------------------------------------------------------------- /parameters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/parameters/__init__.py -------------------------------------------------------------------------------- /parameters/grid-search-parameters.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export DATASET_NAME='sample-dataset' 3 | export DATA_DIRECTORY='data/' 4 | export MODEL_DIRECTORY='model_checkpoints' 5 | export RESULTS_DIRECTORY='results_grid_search' 6 | export MODEL='RNN' 7 | export DEVICE='cpu' 8 | export EPOCHS='10' 9 | export LOSS_FUNCTION='MSE' 10 | export OPTIMIZER='Adagrad' 11 | export BATCH_SIZE='100' 12 | export VALIDATION_SPLIT='0.2' 13 | export TEST_SPLIT='0.1' 14 | export TIME_STEPS='1' 15 | export VALIDATION_PERIOD='5' 16 | -------------------------------------------------------------------------------- /parameters/inference-parameters.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export DATASET_NAME='direct_neighbour' 3 | export DATA_DIRECTORY='data/the_100/' 4 | export MODEL_DIRECTORY='model_checkpoints/configuration&id__model&RNN__epochs&1000__loss_function&MSE__optimizer&Adagrad__batch_size&100__validation_split&0.2__test_split&0.1__time_steps&1__validation_period&5/Epoch_625_model_state_dictionary.pth' 5 | export RESULTS_DIRECTORY='results_inference' 6 | export MODEL='RNN' 7 | export DEVICE='cpu' 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | numpy==1.22.0 3 | pandas==1.0.3 4 | torch==1.13.1 5 | tqdm -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, Extension 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 3 | 4 | setup( 5 | name='message-passing-nn', 6 | version='1.6.0', 7 | packages=find_packages(exclude=["tests"]), 8 | url='https://github.com/kovanostra/message-passing-nn', 9 | download_url='https://github.com/kovanostra/message-passing-nn/archive/1.6.0.tar.gz', 10 | keywords=['MESSAGE PASSING', 'NEURAL NETWORK', 'RNN', 'GRU'], 11 | license='MIT', 12 | author='Michail Kovanis', 13 | description='A message passing neural network with RNN or GRU units', 14 | install_requires=[ 15 | 'click', 16 | 'numpy==1.22.0', 17 | 'pandas==1.0.3', 18 | 'torch==1.13.1', 19 | 'tqdm' 20 | ], 21 | entry_points={ 22 | 'console_scripts': [ 23 | 'message-passing-nn = message_passing_nn.cli:main' 24 | ], 25 | }, 26 | ext_modules=[ 27 | CppExtension('rnn_encoder_cpp', 28 | sources=['message_passing_nn/graph/rnn_encoder.cpp', 29 | 'message_passing_nn/utils/messages.cpp', 30 | 'message_passing_nn/utils/derivatives.cpp']), 31 | CUDAExtension('rnn_encoder_cuda_cpp', 32 | sources=['message_passing_nn/graph/rnn_encoder_cuda.cpp', 33 | 'message_passing_nn/graph/rnn_encoder_cuda_kernel.cu', 34 | 'message_passing_nn/utils/derivatives.cpp'])], 35 | cmdclass={'build_ext': BuildExtension}, 36 | classifiers=[ 37 | 'Development Status :: 5 - Production/Stable', 38 | 'Intended Audience :: Science/Research', 39 | 'Intended Audience :: Developers', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'Topic :: Software Development :: Libraries :: Python Modules', 42 | 'License :: OSI Approved :: MIT License', 43 | 'Programming Language :: Python :: 3.7', 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /setup_cpu.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, Extension 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | setup( 5 | name='message-passing-nn', 6 | version='1.6.0', 7 | packages=find_packages(exclude=["tests"]), 8 | url='https://github.com/kovanostra/message-passing-nn', 9 | download_url='https://github.com/kovanostra/message-passing-nn/archive/1.6.0.tar.gz', 10 | keywords=['MESSAGE PASSING', 'NEURAL NETWORK', 'RNN', 'GRU'], 11 | license='MIT', 12 | author='Michail Kovanis', 13 | description='A message passing neural network with RNN or GRU units', 14 | install_requires=[ 15 | 'click', 16 | 'numpy==1.17.4', 17 | 'pandas==1.0.3', 18 | 'torch==1.5.0', 19 | 'tqdm' 20 | ], 21 | entry_points={ 22 | 'console_scripts': [ 23 | 'message-passing-nn = message_passing_nn.cli:main' 24 | ], 25 | }, 26 | ext_modules=[ 27 | CppExtension('rnn_encoder_cpp', 28 | sources=['message_passing_nn/graph/rnn_encoder.cpp', 29 | 'message_passing_nn/utils/messages.cpp', 30 | 'message_passing_nn/utils/derivatives.cpp'])], 31 | cmdclass={'build_ext': BuildExtension}, 32 | classifiers=[ 33 | 'Development Status :: 5 - Production/Stable', 34 | 'Intended Audience :: Science/Research', 35 | 'Intended Audience :: Developers', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'Topic :: Software Development :: Libraries :: Python Modules', 38 | 'License :: OSI Approved :: MIT License', 39 | 'Programming Language :: Python :: 3.7', 40 | ], 41 | ) 42 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/test_data_preprocessor.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch as to 4 | from message_passing_nn.infrastructure.graph_dataset import GraphDataset 5 | 6 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 7 | from tests.fixtures.matrices_and_vectors import BASE_GRAPH, BASE_GRAPH_NODE_FEATURES 8 | 9 | 10 | class TestGraphPreprocessor(TestCase): 11 | def setUp(self) -> None: 12 | self.data_preprocessor = DataPreprocessor() 13 | 14 | def test_train_validation_test_split(self): 15 | # Given 16 | dataset_length = 10 17 | features = BASE_GRAPH_NODE_FEATURES 18 | all_neighbors = to.tensor([[1, 2, -1, -1], 19 | [0, 2, -1, -1], 20 | [0, 1, 3, -1], 21 | [2, -1, -1, -1]]) 22 | labels = BASE_GRAPH.view(-1) 23 | dataset = GraphDataset("") 24 | dataset.enable_test_mode() 25 | dataset.dataset = [(features, all_neighbors, labels, i) for i in range(dataset_length)] 26 | train_validation_test_split_expected = [7, 2, 1] 27 | 28 | # When 29 | train_validation_test_split = self.data_preprocessor.train_validation_test_split(dataset, 30 | batch_size=1, 31 | validation_split=0.2, 32 | test_split=0.1) 33 | train_validation_test_split = [len(dataset) for dataset in train_validation_test_split] 34 | 35 | # Then 36 | self.assertEqual(train_validation_test_split_expected, train_validation_test_split) 37 | 38 | def test_extract_data_dimensions(self): 39 | # Given 40 | dataset_length = 1 41 | features = BASE_GRAPH_NODE_FEATURES 42 | all_neighbors = to.tensor([[1, 2, -1, -1], 43 | [0, 2, -1, -1], 44 | [0, 1, 3, -1], 45 | [2, -1, -1, -1]]) 46 | labels = BASE_GRAPH.view(-1) 47 | dataset = GraphDataset("") 48 | dataset.enable_test_mode() 49 | dataset.dataset = [(features, all_neighbors, labels, i) for i in range(dataset_length)] 50 | data_dimensions_expected = (features.size(), labels.size()) 51 | 52 | # When 53 | data_dimensions = self.data_preprocessor.extract_data_dimensions(dataset) 54 | 55 | # Then 56 | self.assertEqual(data_dimensions_expected, data_dimensions) 57 | 58 | def test_flatten_when_sizes_match(self): 59 | # Given 60 | dataset_length = 2 61 | labels = BASE_GRAPH.view(-1) 62 | tensors = to.cat((labels, labels)) 63 | tensors_flattened_expected = tensors.view(-1) 64 | 65 | # When 66 | tensors_flattened = self.data_preprocessor.flatten(tensors, desired_size=dataset_length * len(labels)) 67 | 68 | # Then 69 | self.assertTrue(to.allclose(tensors_flattened_expected, tensors_flattened)) 70 | 71 | def test_flatten_when_sizes_do_not_match(self): 72 | # Given 73 | dataset_length = 3 74 | labels = BASE_GRAPH.view(-1) 75 | tensors = to.cat((labels, labels)) 76 | tensors_flattened_expected = to.cat((tensors.view(-1), to.zeros_like(labels))) 77 | 78 | # When 79 | tensors_flattened = self.data_preprocessor.flatten(tensors, desired_size=dataset_length * len(labels)) 80 | 81 | # Then 82 | self.assertTrue(to.allclose(tensors_flattened_expected, tensors_flattened)) 83 | -------------------------------------------------------------------------------- /tests/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/fixtures/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/loss_functions.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | loss_functions = { 4 | "MSE": nn.MSELoss(), 5 | "L1": nn.L1Loss(), 6 | "CrossEntropy": nn.CrossEntropyLoss(), 7 | "CTC": nn.CTCLoss(), 8 | "NLL": nn.NLLLoss(), 9 | "PoissonNLL": nn.PoissonNLLLoss(), 10 | "KLDiv": nn.KLDivLoss(), 11 | "BCE": nn.BCELoss(), 12 | "BCEWithLogits": nn.BCEWithLogitsLoss(), 13 | "MarginRanking": nn.MarginRankingLoss(), 14 | "HingeEmbedding": nn.HingeEmbeddingLoss(), 15 | "MultiLabelMargin": nn.MultiLabelMarginLoss(), 16 | "SmoothL1": nn.SmoothL1Loss(), 17 | "SoftMargin": nn.SoftMarginLoss(), 18 | "MultiLabelSoftMargin": nn.MultiLabelSoftMarginLoss(), 19 | "CosineEmbedding": nn.CosineEmbeddingLoss(), 20 | "MultiMargin": nn.MultiMarginLoss(), 21 | "TripletMargin": nn.TripletMarginLoss() 22 | } 23 | -------------------------------------------------------------------------------- /tests/fixtures/matrices_and_vectors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as to 3 | 4 | BASE_GRAPH = to.tensor([[0, 1, 1, 0], 5 | [1, 0, 1, 0], 6 | [1, 1, 0, 1], 7 | [0, 0, 1, 0]]).float() 8 | BASE_GRAPH_NODE_FEATURES = to.tensor([[1, 2], [1, 1], [2, 0.5], [0.5, 0.5]]).float() 9 | BASE_UNITY_MATRIX = np.ones((BASE_GRAPH_NODE_FEATURES.shape[1], BASE_GRAPH_NODE_FEATURES.shape[1])) 10 | BASE_UNITY_MATRIX_TENSOR = to.tensor(BASE_UNITY_MATRIX).float() 11 | BASE_UNITY_VECTOR = np.ones((BASE_GRAPH_NODE_FEATURES.shape[1])) 12 | BASE_ZEROS_MATRIX = np.zeros((BASE_GRAPH_NODE_FEATURES.shape[1], BASE_GRAPH_NODE_FEATURES.shape[1])) 13 | 14 | BASE_GRAPH_EDGE_FEATURES = to.tensor([[[0.0, 0.0], [1.0, 2.0], [2.0, 0.5], [0.0, 0.0]], 15 | [[1.0, 2.0], [0.0, 0.0], [1.0, 1.0], [0.0, 0.0]], 16 | [[2.0, 0.5], [1.0, 1.0], [0.0, 0.0], [0.5, 0.5]], 17 | [[0.0, 0.0], [0.0, 0.0], [0.5, 0.5], [0.0, 0.0]]]).float() 18 | BASE_W_MATRIX = to.tensor([[BASE_ZEROS_MATRIX, BASE_UNITY_MATRIX, BASE_UNITY_MATRIX, BASE_ZEROS_MATRIX], 19 | [BASE_UNITY_MATRIX, BASE_ZEROS_MATRIX, BASE_UNITY_MATRIX, BASE_ZEROS_MATRIX], 20 | [BASE_UNITY_MATRIX, BASE_UNITY_MATRIX, BASE_ZEROS_MATRIX, BASE_UNITY_MATRIX], 21 | [BASE_ZEROS_MATRIX, BASE_ZEROS_MATRIX, BASE_UNITY_MATRIX, BASE_ZEROS_MATRIX]]).float() 22 | BASE_U_MATRIX = to.tensor([BASE_UNITY_MATRIX, 23 | BASE_UNITY_MATRIX, 24 | BASE_UNITY_MATRIX, 25 | BASE_UNITY_MATRIX]).float() 26 | 27 | BASE_B_VECTOR = to.tensor(BASE_UNITY_VECTOR).float() 28 | 29 | MULTIPLICATION_FACTOR = 0.1 30 | -------------------------------------------------------------------------------- /tests/fixtures/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | optimizers = { 4 | "Adadelta": torch.optim.Adadelta, 5 | "Adagrad": torch.optim.Adagrad, 6 | "Adam": torch.optim.Adam, 7 | "AdamW": torch.optim.AdamW, 8 | "SparseAdam": torch.optim.SparseAdam, 9 | "Adamax": torch.optim.Adamax, 10 | "ASGD": torch.optim.ASGD, 11 | "LBFGS": torch.optim.LBFGS, 12 | "RMSprop": torch.optim.RMSprop, 13 | "Rprop": torch.optim.Rprop, 14 | "SGD": torch.optim.SGD, 15 | } 16 | -------------------------------------------------------------------------------- /tests/graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/graph/__init__.py -------------------------------------------------------------------------------- /tests/graph/test_rnn_encoder.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch as to 4 | from torch import nn 5 | import numpy as np 6 | 7 | from message_passing_nn.graph.rnn_encoder import RNNEncoder 8 | from tests.fixtures.matrices_and_vectors import BASE_GRAPH, BASE_GRAPH_NODE_FEATURES, \ 9 | MULTIPLICATION_FACTOR 10 | import rnn_encoder_cpp as rnn_cpp 11 | 12 | 13 | class TestRNNEncoder(TestCase): 14 | def setUp(self) -> None: 15 | self.number_of_nodes = BASE_GRAPH.size()[0] 16 | self.number_of_node_features = BASE_GRAPH_NODE_FEATURES.size()[1] 17 | self.fully_connected_layer_input_size = self.number_of_nodes * self.number_of_node_features 18 | self.fully_connected_layer_output_size = self.number_of_nodes ** 2 19 | self.device = "cpu" 20 | self.time_steps = 2 21 | self.graph_encoder = RNNEncoder(time_steps=self.time_steps, 22 | number_of_nodes=self.number_of_nodes, 23 | number_of_node_features=self.number_of_node_features, 24 | fully_connected_layer_input_size=self.fully_connected_layer_input_size, 25 | fully_connected_layer_output_size=self.fully_connected_layer_output_size) 26 | self.graph_encoder.w_graph_node_features = nn.Parameter( 27 | MULTIPLICATION_FACTOR * (to.ones((self.number_of_nodes, self.number_of_nodes))), 28 | requires_grad=False) 29 | self.graph_encoder.w_graph_neighbor_messages = nn.Parameter( 30 | MULTIPLICATION_FACTOR * to.ones((self.number_of_nodes, self.number_of_nodes)), 31 | requires_grad=False) 32 | self.graph_encoder.u_graph_node_features = nn.Parameter( 33 | MULTIPLICATION_FACTOR * to.ones((self.number_of_nodes, self.number_of_nodes)), 34 | requires_grad=False) 35 | self.graph_encoder.u_graph_neighbor_messages = nn.Parameter( 36 | MULTIPLICATION_FACTOR * to.ones((self.number_of_node_features, self.number_of_node_features)), 37 | requires_grad=False) 38 | self.graph_encoder.linear_weight = to.nn.Parameter( 39 | MULTIPLICATION_FACTOR * to.ones(self.fully_connected_layer_output_size, 40 | self.fully_connected_layer_input_size), 41 | requires_grad=False).float() 42 | self.graph_encoder.linear_bias = to.nn.Parameter( 43 | MULTIPLICATION_FACTOR * to.tensor([i for i in range(self.fully_connected_layer_output_size)]), 44 | requires_grad=False).float() 45 | 46 | def test_encode_graph_returns_the_expected_encoding_for_a_node_after_one_time_step(self): 47 | # Give 48 | node = 0 49 | time_steps = 1 50 | batch_size = 1 51 | all_neighbors = to.tensor([[[1, 2, -1, -1], 52 | [0, 2, -1, -1], 53 | [0, 1, 3, -1], 54 | [2, -1, -1, -1]]]) 55 | node_encoding_expected = to.tensor([0.6200, 0.5700]) 56 | 57 | # When 58 | _, _, encodings, _, _ = rnn_cpp.forward(to.tensor(time_steps), 59 | to.tensor(self.number_of_nodes), 60 | to.tensor(self.number_of_node_features), 61 | to.tensor(self.fully_connected_layer_output_size), 62 | to.tensor(batch_size), 63 | BASE_GRAPH_NODE_FEATURES.unsqueeze(0), 64 | all_neighbors, 65 | self.graph_encoder.w_graph_node_features, 66 | self.graph_encoder.w_graph_neighbor_messages, 67 | self.graph_encoder.u_graph_node_features, 68 | self.graph_encoder.u_graph_neighbor_messages, 69 | self.graph_encoder.linear_weight, 70 | self.graph_encoder.linear_bias) 71 | node_encodings = encodings[batch_size - 1].view(self.number_of_nodes, self.number_of_node_features)[node] 72 | 73 | # Then 74 | self.assertTrue(to.allclose(node_encoding_expected, node_encodings)) 75 | 76 | def test_encode_graph_returns_the_expected_shape(self): 77 | # Given 78 | time_steps = 1 79 | batch_size = 1 80 | all_neighbors = to.tensor([[[1, 2, -1, -1], 81 | [0, 2, -1, -1], 82 | [0, 1, 3, -1], 83 | [2, -1, -1, -1]]]) 84 | encoded_graph_shape_expected = list(BASE_GRAPH_NODE_FEATURES.view(1, -1).shape) 85 | 86 | # When 87 | _, _, encodings, _, _ = rnn_cpp.forward(to.tensor(time_steps), 88 | to.tensor(self.number_of_nodes), 89 | to.tensor(self.number_of_node_features), 90 | to.tensor(self.fully_connected_layer_output_size), 91 | to.tensor(batch_size), 92 | BASE_GRAPH_NODE_FEATURES.unsqueeze(0), 93 | all_neighbors, 94 | self.graph_encoder.w_graph_node_features, 95 | self.graph_encoder.w_graph_neighbor_messages, 96 | self.graph_encoder.u_graph_node_features, 97 | self.graph_encoder.u_graph_neighbor_messages, 98 | self.graph_encoder.linear_weight, 99 | self.graph_encoder.linear_bias) 100 | 101 | # Then 102 | self.assertEqual(encoded_graph_shape_expected, list(encodings.shape)) 103 | 104 | def test_get_the_expected_messages_from_a_node_after_one_time_step(self): 105 | time_steps = 1 106 | messages_initial = to.zeros((self.number_of_nodes, 107 | self.number_of_nodes, 108 | self.number_of_node_features), 109 | device=self.device) 110 | node_expected = 0 111 | all_neighbors = to.tensor([[1, 2, -1, -1], 112 | [0, 2, -1, -1], 113 | [0, 1, 3, -1], 114 | [2, -1, -1, -1]]) 115 | messages_from_node_expected = to.tensor([[0.00, 0.00], 116 | [0.45, 0.40], 117 | [0.45, 0.40], 118 | [0.00, 0.00]]) 119 | base_messages = to.matmul(self.graph_encoder.w_graph_node_features, BASE_GRAPH_NODE_FEATURES) 120 | # When 121 | messages_from_node, _ = rnn_cpp.compose_messages(time_steps, 122 | self.graph_encoder.number_of_nodes, 123 | self.graph_encoder.number_of_node_features, 124 | self.graph_encoder.w_graph_node_features, 125 | self.graph_encoder.w_graph_neighbor_messages, 126 | base_messages, 127 | all_neighbors, 128 | messages_initial) 129 | 130 | # Then 131 | self.assertTrue(to.allclose(messages_from_node_expected, messages_from_node[node_expected])) 132 | 133 | def test_forward_for_batch_size_one_and_two_steps(self): 134 | # Given 135 | batch_size = 1 136 | number_of_nodes = 4 137 | number_of_node_features = 2 138 | node_features = to.tensor([[[1.0, 2.0], 139 | [1.0, 1.0], 140 | [2.0, 0.5], 141 | [0.5, 0.5]]]) 142 | all_neighbors_input = to.tensor([[[1, 2, -1, -1], 143 | [0, 2, -1, -1], 144 | [0, 1, 3, -1], 145 | [2, -1, -1, -1]]]) 146 | 147 | # Calculations 148 | # -> Pre-loop 149 | batch = 0 150 | all_neighbors = [to.tensor([1, 2]), to.tensor([0, 2]), to.tensor([0, 1, 3]), to.tensor([2])] 151 | neighbors_slice = [[[2], [1]], [[2], [0]], [[1, 3], [0, 3], [0, 1], [1, 3], [0, 3], [0, 1]]] 152 | messages_init = to.zeros((self.number_of_nodes, self.number_of_nodes, self.number_of_node_features), 153 | device=self.device) 154 | 155 | # -> Step 0 156 | # Initialization 157 | messages_step_0_part_1 = to.zeros((self.number_of_nodes, self.number_of_nodes, self.number_of_node_features), 158 | device=self.device) 159 | messages_step_0_part_2 = to.zeros((self.number_of_nodes, self.number_of_nodes, self.number_of_node_features), 160 | device=self.device) 161 | 162 | # -> Step 0 163 | # Calculations 164 | index_pairs_expected = [[0, 2], [0, 1], [1, 2], [1, 0], [2, 1], [2, 3], [2, 0], [2, 3], [2, 0], [2, 1]] 165 | index_pairs, messages_step_0_part_1, messages_step_0_part_2 = self._message_calculations(all_neighbors, 166 | batch, 167 | messages_init, 168 | messages_step_0_part_1, 169 | messages_step_0_part_2, 170 | neighbors_slice, 171 | node_features, 172 | number_of_node_features, 173 | number_of_nodes) 174 | self._assert_step_0_parameters_are_correct(index_pairs, index_pairs_expected, messages_step_0_part_1, 175 | messages_step_0_part_2) 176 | 177 | # -> Step 1 178 | # Messages 179 | messages_step_0 = to.relu(to.add(messages_step_0_part_1, messages_step_0_part_2)) 180 | base_messages = self.graph_encoder.w_graph_node_features.matmul(node_features[0]) 181 | messages_from_model_step_0, _ = rnn_cpp.compose_messages(1, 182 | self.graph_encoder.number_of_nodes, 183 | self.graph_encoder.number_of_node_features, 184 | self.graph_encoder.w_graph_node_features, 185 | self.graph_encoder.w_graph_neighbor_messages, 186 | base_messages, 187 | all_neighbors_input[0], 188 | messages_init) 189 | self.assertTrue(messages_step_0.size() == messages_init.size()) 190 | self.assertTrue(np.allclose(messages_step_0.numpy(), messages_from_model_step_0.numpy(), atol=1e-02)) 191 | print("Passed first step assertions!") 192 | 193 | # -> Step 1 194 | # Initialization 195 | messages_step_1_part_1 = to.zeros((self.number_of_nodes, self.number_of_nodes, self.number_of_node_features), 196 | device=self.device) 197 | messages_step_1_part_2 = to.zeros((self.number_of_nodes, self.number_of_nodes, self.number_of_node_features), 198 | device=self.device) 199 | 200 | # -> Step 1 201 | # Calculations 202 | index_pairs_expected = [[0, 2], [0, 1], [1, 2], [1, 0], [2, 1], [2, 3], [2, 0], [2, 3], [2, 0], [2, 1]] 203 | index_pairs, messages_step_1_part_1, messages_step_1_part_2 = self._message_calculations(all_neighbors, batch, 204 | messages_step_0, 205 | messages_step_1_part_1, 206 | messages_step_1_part_2, 207 | neighbors_slice, 208 | node_features, 209 | number_of_node_features, 210 | number_of_nodes) 211 | self._assert_step_1_parameters_are_correct(index_pairs, index_pairs_expected, messages_step_1_part_1, 212 | messages_step_1_part_2, number_of_nodes) 213 | 214 | # -> Step 2 215 | # Messages 216 | messages_step_1 = to.relu(to.add(messages_step_1_part_1, messages_step_1_part_2)) 217 | base_messages = self.graph_encoder.w_graph_node_features.matmul(node_features[0]) 218 | messages_from_model_step_1, a = rnn_cpp.compose_messages(2, 219 | self.graph_encoder.number_of_nodes, 220 | self.graph_encoder.number_of_node_features, 221 | self.graph_encoder.w_graph_node_features, 222 | self.graph_encoder.w_graph_neighbor_messages, 223 | base_messages, 224 | all_neighbors_input[0], 225 | messages_init) 226 | self.assertTrue(messages_step_1.size() == messages_init.size()) 227 | self.assertTrue(np.allclose(messages_step_1.numpy(), messages_from_model_step_1.numpy(), atol=1e-02)) 228 | print("Passed second step assertions!") 229 | 230 | # -> Sum messages 231 | index_pairs, index_pairs_expected, messages_summed = self._sum_messages(all_neighbors, batch, 232 | messages_step_1, 233 | number_of_node_features, 234 | number_of_nodes) 235 | self.assertTrue(messages_summed.size() == to.empty(number_of_nodes, number_of_node_features).size()) 236 | self.assertTrue(np.array_equal(index_pairs_expected, np.array(index_pairs))) 237 | print("Passed sum messages assertions!") 238 | 239 | # -> Get encodings 240 | encodings = to.relu(to.add(self.graph_encoder.u_graph_node_features.matmul(node_features[batch]), 241 | messages_summed)) 242 | self.assertTrue(encodings.size() == to.empty(number_of_nodes, number_of_node_features).size()) 243 | print("Passed encodings assertions!") 244 | 245 | # -> Pass through fully connected layer 246 | weight = MULTIPLICATION_FACTOR * to.ones(self.fully_connected_layer_output_size, 247 | self.fully_connected_layer_input_size) 248 | bias = MULTIPLICATION_FACTOR * to.tensor([i for i in range(self.fully_connected_layer_output_size)]) 249 | outputs_expected = to.sigmoid(to.add(weight.matmul(encodings.view(batch_size, -1, 1)).squeeze(), bias)) 250 | self.assertTrue(len(outputs_expected) == self.fully_connected_layer_output_size) 251 | print("Passed outputs assertions!") 252 | 253 | # When 254 | outputs = self.graph_encoder.forward(node_features, all_neighbors_input, batch_size) 255 | 256 | # Then 257 | self.assertTrue(np.allclose(outputs_expected.numpy(), outputs.numpy(), atol=1e-02)) 258 | 259 | def _sum_messages(self, all_neighbors, batch, messages_step_1, 260 | number_of_node_features, number_of_nodes): 261 | messages_summed = to.zeros(number_of_nodes, number_of_node_features, device=self.device) 262 | index_pairs_expected = np.array([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1], [2, 3], [3, 2]]) 263 | index_pairs = [] 264 | for node_id in range(number_of_nodes): 265 | if batch + node_id <= len(all_neighbors): 266 | messages_per_node = to.zeros(number_of_nodes, number_of_node_features, device=self.device) 267 | for index in range(len(all_neighbors[batch + node_id])): 268 | end_node_id = all_neighbors[batch + node_id][index].item() 269 | index_pairs.append([node_id, end_node_id]) 270 | messages_per_node[end_node_id] = self.graph_encoder.u_graph_neighbor_messages.matmul( 271 | messages_step_1[end_node_id, node_id]) 272 | messages_summed[node_id] = to.sum(messages_per_node, dim=0) 273 | return index_pairs, index_pairs_expected, messages_summed 274 | 275 | def _assert_step_1_parameters_are_correct(self, index_pairs, index_pairs_expected, messages_step_1_part_1, 276 | messages_step_1_part_2, number_of_nodes): 277 | self.assertTrue(to.allclose(messages_step_1_part_1[0, 1], to.tensor([0.45, 0.40]))) 278 | self.assertTrue(to.allclose(messages_step_1_part_1[1, 2], to.tensor([0.45, 0.40]))) 279 | self.assertTrue(to.allclose(messages_step_1_part_1[2, 3], to.tensor([0.45, 0.40]))) 280 | self.assertTrue(to.allclose(messages_step_1_part_1[3, 2], to.tensor([0.45, 0.40]))) 281 | for node_id in range(number_of_nodes): 282 | for end_node_id in range(number_of_nodes): 283 | if [node_id, end_node_id] in index_pairs_expected: 284 | self.assertTrue(messages_step_1_part_2[node_id, end_node_id][0].item() > 0.0) 285 | self.assertTrue(messages_step_1_part_2[node_id, end_node_id][1].item() > 0.0) 286 | else: 287 | self.assertTrue(to.allclose(messages_step_1_part_2[node_id, end_node_id], to.tensor([0.0, 0.0]))) 288 | self.assertTrue(np.array_equal(index_pairs_expected, index_pairs)) 289 | 290 | def _assert_step_0_parameters_are_correct(self, index_pairs, index_pairs_expected, messages_step_0_part_1, 291 | messages_step_0_part_2): 292 | self.assertTrue(to.allclose(messages_step_0_part_1[0, 1], to.tensor([0.45, 0.40]))) 293 | self.assertTrue(to.allclose(messages_step_0_part_1[1, 2], to.tensor([0.45, 0.40]))) 294 | self.assertTrue(to.allclose(messages_step_0_part_1[2, 3], to.tensor([0.45, 0.40]))) 295 | self.assertTrue(to.allclose(messages_step_0_part_1[3, 2], to.tensor([0.45, 0.40]))) 296 | self.assertTrue(np.array_equal(index_pairs_expected, np.array(index_pairs))) 297 | self.assertTrue(to.allclose(messages_step_0_part_2, to.zeros( 298 | (self.number_of_nodes, self.number_of_nodes, self.number_of_node_features)))) 299 | 300 | def _message_calculations(self, all_neighbors, batch, messages_init, messages_step_0_part_1, messages_step_0_part_2, 301 | neighbors_slice, node_features, number_of_node_features, number_of_nodes): 302 | index_pairs = [] 303 | base_messages = self.graph_encoder.w_graph_node_features.matmul(node_features) 304 | base_neighbor_messages = self.graph_encoder.w_graph_neighbor_messages.matmul(messages_init) 305 | for node_id in range(number_of_nodes): 306 | for index in range(len(all_neighbors[batch + node_id])): 307 | end_node_id = all_neighbors[batch + node_id][index].item() 308 | messages_from_neighbors_step_0 = to.zeros(number_of_node_features, device=self.device) 309 | messages_step_0_part_1[node_id][end_node_id] += base_messages[batch][node_id] 310 | if batch + node_id < len(neighbors_slice): 311 | for node_to_sum in neighbors_slice[batch + node_id][index]: 312 | index_pairs.append([node_id, node_to_sum]) 313 | messages_from_neighbors_step_0 += base_neighbor_messages[node_to_sum, node_id] 314 | messages_step_0_part_2[node_id, end_node_id] = messages_from_neighbors_step_0 315 | if batch + node_id < len(all_neighbors): 316 | pass 317 | return index_pairs, messages_step_0_part_1, messages_step_0_part_2 318 | -------------------------------------------------------------------------------- /tests/infrastructure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/infrastructure/__init__.py -------------------------------------------------------------------------------- /tests/infrastructure/test_file_system_repository.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from os import path 3 | from unittest import TestCase 4 | 5 | from message_passing_nn.infrastructure.file_system_repository import FileSystemRepository 6 | from tests.fixtures.matrices_and_vectors import BASE_GRAPH, BASE_GRAPH_NODE_FEATURES 7 | 8 | 9 | class TestTrainingDataRepository(TestCase): 10 | def setUp(self) -> None: 11 | self.dataset = 'repo-test-data' 12 | self.tests_data_directory = 'tests/test_data/' 13 | self.file_system_repository = FileSystemRepository(self.tests_data_directory, self.dataset) 14 | 15 | def test_save(self): 16 | # Given 17 | features = BASE_GRAPH_NODE_FEATURES 18 | adjacency_matrix = BASE_GRAPH 19 | labels = BASE_GRAPH.view(-1) 20 | 21 | filenames_to_save = ['code_features.pickle', 'code_adjacency-matrix.pickle', 'code_labels.pickle'] 22 | filenames_expected = [self.tests_data_directory + self.dataset + '/code_features.pickle', 23 | self.tests_data_directory + self.dataset + '/code_adjacency-matrix.pickle', 24 | self.tests_data_directory + self.dataset + '/code_labels.pickle'] 25 | 26 | # When 27 | self.file_system_repository.save(filenames_to_save[0], features) 28 | self.file_system_repository.save(filenames_to_save[1], adjacency_matrix) 29 | self.file_system_repository.save(filenames_to_save[2], labels) 30 | 31 | # Then 32 | path.exists(filenames_expected[0]) 33 | path.exists(filenames_expected[1]) 34 | path.exists(filenames_expected[2]) 35 | os.remove(filenames_expected[0]) 36 | os.remove(filenames_expected[1]) 37 | os.remove(filenames_expected[2]) 38 | -------------------------------------------------------------------------------- /tests/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/model/__init__.py -------------------------------------------------------------------------------- /tests/model/test_inferencer.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch as to 4 | from message_passing_nn.infrastructure.graph_dataset import GraphDataset 5 | 6 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 7 | from message_passing_nn.model import Inferencer 8 | from message_passing_nn.utils.model_selector import ModelSelector 9 | from tests.fixtures.matrices_and_vectors import BASE_GRAPH_NODE_FEATURES, BASE_GRAPH 10 | 11 | 12 | class TestInferencer(TestCase): 13 | def test_do_inference(self): 14 | # Given 15 | data_preprocessor = DataPreprocessor() 16 | device = "cpu" 17 | inferencer = Inferencer(data_preprocessor, device) 18 | data_dimensions = (BASE_GRAPH_NODE_FEATURES.size(), BASE_GRAPH.size(), BASE_GRAPH.view(-1).size()) 19 | model = ModelSelector.load_model("RNN") 20 | model = model(time_steps=1, 21 | number_of_nodes=data_dimensions[1][0], 22 | number_of_node_features=data_dimensions[0][1], 23 | fully_connected_layer_input_size=data_dimensions[1][0] * data_dimensions[0][1], 24 | fully_connected_layer_output_size=data_dimensions[2][0]) 25 | all_neighbors = to.tensor([[1, 2, -1, -1], 26 | [0, 2, -1, -1], 27 | [0, 1, 3, -1], 28 | [2, -1, -1, -1]]) 29 | dataset = GraphDataset("") 30 | dataset.enable_test_mode() 31 | tag = 'tag' 32 | dataset.dataset = [(BASE_GRAPH_NODE_FEATURES, all_neighbors, BASE_GRAPH.view(-1), tag)] 33 | inference_data, _, _ = DataPreprocessor().train_validation_test_split(dataset, 1, 0.0, 0.0) 34 | output_label_pairs_expected = [BASE_GRAPH.view(-1), BASE_GRAPH.view(-1)] 35 | 36 | # When 37 | output_label_pairs = inferencer.do_inference(model, inference_data) 38 | 39 | # Then 40 | self.assertEqual(output_label_pairs[0][0].squeeze().size(), output_label_pairs_expected[0].size()) 41 | self.assertEqual(output_label_pairs[0][1].squeeze().size(), output_label_pairs_expected[1].size()) 42 | -------------------------------------------------------------------------------- /tests/model/test_loader.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from message_passing_nn.graph import RNNEncoder 4 | from message_passing_nn.model import Loader 5 | 6 | 7 | class TestLoader(TestCase): 8 | def test_load_model(self): 9 | # Given 10 | loader = Loader("RNN") 11 | data_dimensions = ([4, 2], [16]) 12 | path_to_model = "tests/test_data/model-checkpoints-test/configuration&id__model&" + \ 13 | "RNN__epochs&10__loss_function&MSE__optimizer&Adagrad__batch_size&" + \ 14 | "100__validation_split&0.2__test_split&0.1__time_steps&1__validation_period&" + \ 15 | "5/Epoch_5_model_state_dictionary.pth" 16 | 17 | # When 18 | model = loader.load_model(data_dimensions, path_to_model) 19 | 20 | # Then 21 | self.assertTrue(isinstance(model, RNNEncoder)) 22 | -------------------------------------------------------------------------------- /tests/model/test_trainer.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch as to 4 | from message_passing_nn.infrastructure.graph_dataset import GraphDataset 5 | 6 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 7 | from message_passing_nn.model.trainer import Trainer 8 | from tests.fixtures.matrices_and_vectors import BASE_GRAPH, BASE_GRAPH_NODE_FEATURES 9 | 10 | 11 | class TestTrainer(TestCase): 12 | def setUp(self) -> None: 13 | time_steps = 1 14 | loss_function = "MSE" 15 | optimizer = "SGD" 16 | model = "RNN" 17 | device = "cpu" 18 | self.configuration_dictionary = {"model": model, 19 | "loss_function": loss_function, 20 | "optimizer": optimizer, 21 | "time_steps": time_steps} 22 | data_preprocessor = DataPreprocessor() 23 | self.model_trainer = Trainer(data_preprocessor, device, normalize=True) 24 | 25 | def test_instantiate_attributes(self): 26 | # Given 27 | number_of_nodes = BASE_GRAPH.size()[0] 28 | number_of_node_features = BASE_GRAPH_NODE_FEATURES.size()[1] 29 | data_dimensions = (BASE_GRAPH_NODE_FEATURES.size(), BASE_GRAPH.view(-1).size()) 30 | 31 | # When 32 | self.model_trainer.instantiate_attributes(data_dimensions, self.configuration_dictionary) 33 | 34 | # Then 35 | self.assertTrue(self.model_trainer.model.number_of_nodes == number_of_nodes) 36 | self.assertTrue( 37 | self.model_trainer.model.number_of_node_features == number_of_node_features) 38 | self.assertTrue(self.model_trainer.optimizer.param_groups) 39 | 40 | def test_do_train(self): 41 | # Given 42 | data_dimensions = (BASE_GRAPH_NODE_FEATURES.size(), BASE_GRAPH.view(-1).size()) 43 | self.model_trainer.instantiate_attributes(data_dimensions, 44 | self.configuration_dictionary) 45 | all_neighbors = to.tensor([[1, 2, -1, -1], 46 | [0, 2, -1, -1], 47 | [0, 1, 3, -1], 48 | [2, -1, -1, -1]]) 49 | dataset = GraphDataset("") 50 | dataset.enable_test_mode() 51 | tag = 'tag' 52 | dataset.dataset = [(BASE_GRAPH_NODE_FEATURES, all_neighbors, BASE_GRAPH.view(-1), tag)] 53 | training_data, _, _ = DataPreprocessor().train_validation_test_split(dataset, 1, 0.0, 0.0) 54 | 55 | # When 56 | training_loss = self.model_trainer.do_train(training_data=training_data, epoch=1) 57 | 58 | # Then 59 | self.assertTrue(training_loss > 0.0) 60 | 61 | def test_do_evaluate(self): 62 | # Given 63 | data_dimensions = (BASE_GRAPH_NODE_FEATURES.size(), BASE_GRAPH.view(-1).size()) 64 | self.model_trainer.instantiate_attributes(data_dimensions, 65 | self.configuration_dictionary) 66 | all_neighbors = to.tensor([[1, 2, -1, -1], 67 | [0, 2, -1, -1], 68 | [0, 1, 3, -1], 69 | [2, -1, -1, -1]]) 70 | dataset = GraphDataset("") 71 | dataset.enable_test_mode() 72 | tag = 'tag' 73 | dataset.dataset = [(BASE_GRAPH_NODE_FEATURES, all_neighbors, BASE_GRAPH.view(-1), tag)] 74 | training_data, _, _ = DataPreprocessor().train_validation_test_split(dataset, 1, 0.0, 0.0) 75 | 76 | # When 77 | validation_loss = self.model_trainer.do_evaluate(evaluation_data=training_data, epoch=1) 78 | 79 | # Then 80 | self.assertTrue(validation_loss > 0.0) 81 | -------------------------------------------------------------------------------- /tests/test_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/test_data/__init__.py -------------------------------------------------------------------------------- /tests/test_data/model-checkpoints-test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/test_data/model-checkpoints-test/__init__.py -------------------------------------------------------------------------------- /tests/test_data/model-checkpoints-test/configuration&id__model&RNN__epochs&10__loss_function&MSE__optimizer&Adagrad__batch_size&100__validation_split&0.2__test_split&0.1__time_steps&1__validation_period&5/Epoch_5_model_state_dictionary.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/test_data/model-checkpoints-test/configuration&id__model&RNN__epochs&10__loss_function&MSE__optimizer&Adagrad__batch_size&100__validation_split&0.2__test_split&0.1__time_steps&1__validation_period&5/Epoch_5_model_state_dictionary.pth -------------------------------------------------------------------------------- /tests/test_data/repo-test-data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/test_data/repo-test-data/__init__.py -------------------------------------------------------------------------------- /tests/test_data/training-test-data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/test_data/training-test-data/__init__.py -------------------------------------------------------------------------------- /tests/usecase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/usecase/__init__.py -------------------------------------------------------------------------------- /tests/usecase/test_grid_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from unittest import TestCase 4 | 5 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 6 | from message_passing_nn.model.trainer import Trainer 7 | from message_passing_nn.infrastructure.file_system_repository import FileSystemRepository 8 | from message_passing_nn.usecase.grid_search import GridSearch 9 | from message_passing_nn.utils.saver import Saver 10 | from tests.fixtures.matrices_and_vectors import BASE_GRAPH, BASE_GRAPH_NODE_FEATURES 11 | 12 | 13 | class TestTraining(TestCase): 14 | def setUp(self) -> None: 15 | self.features = BASE_GRAPH_NODE_FEATURES 16 | self.adjacency_matrix = BASE_GRAPH 17 | self.labels = BASE_GRAPH.view(-1) 18 | self.dataset = 'training-test-data' 19 | self.tests_data_directory = 'tests/test_data/' 20 | tests_model_directory = 'tests/model_checkpoints' 21 | tests_results_directory = 'tests/grid_search_results' 22 | device = "cpu" 23 | self.data_path = self.tests_data_directory + self.dataset + "/" 24 | self.repository = FileSystemRepository(self.tests_data_directory, self.dataset) 25 | self.data_preprocessor = DataPreprocessor() 26 | self.data_preprocessor.enable_test_mode() 27 | self.model_trainer = Trainer(self.data_preprocessor, device) 28 | self.saver = Saver(tests_model_directory, tests_results_directory) 29 | 30 | def test_start_for_multiple_batches_of_the_same_size(self): 31 | # Given 32 | dataset_size = 6 33 | grid_search_dictionary = { 34 | "model": ["RNN"], 35 | "epochs": [10], 36 | "batch_size": [3], 37 | "validation_split": [0.2], 38 | "test_split": [0.1], 39 | "loss_function": ["MSE"], 40 | "optimizer": ["SGD"], 41 | "time_steps": [1], 42 | "validation_period": [5] 43 | } 44 | grid_search = GridSearch(self.data_path, 45 | self.data_preprocessor, 46 | self.model_trainer, 47 | grid_search_dictionary, 48 | self.saver, 49 | test_mode=True) 50 | 51 | adjacency_matrix_filenames, features_filenames, labels_filenames = self._save_test_data(dataset_size) 52 | 53 | # When 54 | losses = grid_search.start() 55 | configuration_id = list(losses["training_loss"].keys())[0] 56 | 57 | # Then 58 | self.assertTrue(losses["training_loss"][configuration_id][grid_search_dictionary["epochs"][0]] > 0.0) 59 | self.assertTrue( 60 | losses["validation_loss"][configuration_id][grid_search_dictionary["validation_period"][0]] > 0.0) 61 | self.assertTrue(losses["test_loss"][configuration_id]["final_epoch"] > 0.0) 62 | 63 | # Tear down 64 | self._remove_files(dataset_size, features_filenames, adjacency_matrix_filenames, labels_filenames) 65 | 66 | def test_start_for_multiple_batches_of_differing_size(self): 67 | # Given 68 | dataset_size = 5 69 | grid_search_dictionary = { 70 | "model": ["RNN"], 71 | "epochs": [10], 72 | "batch_size": [3], 73 | "validation_split": [0.2], 74 | "test_split": [0.1], 75 | "loss_function": ["MSE"], 76 | "optimizer": ["SGD"], 77 | "time_steps": [1], 78 | "validation_period": [5] 79 | } 80 | grid_search = GridSearch(self.data_path, 81 | self.data_preprocessor, 82 | self.model_trainer, 83 | grid_search_dictionary, 84 | self.saver, 85 | test_mode=True) 86 | 87 | adjacency_matrix_filenames, features_filenames, labels_filenames = self._save_test_data(dataset_size) 88 | 89 | # When 90 | losses = grid_search.start() 91 | configuration_id = list(losses["training_loss"].keys())[0] 92 | 93 | # Then 94 | self.assertTrue(losses["training_loss"][configuration_id][grid_search_dictionary["epochs"][0]] > 0.0) 95 | self.assertTrue( 96 | losses["validation_loss"][configuration_id][grid_search_dictionary["validation_period"][0]] > 0.0) 97 | self.assertTrue(losses["test_loss"][configuration_id]["final_epoch"] > 0.0) 98 | 99 | # Tear down 100 | self._remove_files(dataset_size, features_filenames, adjacency_matrix_filenames, labels_filenames) 101 | 102 | def test_start_a_grid_search(self): 103 | # Given 104 | dataset_size = 6 105 | grid_search_dictionary = { 106 | "model": ["RNN"], 107 | "epochs": [10, 15], 108 | "batch_size": [3, 4], 109 | "validation_split": [0.2], 110 | "test_split": [0.1], 111 | "loss_function": ["MSE"], 112 | "optimizer": ["SGD"], 113 | "time_steps": [1], 114 | "validation_period": [5] 115 | } 116 | grid_search = GridSearch(self.data_path, 117 | self.data_preprocessor, 118 | self.model_trainer, 119 | grid_search_dictionary, 120 | self.saver, 121 | test_mode=True) 122 | 123 | adjacency_matrix_filenames, features_filenames, labels_filenames = self._save_test_data(dataset_size) 124 | 125 | # When 126 | losses = grid_search.start() 127 | configuration_id = list(losses["training_loss"].keys())[0] 128 | 129 | # Then 130 | self.assertTrue(losses["training_loss"][configuration_id][grid_search_dictionary["epochs"][0]] > 0.0) 131 | self.assertTrue( 132 | losses["validation_loss"][configuration_id][grid_search_dictionary["validation_period"][0]] > 0.0) 133 | self.assertTrue(losses["test_loss"][configuration_id]["final_epoch"] > 0.0) 134 | 135 | # Tear down 136 | self._remove_files(dataset_size, features_filenames, adjacency_matrix_filenames, labels_filenames) 137 | 138 | def _save_test_data(self, dataset_size): 139 | features_filenames = [str(i) + '_training_features' + '.pickle' for i in range(dataset_size)] 140 | adjacency_matrix_filenames = [str(i) + '_training_adjacency-matrix' '.pickle' for i in range(dataset_size)] 141 | labels_filenames = [str(i) + '_training_labels' '.pickle' for i in range(dataset_size)] 142 | for i in range(dataset_size): 143 | self.repository.save(features_filenames[i], self.features) 144 | self.repository.save(adjacency_matrix_filenames[i], self.adjacency_matrix) 145 | self.repository.save(labels_filenames[i], self.labels) 146 | return adjacency_matrix_filenames, features_filenames, labels_filenames 147 | 148 | def _remove_files(self, 149 | dataset_size: int, 150 | features_filenames: List[str], 151 | adjacency_matrix_filenames: List[str], 152 | labels_filenames: List[str]) -> None: 153 | for i in range(dataset_size): 154 | os.remove(self.tests_data_directory + self.dataset + "/" + features_filenames[i]) 155 | os.remove(self.tests_data_directory + self.dataset + "/" + adjacency_matrix_filenames[i]) 156 | os.remove(self.tests_data_directory + self.dataset + "/" + labels_filenames[i]) 157 | -------------------------------------------------------------------------------- /tests/usecase/test_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from unittest import TestCase 4 | import shutil 5 | from datetime import datetime 6 | import torch as to 7 | 8 | from message_passing_nn.data.data_preprocessor import DataPreprocessor 9 | from message_passing_nn.model import Loader, Inferencer 10 | from message_passing_nn.infrastructure.file_system_repository import FileSystemRepository 11 | from message_passing_nn.usecase import Inference 12 | from message_passing_nn.utils.saver import Saver 13 | 14 | 15 | class TestInference(TestCase): 16 | def test_start(self): 17 | # Given 18 | dataset_size = 1 19 | features = to.ones(4, 2) 20 | adjacency_matrix = to.ones(4, 4) 21 | labels = to.ones(16) 22 | dataset = 'inference-test-data' 23 | tests_data_directory = 'tests/test_data/' 24 | tests_model_directory = "tests/test_data/model-checkpoints-test/configuration&id__model&" + \ 25 | "RNN__epochs&10__loss_function&MSE__optimizer&Adagrad__batch_size&" + \ 26 | "100__validation_split&0.2__test_split&0.1__time_steps&1__validation_period&" + \ 27 | "5/Epoch_5_model_state_dictionary.pth" 28 | tests_results_directory = 'tests/results_inference' 29 | device = "cpu" 30 | repository = FileSystemRepository(tests_data_directory, dataset) 31 | data_path = tests_data_directory + dataset + "/" 32 | data_preprocessor = DataPreprocessor() 33 | data_preprocessor.enable_test_mode() 34 | loader = Loader("RNN") 35 | inferencer = Inferencer(data_preprocessor, device) 36 | saver = Saver(tests_model_directory, tests_results_directory) 37 | inference = Inference(data_path, 38 | data_preprocessor, 39 | loader, 40 | inferencer, 41 | saver, 42 | test_mode=True) 43 | 44 | adjacency_matrix_filenames, features_filenames, labels_filenames = self._save_test_data(adjacency_matrix, 45 | dataset_size, 46 | features, 47 | labels, 48 | repository) 49 | 50 | # When 51 | inference.start() 52 | 53 | # Then 54 | filename_expected = datetime.now().strftime("%d-%b-%YT%H_%M") + "_distance_maps.pickle" 55 | self.assertTrue(os.path.isfile(tests_results_directory + "/" + filename_expected)) 56 | 57 | # Tear down 58 | self._remove_files(dataset_size, 59 | features_filenames, 60 | adjacency_matrix_filenames, 61 | labels_filenames, 62 | tests_data_directory, 63 | dataset, 64 | tests_results_directory) 65 | 66 | @staticmethod 67 | def _save_test_data(adjacency_matrix, dataset_size, features, labels, repository): 68 | features_filenames = [str(i) + '_training_features' + '.pickle' for i in range(dataset_size)] 69 | adjacency_matrix_filenames = [str(i) + '_training_adjacency-matrix' '.pickle' for i in range(dataset_size)] 70 | labels_filenames = [str(i) + '_training_labels' '.pickle' for i in range(dataset_size)] 71 | for i in range(dataset_size): 72 | repository.save(features_filenames[i], features) 73 | repository.save(adjacency_matrix_filenames[i], adjacency_matrix) 74 | repository.save(labels_filenames[i], labels) 75 | return adjacency_matrix_filenames, features_filenames, labels_filenames 76 | 77 | @staticmethod 78 | def _remove_files(dataset_size: int, 79 | features_filenames: List[str], 80 | adjacency_matrix_filenames: List[str], 81 | labels_filenames: List[str], 82 | tests_data_directory: str, 83 | dataset: str, 84 | tests_results_directory: str) -> None: 85 | for i in range(dataset_size): 86 | os.remove(tests_data_directory + dataset + "/" + features_filenames[i]) 87 | os.remove(tests_data_directory + dataset + "/" + adjacency_matrix_filenames[i]) 88 | os.remove(tests_data_directory + dataset + "/" + labels_filenames[i]) 89 | shutil.rmtree(tests_results_directory) 90 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kovanostra/message-passing-neural-network/5aec08bda54158d62fae220482e3eb7933436057/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_grid_search_parameters_parser.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from message_passing_nn.utils.grid_search_parameters_parser import GridSearchParametersParser 4 | 5 | 6 | class TestGridSearchParametersParser(TestCase): 7 | def test_get_grid_search_dictionary(self): 8 | # Given 9 | model = "RNN&GRU" 10 | epochs = "10&15&5" 11 | loss_function_selection = "MSE&CrossEntropy" 12 | optimizer_selection = "SGD&Adam" 13 | batch_size = "5" 14 | validation_split = "0.2" 15 | test_split = "0.1&0.2&2" 16 | time_steps = "10" 17 | validation_period = "5" 18 | grid_search_dictionary_expected = { 19 | "model": ["RNN", "GRU"], 20 | "epochs": [10, 11, 12, 13, 15], 21 | "loss_function": ["MSE", "CrossEntropy"], 22 | "optimizer": ["SGD", "Adam"], 23 | "batch_size": [5], 24 | "validation_split": [0.2], 25 | "test_split": [0.1, 0.2], 26 | "time_steps": [10], 27 | "validation_period": [5], 28 | } 29 | 30 | # When 31 | grid_search_dictionary = GridSearchParametersParser().get_grid_search_dictionary( 32 | model, 33 | epochs, 34 | loss_function_selection, 35 | optimizer_selection, 36 | batch_size, 37 | validation_split, 38 | test_split, 39 | time_steps, 40 | validation_period) 41 | 42 | # Then 43 | self.assertEqual(grid_search_dictionary_expected, grid_search_dictionary) 44 | --------------------------------------------------------------------------------