├── run.sh ├── resources ├── saved │ ├── karate_club.gif │ ├── train_karate_animation.mp4 │ ├── cpp_predicted.txt │ ├── py_predicted.txt │ └── saved_weights.txt ├── meta.ucidata-zachary ├── out.ucidata-zachary └── README.ucidata-zachary ├── Requirements.txt ├── Cycles_Graph_Level_Classification ├── resources │ ├── CYCLE │ │ ├── CYCLE_graph_labels.txt │ │ ├── CYCLE_node_labels.txt │ │ ├── CYCLE_graph_indicator.txt │ │ └── CYCLE_A.txt │ └── saved │ │ ├── train_CYCLE_animation.gif │ │ ├── train_CYCLE_animation.mp4 │ │ ├── cpp_predicted_cycle.txt │ │ ├── py_predicted_cycle.txt │ │ └── saved_weights_cycle.txt ├── GCN_Graph_Level_Classification_Architecture.png ├── README.md ├── arguments.py ├── main.cpp ├── GCN_Model.h ├── utils.py ├── layers.h ├── model.py ├── data.h ├── Train_Cycles.py ├── Cycle_Compare_Predictions.ipynb └── .ipynb_checkpoints │ └── Cycle_Compare_Predictions-checkpoint.ipynb ├── get_karate_club.sh ├── CMakeLists.txt ├── arguments.py ├── utils.py ├── GCN_Model.h ├── main.cpp ├── README.md ├── layers.h ├── model.py ├── data.h ├── train.py └── Compare_Predictions.ipynb /run.sh: -------------------------------------------------------------------------------- 1 | python train.py 2 | g++ main.cpp -I eigen -std=c++17 3 | ./a.out 4 | 5 | -------------------------------------------------------------------------------- /resources/saved/karate_club.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnirudhDagar/MessagePassing_for_GNNs/HEAD/resources/saved/karate_club.gif -------------------------------------------------------------------------------- /Requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | torch==1.0.0 3 | ConfigArgParse==0.13.0 4 | matplotlib==3.0.1 5 | imageio==2.4.1 6 | celluloid==0.2.0 7 | -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/CYCLE/CYCLE_graph_labels.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 0 4 | 1 5 | 1 6 | 0 7 | 1 8 | 0 9 | 1 10 | 0 11 | 1 12 | 0 13 | 0 -------------------------------------------------------------------------------- /resources/saved/train_karate_animation.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnirudhDagar/MessagePassing_for_GNNs/HEAD/resources/saved/train_karate_animation.mp4 -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/saved/train_CYCLE_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnirudhDagar/MessagePassing_for_GNNs/HEAD/Cycles_Graph_Level_Classification/resources/saved/train_CYCLE_animation.gif -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/saved/train_CYCLE_animation.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnirudhDagar/MessagePassing_for_GNNs/HEAD/Cycles_Graph_Level_Classification/resources/saved/train_CYCLE_animation.mp4 -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/GCN_Graph_Level_Classification_Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnirudhDagar/MessagePassing_for_GNNs/HEAD/Cycles_Graph_Level_Classification/GCN_Graph_Level_Classification_Architecture.png -------------------------------------------------------------------------------- /get_karate_club.sh: -------------------------------------------------------------------------------- 1 | curl -OL http://konect.cc/files/download.tsv.ucidata-zachary.tar.bz2 2 | tar xvzf download.tsv.ucidata-zachary.tar.bz2 3 | rm download.tsv.ucidata-zachary.tar.bz2 4 | mv ucidata-zachary resources 5 | 6 | echo "Successfully Downloaded the Karate Club Dataset" 7 | -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/saved/cpp_predicted_cycle.txt: -------------------------------------------------------------------------------- 1 | 0.216237 0.783763 2 | 0.210187 0.789813 3 | 0.683550 0.316450 4 | 0.144340 0.855660 5 | 0.191623 0.808377 6 | 0.655664 0.344336 7 | 0.064254 0.935746 8 | 0.636538 0.363462 9 | 0.074136 0.925864 10 | 0.768969 0.231031 11 | 0.288318 0.711682 12 | 0.882303 0.117697 13 | 0.520833 0.479167 14 | -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/CYCLE/CYCLE_node_labels.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 2 4 | 1 5 | 2 6 | 2 7 | 1 8 | 1 9 | 1 10 | 1 11 | 2 12 | 1 13 | 1 14 | 1 15 | 2 16 | 1 17 | 2 18 | 1 19 | 2 20 | 2 21 | 2 22 | 1 23 | 1 24 | 2 25 | 2 26 | 1 27 | 2 28 | 1 29 | 1 30 | 1 31 | 2 32 | 1 33 | 2 34 | 2 35 | 2 36 | 2 37 | 1 38 | 1 39 | 2 40 | 1 41 | 1 42 | 2 43 | 1 44 | 2 45 | 2 46 | 1 47 | 1 48 | 2 49 | 1 50 | 2 51 | 2 52 | 1 53 | 1 54 | 2 55 | 2 56 | 1 57 | 2 58 | 1 59 | 1 60 | 1 -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/CYCLE/CYCLE_graph_indicator.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 1 4 | 1 5 | 2 6 | 2 7 | 2 8 | 2 9 | 3 10 | 3 11 | 3 12 | 3 13 | 3 14 | 4 15 | 4 16 | 4 17 | 4 18 | 4 19 | 4 20 | 4 21 | 4 22 | 5 23 | 5 24 | 5 25 | 5 26 | 6 27 | 6 28 | 6 29 | 7 30 | 7 31 | 7 32 | 7 33 | 7 34 | 7 35 | 7 36 | 7 37 | 7 38 | 7 39 | 8 40 | 8 41 | 8 42 | 8 43 | 9 44 | 9 45 | 9 46 | 9 47 | 9 48 | 10 49 | 10 50 | 10 51 | 11 52 | 11 53 | 11 54 | 11 55 | 12 56 | 12 57 | 12 58 | 12 59 | 13 60 | 13 -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.2) 2 | 3 | set(CMAKE_CXX_STANDARD 17) 4 | 5 | project(Message_Passing_NN VERSION 1.0) 6 | 7 | 8 | INCLUDE_DIRECTORIES(eigen/) 9 | 10 | add_custom_target(copy-runtime-files ALL 11 | COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/resources ${CMAKE_BINARY_DIR}/resources 12 | DEPENDS ${MY_TARGET}) 13 | 14 | set(SOURCE_FILES main.cpp data.h layers.h GCN_Model.h) 15 | add_executable(out ${SOURCE_FILES}) 16 | 17 | install(TARGETS out DESTINATION bin) 18 | -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/saved/py_predicted_cycle.txt: -------------------------------------------------------------------------------- 1 | 0.21623682975769043 0.7837631702423096 2 | 0.2101866900920868 0.7898133397102356 3 | 0.6835503578186035 0.31644967198371887 4 | 0.14434026181697845 0.8556597232818604 5 | 0.19162258505821228 0.8083773851394653 6 | 0.6556643843650818 0.3443356454372406 7 | 0.06425371021032333 0.9357463121414185 8 | 0.636538028717041 0.36346200108528137 9 | 0.07413610070943832 0.9258638620376587 10 | 0.7689694166183472 0.23103059828281403 11 | 0.2883181571960449 0.7116818428039551 12 | 0.882302463054657 0.11769749969244003 13 | 0.5208332538604736 0.47916674613952637 14 | 15 | 16 | -------------------------------------------------------------------------------- /resources/saved/cpp_predicted.txt: -------------------------------------------------------------------------------- 1 | 5.790986 1.682824 2 | 3.795149 1.501273 3 | 3.055941 2.654548 4 | 3.595539 1.022943 5 | 3.677935 0.660215 6 | 4.157325 0.904744 7 | 4.097245 0.990750 8 | 2.937612 1.044070 9 | 1.837216 2.591389 10 | 1.233738 2.049130 11 | 3.487444 0.730906 12 | 2.713360 0.225547 13 | 2.820468 0.403513 14 | 2.870050 1.638387 15 | 0.700655 2.673667 16 | 0.886203 2.663809 17 | 3.327699 0.818192 18 | 2.493465 0.605860 19 | 0.748322 2.847911 20 | 2.208598 1.519934 21 | 0.766496 2.694864 22 | 2.676193 0.703044 23 | 0.801118 2.665944 24 | 1.172814 3.570558 25 | 1.211359 2.251581 26 | 1.030864 2.400856 27 | 0.376341 3.371932 28 | 1.399620 2.712790 29 | 1.218853 2.286704 30 | 0.655453 3.944298 31 | 1.436239 2.537111 32 | 1.889816 3.041958 33 | 1.257894 5.096536 34 | 1.657045 5.762127 35 | -------------------------------------------------------------------------------- /resources/meta.ucidata-zachary: -------------------------------------------------------------------------------- 1 | name: Zachary karate club 2 | code: ZA 3 | url: http://vlado.fmf.uni-lj.si/pub/networks/data/ucinet/ucidata.htm#zachary 4 | category: HumanSocial 5 | description: Member–member ties 6 | long-description: This is the well-known and much-used Zachary karate club network. The data was collected from the members of a university karate club by Wayne Zachary in 1977. Each node represents a member of the club, and each edge represents a tie between two members of the club. The network is undirected. An often discussed problem using this dataset is to find the two groups of people into which the karate club split after an argument between two teachers. 7 | entity-names: member 8 | relationship-names: tie 9 | extr: ucidata 10 | cite: konect:ucidata-zachary 11 | timeiso: 1977 12 | -------------------------------------------------------------------------------- /resources/out.ucidata-zachary: -------------------------------------------------------------------------------- 1 | % sym unweighted 2 | % 78 34 34 3 | 1 2 4 | 1 3 5 | 2 3 6 | 1 4 7 | 2 4 8 | 3 4 9 | 1 5 10 | 1 6 11 | 1 7 12 | 5 7 13 | 6 7 14 | 1 8 15 | 2 8 16 | 3 8 17 | 4 8 18 | 1 9 19 | 3 9 20 | 3 10 21 | 1 11 22 | 5 11 23 | 6 11 24 | 1 12 25 | 1 13 26 | 4 13 27 | 1 14 28 | 2 14 29 | 3 14 30 | 4 14 31 | 6 17 32 | 7 17 33 | 1 18 34 | 2 18 35 | 1 20 36 | 2 20 37 | 1 22 38 | 2 22 39 | 24 26 40 | 25 26 41 | 3 28 42 | 24 28 43 | 25 28 44 | 3 29 45 | 24 30 46 | 27 30 47 | 2 31 48 | 9 31 49 | 1 32 50 | 25 32 51 | 26 32 52 | 29 32 53 | 3 33 54 | 9 33 55 | 15 33 56 | 16 33 57 | 19 33 58 | 21 33 59 | 23 33 60 | 24 33 61 | 30 33 62 | 31 33 63 | 32 33 64 | 9 34 65 | 10 34 66 | 14 34 67 | 15 34 68 | 16 34 69 | 19 34 70 | 20 34 71 | 21 34 72 | 23 34 73 | 24 34 74 | 27 34 75 | 28 34 76 | 29 34 77 | 30 34 78 | 31 34 79 | 32 34 80 | 33 34 81 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | def buildParser(): 4 | parser = configargparse.ArgParser() 5 | 6 | # Data setup 7 | parser.add('--datapath', default='resources/out.ucidata-zachary', help='Saved Path Destination') 8 | parser.add('--savepath', default='resources/saved/', help='Saved Path Destination') 9 | parser.add('-w', '--weight_filename', default='saved_weights', help='Saved weights Destination') 10 | parser.add('-p', '--predictions_filename', default='py_predicted', help='Output predictions for comparision') 11 | 12 | # Training setup 13 | parser.add('-s', '--seed', help='Seed for random number generation', type=int, default=2020) 14 | parser.add('-e', '--epochs', help='Number of epochs', type=int, default=400) 15 | parser.add('--no_vis', dest='no_vis', action='store_true', help='Do not create Visualization') 16 | parser.add('--print_freq', help='Frequency of printing updates between epochs', type=int, default=20) 17 | 18 | # Optimizer setup 19 | parser.add('-l', '--lr', help='Learning rate', type=float, default=0.01) 20 | parser.add('-m', '--momentum', help='Momentum of optimizer', type=float, default=0.9) 21 | 22 | # Model setup 23 | parser.add('--hid', help='hidden dimension', type=int, default=10) 24 | parser.add('--out', help='out dimension', type=int, default=2) 25 | 26 | return parser 27 | -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/README.md: -------------------------------------------------------------------------------- 1 | ## GCN Graph Level Classification on CYCLE dataset. 2 | 3 | The GCN model used follows the following architecture with an added pooling layer and a FC layer at the end. 4 | 5 |

6 | 7 |

8 | 9 | ### Training Visualization 10 | 11 |

12 | 13 |

14 |

15 | Representation Space of the graph embeddings learnt over 400 epochs 16 |

17 | 18 | The predicted outputs from the Python forward pass and C++ forward pass are saved in `resources/saved/`. As expected, they are almost identical and the similarity between both can be seen quantitatively as well visually with scatter plots in `Cycle_Compare_Predictions.ipynb`. 19 | 20 | ## Usage 21 | ### Step 1 22 | 23 | ``` 24 | # Clone the repository 25 | git clone https://github.com/AnirudhDagar/MessagePassing_for_GNNs.git 26 | ``` 27 | 28 | ### Step 2 29 | ``` 30 | # Train the model and save the weights. 31 | python Train_Cycles.py 32 | 33 | ``` 34 | 35 | ### Step 3 36 | ``` 37 | # Use a compiler directly to compile the executables. 38 | g++ main.cpp -I eigen -std=c++17 39 | 40 | ``` 41 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def read_karate(datapath): 5 | with open(datapath) as file_content: 6 | file_content = [i for i in file_content.readlines()] 7 | 8 | # Remove meta info at document starting 9 | file_content = file_content[2:] 10 | 11 | edges = [] 12 | for i in file_content: 13 | edge_tuple = list(map(int, i.strip().split())) 14 | edges.append(edge_tuple) 15 | 16 | return edges 17 | 18 | def create_adjacency(edges, sparse=False): 19 | max_node=0 20 | for edge in edges: 21 | if edge[0] > max_node: 22 | max_node = edge[0] 23 | if edge[1] > max_node: 24 | max_node = edge[1] 25 | 26 | size = max_node 27 | adj = [[0 for i in range(size)] for j in range(size)] 28 | 29 | # Build bi-directional graph adj matrix 30 | for edge in edges: 31 | adj[edge[0]-1][edge[1]-1] = 1 32 | adj[edge[1]-1][edge[0]-1] = 1 33 | 34 | # Convert numpy array to torch tensor 35 | adj_tensor = torch.tensor(adj, dtype=torch.float32) 36 | 37 | ## Use Sparse tensors for larger adjacency matrix 38 | if sparse: 39 | adj_tensor = adj_tensor.to_sparse() 40 | 41 | return adj_tensor 42 | 43 | def zeros(tensor): 44 | if tensor is not None: 45 | tensor.data.fill_(0, dtype=torch.float32) -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/resources/CYCLE/CYCLE_A.txt: -------------------------------------------------------------------------------- 1 | 1, 2 2 | 2, 3 3 | 2, 4 4 | 1, 3 5 | 2, 1 6 | 3, 2 7 | 4, 2 8 | 3, 1 9 | 5, 6 10 | 6, 7 11 | 7, 8 12 | 8, 5 13 | 6, 5 14 | 7, 6 15 | 8, 7 16 | 5, 8 17 | 9, 10 18 | 9, 11 19 | 9, 12 20 | 9, 13 21 | 10, 9 22 | 11, 9 23 | 12, 9 24 | 13, 9 25 | 14, 15 26 | 15, 16 27 | 17, 18 28 | 18, 19 29 | 19, 20 30 | 15, 21 31 | 15, 14 32 | 16, 15 33 | 18, 17 34 | 19, 18 35 | 20, 19 36 | 21, 15 37 | 22, 23 38 | 22, 24 39 | 23, 24 40 | 24, 25 41 | 25, 23 42 | 23, 22 43 | 24, 22 44 | 24, 23 45 | 25, 24 46 | 23, 25 47 | 26, 27 48 | 27, 28 49 | 27, 26 50 | 28, 27 51 | 29, 30 52 | 29, 36 53 | 30, 31 54 | 30, 36 55 | 31, 32 56 | 32, 33 57 | 33, 34 58 | 34, 35 59 | 34, 37 60 | 34, 38 61 | 38, 37 62 | 30, 35 63 | 30, 29 64 | 36, 29 65 | 31, 30 66 | 36, 30 67 | 32, 31 68 | 33, 32 69 | 34, 33 70 | 35, 34 71 | 37, 34 72 | 38, 34 73 | 37, 38 74 | 35, 30 75 | 39, 40 76 | 40, 41 77 | 41, 42 78 | 40, 39 79 | 41, 40 80 | 42, 41 81 | 43, 44 82 | 44, 46 83 | 44, 45 84 | 46, 47 85 | 44, 43 86 | 46, 44 87 | 45, 44 88 | 47, 46 89 | 48, 49 90 | 49, 50 91 | 49, 48 92 | 50, 49 93 | 51, 52 94 | 52, 53 95 | 53, 54 96 | 52, 54 97 | 52, 51 98 | 53, 52 99 | 54, 53 100 | 54, 52 101 | 56, 55 102 | 57, 55 103 | 57, 58 104 | 55, 56 105 | 55, 57 106 | 58, 57 107 | 59, 60 108 | 60, 59 -------------------------------------------------------------------------------- /GCN_Model.h: -------------------------------------------------------------------------------- 1 | // Author: Anirudh Dagar 10/03/20 2 | 3 | /*////////////////////////////////////////////////////////////////////////// 4 | // Definition of the GCN class, which uses the graph Convolutional // 5 | // operator GCNConv from layers.h to implement // 6 | // Semi-supervised Classification with Graph Convolutional Networks" // 7 | // paper used for Message Passing // 8 | // Neural Networks. // 9 | //////////////////////////////////////////////////////////////////////////*/ 10 | 11 | #pragma once 12 | 13 | #ifndef GCN_MODEL 14 | #define GCN_MODEL 15 | 16 | #include 17 | #include 18 | #include "layers.h" 19 | 20 | 21 | class GCN 22 | { 23 | public: 24 | // Default Constructor 25 | GCNConv conv1; 26 | GCNConv conv2; 27 | GCN (Eigen::MatrixXd A, int nfeat, int nhid, int nout, std::vector weights): 28 | conv1(nfeat, nhid, A, weights[0], true), conv2(nhid, nout, A, weights[1], true) 29 | {} 30 | 31 | Eigen::MatrixXd forward(Eigen::MatrixXd x) 32 | { 33 | Eigen::MatrixXd h1 = conv1.forward(x); 34 | h1 = relu(h1); 35 | Eigen::MatrixXd h2 = conv2.forward(h1); 36 | h2 = relu(h2); 37 | 38 | return h2; 39 | } 40 | 41 | }; 42 | 43 | #endif //GCN_MODEL -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/arguments.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | def buildParser(): 4 | parser = configargparse.ArgParser() 5 | 6 | # Data setup 7 | parser.add('--datapath', default='resources/CYCLE/', help='Data Path Destination') 8 | parser.add('--savepath', default='resources/saved/', help='Saved Path Destination') 9 | parser.add('-w', '--weight_filename', default='saved_weights_cycle', help='Saved weights Destination') 10 | parser.add('-p', '--predictions_filename', default='py_predicted_cycle', help='Output predictions for comparision') 11 | 12 | # Training setup 13 | parser.add('-s', '--seed', help='Seed for random number generation', type=int, default=2020) 14 | parser.add('-e', '--epochs', help='Number of epochs', type=int, default=400) 15 | parser.add('--no_vis', dest='no_vis', action='store_true', help='Do not create Visualization') 16 | parser.add('--print_freq', help='Frequency of printing updates between epochs', type=int, default=20) 17 | 18 | # Optimizer setup 19 | parser.add('-l', '--lr', help='Learning rate', type=float, default=0.01) 20 | parser.add('-m', '--momentum', help='Momentum of optimizer', type=float, default=0.9) 21 | 22 | # Model setup 23 | parser.add('--nhid1', help='hidden dimension 1', type=int, default=6) 24 | parser.add('--nhid2', help='hidden dimension 2', type=int, default=4) 25 | parser.add('--out', help='out dimension', type=int, default=2) 26 | 27 | return parser 28 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | // Author: Anirudh Dagar 10/03/20 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "GCN_Model.h" 9 | #include "data.h" 10 | 11 | int main() 12 | { 13 | // Import Karate Club Data 14 | auto adj_karate = read_karate(); 15 | 16 | // Import Weights 17 | std::string address = "resources/saved/saved_weights.txt"; 18 | std::vector weight_vec = getWeights(address); 19 | std::cout<<"Imported PyTorch Trained Weights"< 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "GCN_Model.h" 9 | #include "data.h" 10 | 11 | int main() 12 | { 13 | // Import Cycles Graph Data 14 | auto adj_cycle = get_adj_cycle("resources/CYCLE/CYCLE_A.txt"); 15 | 16 | //Get Graph Indiacator 17 | std::vector gi_vec = get_graph_indicator("resources/CYCLE/CYCLE_graph_indicator.txt"); 18 | 19 | //Get Graph Labels 20 | std::vector gl_vec = get_graph_label("resources/CYCLE/CYCLE_graph_labels.txt"); 21 | 22 | // Import Weights 23 | std::vector weight_vec = getWeights_Cycle("resources/saved/saved_weights_cycle.txt"); 24 | std::cout<<"Imported PyTorch Trained Weights"< paper used for Message Passing // 8 | // Neural Networks. // 9 | //////////////////////////////////////////////////////////////////////////*/ 10 | 11 | #pragma once 12 | 13 | #ifndef GCN_MODEL 14 | #define GCN_MODEL 15 | 16 | #include 17 | #include 18 | #include "layers.h" 19 | 20 | 21 | class GCN 22 | { 23 | 24 | private: 25 | std::vector gi_vec; 26 | std::vector weights; 27 | 28 | public: 29 | // Default Constructor 30 | GCNConv conv1; 31 | GCNConv conv2; 32 | GCN (Eigen::MatrixXd A, int nfeat, int nhid, int nhid2, int nout, std::vector weights, std::vector gi_vec): 33 | conv1(nfeat, nhid, A, weights[0], true), conv2(nhid, nhid2, A, weights[1], true) 34 | { 35 | this->weights = weights; 36 | this->gi_vec = gi_vec; 37 | } 38 | 39 | Eigen::MatrixXd forward(Eigen::MatrixXd x) 40 | { 41 | Eigen::MatrixXd h1 = conv1.forward(x); 42 | h1 = relu(h1); 43 | Eigen::MatrixXd h2 = conv2.forward(h1); 44 | h2 = relu(h2); 45 | Eigen::MatrixXd pooled = func_pool(h2, this->gi_vec); 46 | Eigen::MatrixXd fc_out = pooled * weights[2].transpose(); 47 | 48 | return softmax(fc_out); 49 | } 50 | }; 51 | 52 | 53 | 54 | #endif //GCN_MODEL -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Message Passing GNNs C++ 2 | 3 | My experiments with Graph Neural Nets at the scratch level using C++. Graph Convolutional Network (GCN) is one of the most popular GNN architectures and is extremely powerful. Of the popular graph representation learning methods which can be found at [https://github.com/dsgiitr/graph_nets](https://github.com/dsgiitr/graph_nets), this repo aims to implement GCNs in C++. 4 | 5 | ## GCN C++ Forward Pass 6 | 7 | This is a C++ implementation using [Eigen](http://eigen.tuxfamily.org/index.php?title=Main_Page) for the **forward pass** of **Graph Convolutional Neural Networks** . The model doesn't involve any training loops and backpropagation. Pytorch is used to train the GCN model in Python and save the weights learnt after convergence. These saved weights are then imported and used in the C++ implementation of the model for the forward pass on the same dataset. 8 | 9 | The GCN architecture and PyTorch implementation are explained in this [blog](https://dsgiitr.in/blogs/gcn/) are followed. The network is a 2 layer gcn model. 10 | 11 | ### Training Visualization 12 | ![karate animation](resources/saved/karate_club.gif) 13 | 14 | 15 | The predicted outputs from the Python forward pass and C++ forward pass are saved in `resources/saved/`. As expected, they are almost identical and the similarity between both can be seen quantitatively as well visually with scatter plots in `Compare_Predictions.ipynb`. 16 | 17 | ## Usage 18 | ### Step 1 19 | 20 | ``` 21 | # Clone the repository 22 | git clone https://github.com/AnirudhDagar/MessagePassing_for_GNNs.git 23 | 24 | # Download the Karate Club Dataset 25 | bash ./get_karate_club.sh 26 | ``` 27 | 28 | ### Step 2 29 | ``` 30 | # Train the model and save the weights. 31 | python train.py 32 | 33 | ``` 34 | 35 | ### Step 3 36 | ``` 37 | # Use the CMakeLists.txt to build and run the project for C++ implementation. 38 | 39 | # In the source directory 40 | mkdir _build 41 | 42 | # Change dir into _build 43 | cd _build 44 | 45 | # Build the project 46 | cmake .. 47 | make 48 | ``` 49 | 50 | #### OR 51 | ``` 52 | # Use a compiler directly to compile the executables. 53 | g++ main.cpp -I eigen -std=c++17 54 | 55 | ``` 56 | 57 | ### Run at Once 58 | ``` 59 | # Run everything at once. 60 | bash run.sh 61 | 62 | ``` 63 | 64 | 65 | ### Requirements 66 | 67 | ``` 68 | # C++ 69 | eigen 70 | 71 | # Python 72 | numpy==1.18.1 73 | torch==1.0.0 74 | ConfigArgParse==0.13.0 75 | matplotlib==3.0.1 76 | imageio==2.4.1 77 | celluloid==0.2.0 78 | 79 | ``` 80 | 81 | ## Contributing 82 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 83 | -------------------------------------------------------------------------------- /layers.h: -------------------------------------------------------------------------------- 1 | // Author: Anirudh Dagar 10/03/20 2 | 3 | /*////////////////////////////////////////////////////////////////////////// 4 | // Definition of the GCNConv class and relu method // 5 | // for the graph convolutional operator from the paper: // 6 | // Semi-supervised Classification with Graph Convolutional Networks" // 7 | // paper used for Message Passing // 8 | // Neural Networks. // 9 | //////////////////////////////////////////////////////////////////////////*/ 10 | 11 | #pragma once 12 | 13 | #ifndef LAYERS_MPNN 14 | #define LAYERS_MPNN 15 | 16 | #include 17 | #include 18 | 19 | 20 | class GCNConv 21 | { 22 | private: 23 | Eigen::MatrixXd adj; 24 | Eigen::MatrixXd weight; 25 | int in_channels; 26 | int out_channels; 27 | int num_nodes; 28 | 29 | public: 30 | // Constructor 31 | GCNConv(int in_channels, int out_channels, Eigen::MatrixXd adj, Eigen::MatrixXd weight, 32 | bool normalize=true) 33 | { 34 | this->adj = adj; 35 | this->weight = weight; 36 | num_nodes = adj.rows(); 37 | in_channels = in_channels; 38 | out_channels = out_channels; 39 | 40 | // Add self-loops to Adjacecny Matrix 41 | this->adj = this->adj + Eigen::MatrixXd::Identity(num_nodes, num_nodes); 42 | 43 | //Degree Diagonal Matrix D 44 | Eigen::MatrixXd D = Eigen::MatrixXd::Zero(num_nodes, num_nodes); 45 | for(int i=0; iadj.rowwise().sum()(i); 52 | } 53 | } 54 | } 55 | 56 | // Symmetric Normalization of Adjacency Matrix 57 | D = D.inverse(); 58 | D = D.cwiseSqrt(); 59 | this->adj = (D * this->adj) * D; 60 | 61 | // print norm to compare with Pytorch adjacency matrix 62 | // std::cout<<"Norm of Adjacency Matrix after normalization: "<adj.norm()<weight.rows()<<", "<weight.cols(); 69 | 70 | Eigen::MatrixXd xw = x * this->weight; 71 | Eigen::MatrixXd axw = this->adj * xw; 72 | 73 | return axw; 74 | } 75 | 76 | }; 77 | 78 | Eigen::MatrixXd relu(Eigen::MatrixXd &out) 79 | { 80 | return out.array().cwiseMax(0.0); 81 | } 82 | 83 | #endif //LAYERS_MPNN -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def read_karate(datapath): 5 | with open(datapath) as file_content: 6 | file_content = [i for i in file_content.readlines()] 7 | 8 | # Remove meta info at document starting 9 | file_content = file_content[2:] 10 | 11 | edges = [] 12 | for i in file_content: 13 | edge_tuple = list(map(int, i.strip().split())) 14 | edges.append(edge_tuple) 15 | 16 | return edges 17 | 18 | def read_enzyme_A(datapath): 19 | with open(datapath) as file_content: 20 | file_content = [i for i in file_content.readlines()] 21 | 22 | edges = [] 23 | for i in file_content: 24 | i = i.replace(',', '') 25 | edge_tuple = list(map(int, i.strip().split())) 26 | edges.append(edge_tuple) 27 | return edges 28 | 29 | def read_enzyme_graph_label(datapath): 30 | with open(datapath) as file_content: 31 | file_content = [i for i in file_content.readlines()] 32 | 33 | graph_labels = [] 34 | for i in file_content: 35 | label = int(i.strip()) 36 | graph_labels.append(label) 37 | return graph_labels 38 | 39 | def read_enzyme_node_label(datapath): 40 | with open(datapath) as file_content: 41 | file_content = [i for i in file_content.readlines()] 42 | 43 | node_labels = [] 44 | for i in file_content: 45 | label = int(i.strip()) 46 | node_labels.append(label) 47 | return node_labels 48 | 49 | def create_one_hot_feats(node_labels): 50 | feats = np.zeros((len(node_labels), max(node_labels)), dtype=float) 51 | for i in range(len(node_labels)): 52 | feats[i][node_labels[i]-1]=1 53 | 54 | return feats 55 | 56 | def read_enzyme_graph_indicator(datapath): 57 | with open(datapath) as file_content: 58 | file_content = [i for i in file_content.readlines()] 59 | 60 | graph_indicator = [] 61 | for i in file_content: 62 | graph_num = int(i.strip()) 63 | graph_indicator.append(graph_num) 64 | return graph_indicator 65 | 66 | def create_adjacency(edges, sparse=False): 67 | max_node=0 68 | for edge in edges: 69 | if edge[0] > max_node: 70 | max_node = edge[0] 71 | if edge[1] > max_node: 72 | max_node = edge[1] 73 | 74 | size = max_node 75 | adj = [[0 for i in range(size)] for j in range(size)] 76 | 77 | # Build bi-directional graph adj matrix 78 | for edge in edges: 79 | adj[edge[0]-1][edge[1]-1] = 1 80 | adj[edge[1]-1][edge[0]-1] = 1 81 | 82 | # Convert numpy array to torch tensor 83 | adj_tensor = torch.tensor(adj, dtype=torch.float32) 84 | 85 | ## Use Sparse tensors for larger adjacency matrix 86 | if sparse: 87 | adj_tensor = adj_tensor.to_sparse() 88 | 89 | return adj_tensor 90 | 91 | def zeros(tensor): 92 | if tensor is not None: 93 | tensor.data.fill_(0, dtype=torch.float32) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | from torch.nn.init import xavier_uniform_ as xavier 6 | from utils import zeros 7 | 8 | 9 | class GCNConv(nn.Module): 10 | """ 11 | The graph convolutional operator from the "Semi-supervised 12 | Classification with Graph Convolutional Networks" 13 | https://arxiv.org/abs/1609.02907 14 | 15 | Args: 16 | in_channels (int): Size of each input sample. 17 | out_channels (int): Size of each output sample. 18 | bias (bool, optional): The layer will not learn any bias if `False`. 19 | (default: `True`) 20 | normalize (bool, optional): Add Self Loops & Apply symmetric normalization. 21 | (default: `True`) 22 | """ 23 | def __init__(self, adj, in_channels, out_channels, normalize=True, use_bias=False, glorot=False): 24 | super().__init__() 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | 28 | self.weight = Parameter(torch.rand(self.in_channels, self.out_channels, requires_grad=True, dtype=torch.float32)) 29 | 30 | if use_bias: 31 | self.bias = Parameter(torch.FloatTensor(self.out_channels)) 32 | else: 33 | self.register_parameter('bias', None) 34 | 35 | if glorot: 36 | # Set parameter initializations to Glorot 37 | self.reset_parameters() 38 | 39 | if normalize: 40 | # Add self loop. `A = Adjacency Mat + Identity Mat` 41 | self.adj = adj + torch.eye(adj.size(0), dtype=torch.float32) 42 | 43 | # Diagonal Node Degree Matrix D 44 | self.D = torch.diag(torch.sum(self.adj, 1)) 45 | 46 | # Normalize the adjacency matrix using D 47 | self.D = self.D.inverse().sqrt() 48 | self.adj = torch.mm(torch.mm(self.D, self.adj), self.D) 49 | 50 | else: 51 | self.adj = adj 52 | 53 | 54 | def forward(self, x): 55 | x = torch.mm(x, self.weight) 56 | out = torch.mm(self.adj, x) 57 | 58 | return out 59 | 60 | 61 | def reset_parameters(self): 62 | xavier(self.weight) 63 | zeros(self.bias) 64 | print("Glorot Initialized Weights") 65 | 66 | 67 | def __repr__(self): 68 | return '{} (InChannels:{}, OutChannels:{})'.format( 69 | self.__class__.__name__, self.in_channels, 70 | self.out_channels) 71 | 72 | 73 | # NOTE: Skipping the Dropout layer. 74 | class GCN(nn.Module): 75 | def __init__(self, A, nfeat, nhid, nout): 76 | super(GCN, self).__init__() 77 | self.conv1 = GCNConv(A, nfeat, nhid) 78 | self.conv2 = GCNConv(A, nhid, nout) 79 | 80 | def forward(self, x): 81 | h1 = F.relu(self.conv1(x)) 82 | h2 = F.relu(self.conv2(h1)) 83 | 84 | return h2 85 | -------------------------------------------------------------------------------- /data.h: -------------------------------------------------------------------------------- 1 | // Author: Anirudh Dagar 10/03/20 2 | 3 | /*////////////////////////////////////////////////////////////////////////// 4 | // C++ Data utils to read in arbitrary sized graph data, // 5 | // eg: Karate Club Dataset. Also provides functionality to read // 6 | // the trained weight matrices into an Eigen Matrix for the forward // 7 | // pass of Message Passing Neural Networks eg GCN // 8 | //////////////////////////////////////////////////////////////////////////*/ 9 | 10 | #pragma once 11 | 12 | #ifndef DATA_UTILS 13 | #define DATA_UTILS 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | auto read_karate() { 23 | std::string line; 24 | std::ifstream myfile ("resources/out.ucidata-zachary"); 25 | 26 | int edges, num_nodes; 27 | 28 | if (getline (myfile, line)) 29 | { 30 | // Remove the first line 31 | // pass; 32 | } 33 | // Get number of edges and num nodes 34 | if (getline (myfile, line)) 35 | { 36 | std::stringstream temp(line.substr(2,line.length()-2)); 37 | temp>>edges; 38 | temp>>num_nodes; 39 | 40 | } 41 | 42 | Eigen::MatrixXd data = Eigen::MatrixXd::Zero(num_nodes, num_nodes); 43 | for(int i=0; i>x; 46 | myfile>>y; 47 | data(x-1,y-1) = 1; 48 | data(y-1,x-1) = 1; 49 | } 50 | 51 | // std::cout<>rows; 62 | myfile>>cols; 63 | Eigen::MatrixXd weights1 = Eigen::MatrixXd::Zero(rows, cols); 64 | for(int i=0; i> weights1(i,j); 67 | } 68 | } 69 | myfile>>rows; 70 | myfile>>cols; 71 | Eigen::MatrixXd weights2 = Eigen::MatrixXd::Zero(rows, cols); 72 | for(int i=0; i> weights2(i,j); 75 | } 76 | } 77 | myfile.close(); 78 | 79 | std::vector out_weights; 80 | out_weights.push_back(weights1); 81 | out_weights.push_back(weights2); 82 | 83 | return out_weights; 84 | } 85 | 86 | 87 | void writeTofile(std::string name, Eigen::MatrixXd matrix) 88 | { 89 | std::ofstream file(name.c_str()); 90 | 91 | for(int i=0; i paper used for Message Passing // 8 | // Neural Networks. // 9 | //////////////////////////////////////////////////////////////////////////*/ 10 | 11 | #ifndef LAYERS_MPNN 12 | #define LAYERS_MPNN 13 | 14 | #include 15 | #include 16 | 17 | 18 | class GCNConv 19 | { 20 | private: 21 | Eigen::MatrixXd adj; 22 | Eigen::MatrixXd weight; 23 | int in_channels; 24 | int out_channels; 25 | int num_nodes; 26 | 27 | public: 28 | // Constructor 29 | GCNConv(int in_channels, int out_channels, Eigen::MatrixXd adj, Eigen::MatrixXd weight, 30 | bool normalize=true) 31 | { 32 | this->adj = adj; 33 | this->weight = weight; 34 | num_nodes = adj.rows(); 35 | in_channels = in_channels; 36 | out_channels = out_channels; 37 | 38 | // Add self-loops to Adjacecny Matrix 39 | this->adj = this->adj + Eigen::MatrixXd::Identity(num_nodes, num_nodes); 40 | 41 | //Degree Diagonal Matrix D 42 | Eigen::MatrixXd D = Eigen::MatrixXd::Zero(num_nodes, num_nodes); 43 | for(int i=0; iadj.rowwise().sum()(i); 50 | } 51 | } 52 | } 53 | 54 | // Symmetric Normalization of Adjacency Matrix 55 | D = D.inverse(); 56 | D = D.cwiseSqrt(); 57 | this->adj = (D * this->adj) * D; 58 | 59 | // print norm to compare with Pytorch adjacency matrix 60 | // std::cout<<"Norm of Adjacency Matrix after normalization: "<adj.norm()<weight.rows()<<", "<weight.cols(); 67 | 68 | Eigen::MatrixXd xw = x * this->weight; 69 | Eigen::MatrixXd axw = this->adj * xw; 70 | std::cout<<"\n Done Convolution:\n"; 71 | return axw; 72 | } 73 | 74 | }; 75 | 76 | auto func_pool(Eigen::MatrixXd &X, std::vector gi_vec) 77 | { 78 | int nodes = X.rows(); // nodes = 60 79 | int cols = X.cols(); // cols = 4 80 | int num_graphs = *max_element(gi_vec.begin(), gi_vec.end()); //num_graphs = 13 81 | 82 | Eigen::MatrixXd pooled = Eigen::MatrixXd::Zero(num_graphs, cols); 83 | int flag = gi_vec[0]; 84 | 85 | 86 | for(int i=0; i 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | Eigen::MatrixXd get_adj_cycle(std::string str_name) 14 | { 15 | std::string line; 16 | std::ifstream myfile(str_name); 17 | 18 | int edges = 0, num_nodes = 0, maxx = 0; 19 | while(getline (myfile, line)){ 20 | edges++; 21 | int x, y =0; 22 | int ptr = 0; 23 | for(int i=0; iy)?x:y; 32 | num_nodes = (maxx>num_nodes)? maxx: num_nodes; 33 | } 34 | std::ifstream again("CYCLE/CYCLE_A.txt"); 35 | Eigen::MatrixXd data = Eigen::MatrixXd::Zero(num_nodes, num_nodes); 36 | while(getline (again, line)) 37 | { 38 | int x, y =0; 39 | int ptr = 0; 40 | for(int i=0; i>num_weights; 63 | std::vector out_weights; 64 | while(num_weights--) 65 | { 66 | myfile>>rows; 67 | myfile>>cols; 68 | Eigen::MatrixXd weights = Eigen::MatrixXd::Zero(rows, cols); 69 | for(int i=0; i> weights(i,j); 72 | } 73 | } 74 | out_weights.push_back(weights); 75 | } 76 | 77 | std::cout< get_graph_indicator(std::string strname) 85 | { 86 | std::string line; 87 | std::ifstream myfile(strname); 88 | 89 | std::vector vec; 90 | int x; 91 | while(getline (myfile, line)){ 92 | vec.push_back(stoi(line)); 93 | } 94 | return vec; 95 | } 96 | //////////////////////////////////// 97 | 98 | std::vector get_graph_label(std::string strname){ 99 | std::string line; 100 | std::ifstream myfile(strname); 101 | 102 | std::vector vec; 103 | int x; 104 | while(getline (myfile, line)){ 105 | vec.push_back(stoi(line)); 106 | } 107 | return vec; 108 | } 109 | 110 | ////////////////////////////////////// 111 | // Uncomment if node labels given 112 | 113 | // auto get_nl(){ 114 | // std::string line; 115 | // std::ifstream myfile("CYCLE/CYCLE_node_labels.txt"); 116 | 117 | // std::vector vec; 118 | // int x; 119 | // while(getline (myfile, line)){ 120 | // vec.push_back(stoi(line)); 121 | // } 122 | // int len = vec.size(); 123 | // Eigen::MatrixXd data = Eigen::MatrixXd::Zero(len, 3); 124 | // for(int i=0; i" 132 | ] 133 | }, 134 | "execution_count": 5, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | }, 138 | { 139 | "data": { 140 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAPCklEQVR4nO3dX2jd533H8fe3ik00mtYDqxeW3doDR9Qkoy4idPii2ZpNTi4c00IXQy7CsgQKKRsNApuO0KYXdSdWtgsP6u2iUGhTdxjhERcNFpdCqIcVlNTYQcVzk8bHF1FD1Juojex9d6Ejc6Qc+Rw55+9z3i8Q6Pf8Hs7vy4P84efn+Z3fE5mJJKn/faTbBUiSWsNAl6RCGOiSVAgDXZIKYaBLUiHu6taFt2/fnrt37+7W5SWpL73yyiu/zcyReue6Fui7d+9mdna2W5eXpL4UEW9udM4pF0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhuvbFolaYnqswNTPP9cUldmwbZnJijMP7R7tdliR1Rd8G+vRchWOnL7K0fBOAyuISx05fBDDUJQ2kvp1ymZqZvxXmq5aWbzI1M9+liiSpu/o20K8vLtVtrywuceD4S0zPVTpckSR1V98G+o5twxueW51+MdQlDZK+DfTJiTGGtwxteN7pF0mDpm8XRVcXPqdm5qlsMP2y0bSMJJWob+/QYSXUXz76F4xuMP1yu2kZSSpNXwf6qnrTL8NbhpicGOtSRZLUeX075VKrdvrFLxlJGlRFBDqshLoBLmmQFTHlIkky0CWpGAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKkQxXyzqFLe9k9SrDPRNcNs7Sb3MKZdNcNs7Sb3MQN+Ejd6v7nvXJfUCA30TNnq/uu9dl9QLDPRN8L3rknqZi6Kb4HvXJfUyA32TfO+6pF7llIskFcJAl6RCGOiSVAjn0LvA1wdIagcDvcN8fYCkdmlqyiUiDkbEfERciYijdc5/MiLORcRcRPwyIh5pfall8PUBktqlYaBHxBBwAngY2AcciYh967r9A3AqM/cDjwH/2upCS+HrAyS1SzN36A8AVzLzama+D7wAPLquTwIfq/7+ceB660osi68PkNQuzQT6KPBWzfG1alutbwCPR8Q14Czw1XofFBFPR8RsRMwuLCzcQbn9r9HrA6bnKhw4/hJ7jr7IgeMvMT1X6UaZkvpQqx5bPAJ8PzN3Ao8AP4iID3x2Zp7MzPHMHB8ZGWnRpfvL4f2jfPuL9zO6bZgARrcN8+0v3s/h/aO3Fkwri0skKwumf//jV/nMN//LYJfUUDNPuVSAXTXHO6tttZ4EDgJk5i8i4m5gO/B2K4oszUavD6i3YAqwuLTskzCSGmrmDv0CsDci9kTEVlYWPc+s6/Mb4AsAEfFp4G5gMOdUPoTbLYz6JIykRhoGembeAJ4BZoDXWXma5VJEPB8Rh6rdngWeiojXgB8BT2RmtqvoUjVaGPVJGEm309QXizLzLCuLnbVtz9X8fhk40NrSBs/kxNiaLx2t55Mwkm7Hb4r2kNX58W/+5yXefW95zTk30pDUiC/n6jGH948y99xf8c9//Zm6T8JI0ka8Q+9RbqQhabO8Q5ekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEL4TdEBNz1XYWpmnuuLS+zYNszkxJjfUJX6lIE+wFZ3SFp9u2NlccmNNKQ+5pTLAKu3Q5IbaUj9y0AfYBttmOFGGlJ/MtAH2EYbZriRhtSfDPQBNjkxxvCWoTVtbqQh9S8XRQfY6sKnT7lIZTDQB5wbaUjlcMpFkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1Ihmgr0iDgYEfMRcSUijm7Q58sRcTkiLkXED1tbpiSpkYZ7ikbEEHAC+EvgGnAhIs5k5uWaPnuBY8CBzHw3Ij7RroIlSfU1c4f+AHAlM69m5vvAC8Cj6/o8BZzIzHcBMvPt1pYpSWqkmUAfBd6qOb5Wbat1L3BvRLwcEecj4mC9D4qIpyNiNiJmFxYW7qxiSVJdrVoUvQvYCzwIHAH+LSK2re+UmSczczwzx0dGRlp0aUkSNBfoFWBXzfHOaluta8CZzFzOzF8Dv2Il4CVJHdJMoF8A9kbEnojYCjwGnFnXZ5qVu3MiYjsrUzBXW1inJKmBhoGemTeAZ4AZ4HXgVGZeiojnI+JQtdsM8E5EXAbOAZOZ+U67ipYkfVBkZlcuPD4+nrOzs125tiT1q4h4JTPH653zm6KSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIe7qdgFSaabnKkzNzHN9cYkd24aZnBjj8P7RbpelAWCgSy00PVfh2OmLLC3fBKCyuMSx0xcBDHW1nVMuUgtNzczfCvNVS8s3mZqZ71JFGiQGutRC1xeXNtUutZKBLrXQjm3Dm2qXWslAl1pocmKM4S1Da9qGtwwxOTHWpYo0SFwUlVpodeHTp1zUDQa61GKH948a4OoKp1wkqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCNBXoEXEwIuYj4kpEHL1Nvy9FREbEeOtKlCQ1o2GgR8QQcAJ4GNgHHImIfXX63QP8HfA/rS5SktRYM3foDwBXMvNqZr4PvAA8Wqfft4DvAL9vYX2SpCY1E+ijwFs1x9eqbbdExGeBXZn54u0+KCKejojZiJhdWFjYdLGSpI196EXRiPgI8F3g2UZ9M/NkZo5n5vjIyMiHvbQkqUYzgV4BdtUc76y2rboHuA/4WUS8AXwOOOPCqCR1VjOBfgHYGxF7ImIr8BhwZvVkZv4uM7dn5u7M3A2cBw5l5mxbKpYk1dUw0DPzBvAMMAO8DpzKzEsR8XxEHGp3gZKk5jS1p2hmngXOrmt7boO+D374siS12/Rcxc2sC+Mm0dIAmp6rcOz0RZaWbwJQWVzi2OmLAIZ6H/Or/9IAmpqZvxXmq5aWbzI1M9+litQKBro0gK4vLm2qXf3BQJcG0I5tw5tqV38w0KUBNDkxxvCWoTVtw1uGmJwYA1bm2A8cf4k9R1/kwPGXmJ6r1PsY9RgXRaUBtLrwWe8pFxdM+5eBLg2ow/tH6wb07RZMDfTe5pSLpDVcMO1fBrqkNVww7V8GuqQ1Gi2Yqnc5hy5pjdstmK7n6wN6i4Eu6QM2WjCt5dMwvccpF0l3xNcH9B4DXdId8WmY3mOgS7ojPg3Tewx0SXfEp2F6j4uiku7IZp6GUWcY6JLuWDNPw6hznHKRpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiF8OZekvuR+ph9koEvqO+5nWp9TLpL6jvuZ1megS+o77mdan4Euqe+4n2l9BrqkvuN+pvW5KCqp77ifaX1NBXpEHAT+BRgC/j0zj687/zXgb4EbwALwN5n5ZotrlaRb3M/0gxpOuUTEEHACeBjYBxyJiH3rus0B45n5p8B/AP/Y6kIlSbfXzBz6A8CVzLyame8DLwCP1nbIzHOZ+V718Dyws7VlSpIaaSbQR4G3ao6vVds28iTw03onIuLpiJiNiNmFhYXmq5QkNdTSp1wi4nFgHJiqdz4zT2bmeGaOj4yMtPLSkjTwmlkUrQC7ao53VtvWiIiHgK8Dn8/MP7SmPElSs5q5Q78A7I2IPRGxFXgMOFPbISL2A98DDmXm260vU5LUSMNAz8wbwDPADPA6cCozL0XE8xFxqNptCvgo8JOIeDUizmzwcZKkNmnqOfTMPAucXdf2XM3vD7W4LknSJvnVf0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRBuEi1JbTI9V+noRtYGuiS1wfRchWOnL7K0fBOAyuISx05fBGhbqDvlIkltMDUzfyvMVy0t32RqZr5t1zTQJakNri8ubaq9FQx0SWqDHduGN9XeCga6JLXB5MQYw1uG1rQNbxlicmKsbdd0UVSS2mB14dOnXCSpAIf3j7Y1wNdzykWSCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUiMjM7lw4YgF4sysX703bgd92u4ge5xg15hjdXgnj86nMHKl3omuBrrUiYjYzx7tdRy9zjBpzjG6v9PFxykWSCmGgS1IhDPTecbLbBfQBx6gxx+j2ih4f59AlqRDeoUtSIQx0SSqEgd5hEXEwIuYj4kpEHK1z/msRcTkifhkR/x0Rn+pGnd3UaIxq+n0pIjIiin0MrZ5mxicivlz9O7oUET/sdI3d1sS/s09GxLmImKv+W3ukG3W2XGb606EfYAj4X+BPgK3Aa8C+dX3+HPij6u9fAX7c7bp7bYyq/e4Bfg6cB8a7XXcvjQ+wF5gD/rh6/Ilu192DY3QS+Er1933AG92uuxU/3qF31gPAlcy8mpnvAy8Aj9Z2yMxzmfle9fA8sLPDNXZbwzGq+hbwHeD3nSyuBzQzPk8BJzLzXYDMfLvDNXZbM2OUwMeqv38cuN7B+trGQO+sUeCtmuNr1baNPAn8tK0V9Z6GYxQRnwV2ZeaLnSysRzTzN3QvcG9EvBwR5yPiYMeq6w3NjNE3gMcj4hpwFvhqZ0prLzeJ7lER8TgwDny+27X0koj4CPBd4Ikul9LL7mJl2uVBVv6H9/OIuD8zF7taVW85Anw/M/8pIv4M+EFE3JeZ/9ftwj4M79A7qwLsqjneWW1bIyIeAr4OHMrMP3Sotl7RaIzuAe4DfhYRbwCfA84M0MJoM39D14Azmbmcmb8GfsVKwA+KZsboSeAUQGb+AriblRd39TUDvbMuAHsjYk9EbAUeA87UdoiI/cD3WAnzQZv7hAZjlJm/y8ztmbk7M3ezss5wKDNnu1NuxzX8GwKmWbk7JyK2szIFc7WTRXZZM2P0G+ALABHxaVYCfaGjVbaBgd5BmXkDeAaYAV4HTmXmpYh4PiIOVbtNAR8FfhIRr0bE+j/EojU5RgOryfGZAd6JiMvAOWAyM9/pTsWd1+QYPQs8FRGvAT8CnsjqIy/9zK/+S1IhvEOXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQ/w81ztLjYDMkvgAAAABJRU5ErkJggg==\n", 141 | "text/plain": [ 142 | "
" 143 | ] 144 | }, 145 | "metadata": { 146 | "needs_background": "light" 147 | }, 148 | "output_type": "display_data" 149 | } 150 | ], 151 | "source": [ 152 | "plt.scatter(py_pred[:,0], py_pred[:,1])" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "metadata": { 159 | "scrolled": true 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "" 166 | ] 167 | }, 168 | "execution_count": 6, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | }, 172 | { 173 | "data": { 174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAPCklEQVR4nO3dX2jd533H8fe3ik00mtYDqxeW3doDR9Qkoy4idPii2ZpNTi4c00IXQy7CsgQKKRsNApuO0KYXdSdWtgsP6u2iUGhTdxjhERcNFpdCqIcVlNTYQcVzk8bHF1FD1Juojex9d6Ejc6Qc+Rw55+9z3i8Q6Pf8Hs7vy4P84efn+Z3fE5mJJKn/faTbBUiSWsNAl6RCGOiSVAgDXZIKYaBLUiHu6taFt2/fnrt37+7W5SWpL73yyiu/zcyReue6Fui7d+9mdna2W5eXpL4UEW9udM4pF0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhuvbFolaYnqswNTPP9cUldmwbZnJijMP7R7tdliR1Rd8G+vRchWOnL7K0fBOAyuISx05fBDDUJQ2kvp1ymZqZvxXmq5aWbzI1M9+liiSpu/o20K8vLtVtrywuceD4S0zPVTpckSR1V98G+o5twxueW51+MdQlDZK+DfTJiTGGtwxteN7pF0mDpm8XRVcXPqdm5qlsMP2y0bSMJJWob+/QYSXUXz76F4xuMP1yu2kZSSpNXwf6qnrTL8NbhpicGOtSRZLUeX075VKrdvrFLxlJGlRFBDqshLoBLmmQFTHlIkky0CWpGAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKkQxXyzqFLe9k9SrDPRNcNs7Sb3MKZdNcNs7Sb3MQN+Ejd6v7nvXJfUCA30TNnq/uu9dl9QLDPRN8L3rknqZi6Kb4HvXJfUyA32TfO+6pF7llIskFcJAl6RCGOiSVAjn0LvA1wdIagcDvcN8fYCkdmlqyiUiDkbEfERciYijdc5/MiLORcRcRPwyIh5pfall8PUBktqlYaBHxBBwAngY2AcciYh967r9A3AqM/cDjwH/2upCS+HrAyS1SzN36A8AVzLzama+D7wAPLquTwIfq/7+ceB660osi68PkNQuzQT6KPBWzfG1alutbwCPR8Q14Czw1XofFBFPR8RsRMwuLCzcQbn9r9HrA6bnKhw4/hJ7jr7IgeMvMT1X6UaZkvpQqx5bPAJ8PzN3Ao8AP4iID3x2Zp7MzPHMHB8ZGWnRpfvL4f2jfPuL9zO6bZgARrcN8+0v3s/h/aO3Fkwri0skKwumf//jV/nMN//LYJfUUDNPuVSAXTXHO6tttZ4EDgJk5i8i4m5gO/B2K4oszUavD6i3YAqwuLTskzCSGmrmDv0CsDci9kTEVlYWPc+s6/Mb4AsAEfFp4G5gMOdUPoTbLYz6JIykRhoGembeAJ4BZoDXWXma5VJEPB8Rh6rdngWeiojXgB8BT2RmtqvoUjVaGPVJGEm309QXizLzLCuLnbVtz9X8fhk40NrSBs/kxNiaLx2t55Mwkm7Hb4r2kNX58W/+5yXefW95zTk30pDUiC/n6jGH948y99xf8c9//Zm6T8JI0ka8Q+9RbqQhabO8Q5ekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEL4TdEBNz1XYWpmnuuLS+zYNszkxJjfUJX6lIE+wFZ3SFp9u2NlccmNNKQ+5pTLAKu3Q5IbaUj9y0AfYBttmOFGGlJ/MtAH2EYbZriRhtSfDPQBNjkxxvCWoTVtbqQh9S8XRQfY6sKnT7lIZTDQB5wbaUjlcMpFkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1Ihmgr0iDgYEfMRcSUijm7Q58sRcTkiLkXED1tbpiSpkYZ7ikbEEHAC+EvgGnAhIs5k5uWaPnuBY8CBzHw3Ij7RroIlSfU1c4f+AHAlM69m5vvAC8Cj6/o8BZzIzHcBMvPt1pYpSWqkmUAfBd6qOb5Wbat1L3BvRLwcEecj4mC9D4qIpyNiNiJmFxYW7qxiSVJdrVoUvQvYCzwIHAH+LSK2re+UmSczczwzx0dGRlp0aUkSNBfoFWBXzfHOaluta8CZzFzOzF8Dv2Il4CVJHdJMoF8A9kbEnojYCjwGnFnXZ5qVu3MiYjsrUzBXW1inJKmBhoGemTeAZ4AZ4HXgVGZeiojnI+JQtdsM8E5EXAbOAZOZ+U67ipYkfVBkZlcuPD4+nrOzs125tiT1q4h4JTPH653zm6KSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIe7qdgFSaabnKkzNzHN9cYkd24aZnBjj8P7RbpelAWCgSy00PVfh2OmLLC3fBKCyuMSx0xcBDHW1nVMuUgtNzczfCvNVS8s3mZqZ71JFGiQGutRC1xeXNtUutZKBLrXQjm3Dm2qXWslAl1pocmKM4S1Da9qGtwwxOTHWpYo0SFwUlVpodeHTp1zUDQa61GKH948a4OoKp1wkqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCNBXoEXEwIuYj4kpEHL1Nvy9FREbEeOtKlCQ1o2GgR8QQcAJ4GNgHHImIfXX63QP8HfA/rS5SktRYM3foDwBXMvNqZr4PvAA8Wqfft4DvAL9vYX2SpCY1E+ijwFs1x9eqbbdExGeBXZn54u0+KCKejojZiJhdWFjYdLGSpI196EXRiPgI8F3g2UZ9M/NkZo5n5vjIyMiHvbQkqUYzgV4BdtUc76y2rboHuA/4WUS8AXwOOOPCqCR1VjOBfgHYGxF7ImIr8BhwZvVkZv4uM7dn5u7M3A2cBw5l5mxbKpYk1dUw0DPzBvAMMAO8DpzKzEsR8XxEHGp3gZKk5jS1p2hmngXOrmt7boO+D374siS12/Rcxc2sC+Mm0dIAmp6rcOz0RZaWbwJQWVzi2OmLAIZ6H/Or/9IAmpqZvxXmq5aWbzI1M9+litQKBro0gK4vLm2qXf3BQJcG0I5tw5tqV38w0KUBNDkxxvCWoTVtw1uGmJwYA1bm2A8cf4k9R1/kwPGXmJ6r1PsY9RgXRaUBtLrwWe8pFxdM+5eBLg2ow/tH6wb07RZMDfTe5pSLpDVcMO1fBrqkNVww7V8GuqQ1Gi2Yqnc5hy5pjdstmK7n6wN6i4Eu6QM2WjCt5dMwvccpF0l3xNcH9B4DXdId8WmY3mOgS7ojPg3Tewx0SXfEp2F6j4uiku7IZp6GUWcY6JLuWDNPw6hznHKRpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiF8OZekvuR+ph9koEvqO+5nWp9TLpL6jvuZ1megS+o77mdan4Euqe+4n2l9BrqkvuN+pvW5KCqp77ifaX1NBXpEHAT+BRgC/j0zj687/zXgb4EbwALwN5n5ZotrlaRb3M/0gxpOuUTEEHACeBjYBxyJiH3rus0B45n5p8B/AP/Y6kIlSbfXzBz6A8CVzLyame8DLwCP1nbIzHOZ+V718Dyws7VlSpIaaSbQR4G3ao6vVds28iTw03onIuLpiJiNiNmFhYXmq5QkNdTSp1wi4nFgHJiqdz4zT2bmeGaOj4yMtPLSkjTwmlkUrQC7ao53VtvWiIiHgK8Dn8/MP7SmPElSs5q5Q78A7I2IPRGxFXgMOFPbISL2A98DDmXm260vU5LUSMNAz8wbwDPADPA6cCozL0XE8xFxqNptCvgo8JOIeDUizmzwcZKkNmnqOfTMPAucXdf2XM3vD7W4LknSJvnVf0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRBuEi1JbTI9V+noRtYGuiS1wfRchWOnL7K0fBOAyuISx05fBGhbqDvlIkltMDUzfyvMVy0t32RqZr5t1zTQJakNri8ubaq9FQx0SWqDHduGN9XeCga6JLXB5MQYw1uG1rQNbxlicmKsbdd0UVSS2mB14dOnXCSpAIf3j7Y1wNdzykWSCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUiMjM7lw4YgF4sysX703bgd92u4ge5xg15hjdXgnj86nMHKl3omuBrrUiYjYzx7tdRy9zjBpzjG6v9PFxykWSCmGgS1IhDPTecbLbBfQBx6gxx+j2ih4f59AlqRDeoUtSIQx0SSqEgd5hEXEwIuYj4kpEHK1z/msRcTkifhkR/x0Rn+pGnd3UaIxq+n0pIjIiin0MrZ5mxicivlz9O7oUET/sdI3d1sS/s09GxLmImKv+W3ukG3W2XGb606EfYAj4X+BPgK3Aa8C+dX3+HPij6u9fAX7c7bp7bYyq/e4Bfg6cB8a7XXcvjQ+wF5gD/rh6/Ilu192DY3QS+Er1933AG92uuxU/3qF31gPAlcy8mpnvAy8Aj9Z2yMxzmfle9fA8sLPDNXZbwzGq+hbwHeD3nSyuBzQzPk8BJzLzXYDMfLvDNXZbM2OUwMeqv38cuN7B+trGQO+sUeCtmuNr1baNPAn8tK0V9Z6GYxQRnwV2ZeaLnSysRzTzN3QvcG9EvBwR5yPiYMeq6w3NjNE3gMcj4hpwFvhqZ0prLzeJ7lER8TgwDny+27X0koj4CPBd4Ikul9LL7mJl2uVBVv6H9/OIuD8zF7taVW85Anw/M/8pIv4M+EFE3JeZ/9ftwj4M79A7qwLsqjneWW1bIyIeAr4OHMrMP3Sotl7RaIzuAe4DfhYRbwCfA84M0MJoM39D14Azmbmcmb8GfsVKwA+KZsboSeAUQGb+AriblRd39TUDvbMuAHsjYk9EbAUeA87UdoiI/cD3WAnzQZv7hAZjlJm/y8ztmbk7M3ezss5wKDNnu1NuxzX8GwKmWbk7JyK2szIFc7WTRXZZM2P0G+ALABHxaVYCfaGjVbaBgd5BmXkDeAaYAV4HTmXmpYh4PiIOVbtNAR8FfhIRr0bE+j/EojU5RgOryfGZAd6JiMvAOWAyM9/pTsWd1+QYPQs8FRGvAT8CnsjqIy/9zK/+S1IhvEOXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQ/w81ztLjYDMkvgAAAABJRU5ErkJggg==\n", 175 | "text/plain": [ 176 | "
" 177 | ] 178 | }, 179 | "metadata": { 180 | "needs_background": "light" 181 | }, 182 | "output_type": "display_data" 183 | } 184 | ], 185 | "source": [ 186 | "plt.scatter(cpp_pred[:,0], cpp_pred[:,1])" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "### Error" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 7, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "err_arr = py_pred-cpp_pred" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "metadata": { 209 | "scrolled": false 210 | }, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "text/plain": [ 215 | "array([[-1.70242310e-07, 1.70242310e-07],\n", 216 | " [-3.09907913e-07, 3.39710236e-07],\n", 217 | " [ 3.57818604e-07, -3.28016281e-07],\n", 218 | " [ 2.61816978e-07, -2.76718140e-07],\n", 219 | " [-4.14941788e-07, 3.85139465e-07],\n", 220 | " [ 3.84365082e-07, -3.54562759e-07],\n", 221 | " [-2.89789677e-07, 3.12141418e-07],\n", 222 | " [ 2.87170410e-08, 1.08528136e-09],\n", 223 | " [ 1.00709438e-07, -1.37962341e-07],\n", 224 | " [ 4.16618347e-07, -4.01717186e-07],\n", 225 | " [ 1.57196045e-07, -1.57196045e-07],\n", 226 | " [-5.36945343e-07, 4.99692440e-07],\n", 227 | " [ 2.53860474e-07, -2.53860474e-07]])" 228 | ] 229 | }, 230 | "execution_count": 8, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "err_arr" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "### Norm" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 9, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "Norms for python predictions and c++ predictions: 2.948778180317615 2.9487779999525228\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "py_pred_norm = np.linalg.norm(py_pred)\n", 261 | "cpp_pred_norm = np.linalg.norm(cpp_pred)\n", 262 | "\n", 263 | "print(\"Norms for python predictions and c++ predictions:\", py_pred_norm, cpp_pred_norm)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "### np.allclose()" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 10, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "True" 282 | ] 283 | }, 284 | "execution_count": 10, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | } 288 | ], 289 | "source": [ 290 | "# True if two arrays are element-wise equal within a tolerance.\n", 291 | "np.allclose(py_pred, cpp_pred) # default rtol=1e-05 " 292 | ] 293 | } 294 | ], 295 | "metadata": { 296 | "kernelspec": { 297 | "display_name": "Python 3", 298 | "language": "python", 299 | "name": "python3" 300 | }, 301 | "language_info": { 302 | "codemirror_mode": { 303 | "name": "ipython", 304 | "version": 3 305 | }, 306 | "file_extension": ".py", 307 | "mimetype": "text/x-python", 308 | "name": "python", 309 | "nbconvert_exporter": "python", 310 | "pygments_lexer": "ipython3", 311 | "version": "3.6.9" 312 | } 313 | }, 314 | "nbformat": 4, 315 | "nbformat_minor": 2 316 | } 317 | -------------------------------------------------------------------------------- /Cycles_Graph_Level_Classification/.ipynb_checkpoints/Cycle_Compare_Predictions-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import numpy as np\n", 11 | "import os\n", 12 | "import sys\n", 13 | "import matplotlib.pyplot as plt" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "### Load Predictions" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "datapath_py =\"resources/saved/py_predicted_cycle.txt\"\n", 30 | "datapath_cpp =\"resources/saved/cpp_predicted_cycle.txt\"\n", 31 | "\n", 32 | "def read_predictions(datapath):\n", 33 | " with open(datapath) as file_content:\n", 34 | " file_content = [i for i in file_content.readlines()]\n", 35 | " \n", 36 | " out=[]\n", 37 | " for i in file_content:\n", 38 | " if i == '\\n':\n", 39 | " print(\"Skip Empty line\")\n", 40 | " continue\n", 41 | " out_pair = list(map(float, i.split()))\n", 42 | " out.append(out_pair)\n", 43 | " \n", 44 | " return np.array(out)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "Skip Empty line\n", 57 | "Skip Empty line\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "py_pred = read_predictions(datapath_py)\n", 63 | "cpp_pred = read_predictions(datapath_cpp)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### Print Predictions" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Python fwd pass: [[0.21623683 0.78376317]\n", 83 | " [0.21018669 0.78981334]\n", 84 | " [0.68355036 0.31644967]\n", 85 | " [0.14434026 0.85565972]\n", 86 | " [0.19162259 0.80837739]\n", 87 | " [0.65566438 0.34433565]\n", 88 | " [0.06425371 0.93574631]\n", 89 | " [0.63653803 0.363462 ]\n", 90 | " [0.0741361 0.92586386]\n", 91 | " [0.76896942 0.2310306 ]\n", 92 | " [0.28831816 0.71168184]\n", 93 | " [0.88230246 0.1176975 ]\n", 94 | " [0.52083325 0.47916675]]\n", 95 | "C++ fwd pass: [[0.216237 0.783763]\n", 96 | " [0.210187 0.789813]\n", 97 | " [0.68355 0.31645 ]\n", 98 | " [0.14434 0.85566 ]\n", 99 | " [0.191623 0.808377]\n", 100 | " [0.655664 0.344336]\n", 101 | " [0.064254 0.935746]\n", 102 | " [0.636538 0.363462]\n", 103 | " [0.074136 0.925864]\n", 104 | " [0.768969 0.231031]\n", 105 | " [0.288318 0.711682]\n", 106 | " [0.882303 0.117697]\n", 107 | " [0.520833 0.479167]]\n" 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "print(\"Python fwd pass: \", py_pred)\n", 113 | "print(\"C++ fwd pass: \", cpp_pred)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "### Scatter Plots" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "" 132 | ] 133 | }, 134 | "execution_count": 5, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | }, 138 | { 139 | "data": { 140 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAPCklEQVR4nO3dX2jd533H8fe3ik00mtYDqxeW3doDR9Qkoy4idPii2ZpNTi4c00IXQy7CsgQKKRsNApuO0KYXdSdWtgsP6u2iUGhTdxjhERcNFpdCqIcVlNTYQcVzk8bHF1FD1Juojex9d6Ejc6Qc+Rw55+9z3i8Q6Pf8Hs7vy4P84efn+Z3fE5mJJKn/faTbBUiSWsNAl6RCGOiSVAgDXZIKYaBLUiHu6taFt2/fnrt37+7W5SWpL73yyiu/zcyReue6Fui7d+9mdna2W5eXpL4UEW9udM4pF0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhuvbFolaYnqswNTPP9cUldmwbZnJijMP7R7tdliR1Rd8G+vRchWOnL7K0fBOAyuISx05fBDDUJQ2kvp1ymZqZvxXmq5aWbzI1M9+liiSpu/o20K8vLtVtrywuceD4S0zPVTpckSR1V98G+o5twxueW51+MdQlDZK+DfTJiTGGtwxteN7pF0mDpm8XRVcXPqdm5qlsMP2y0bSMJJWob+/QYSXUXz76F4xuMP1yu2kZSSpNXwf6qnrTL8NbhpicGOtSRZLUeX075VKrdvrFLxlJGlRFBDqshLoBLmmQFTHlIkky0CWpGAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKkQxXyzqFLe9k9SrDPRNcNs7Sb3MKZdNcNs7Sb3MQN+Ejd6v7nvXJfUCA30TNnq/uu9dl9QLDPRN8L3rknqZi6Kb4HvXJfUyA32TfO+6pF7llIskFcJAl6RCGOiSVAjn0LvA1wdIagcDvcN8fYCkdmlqyiUiDkbEfERciYijdc5/MiLORcRcRPwyIh5pfall8PUBktqlYaBHxBBwAngY2AcciYh967r9A3AqM/cDjwH/2upCS+HrAyS1SzN36A8AVzLzama+D7wAPLquTwIfq/7+ceB660osi68PkNQuzQT6KPBWzfG1alutbwCPR8Q14Czw1XofFBFPR8RsRMwuLCzcQbn9r9HrA6bnKhw4/hJ7jr7IgeMvMT1X6UaZkvpQqx5bPAJ8PzN3Ao8AP4iID3x2Zp7MzPHMHB8ZGWnRpfvL4f2jfPuL9zO6bZgARrcN8+0v3s/h/aO3Fkwri0skKwumf//jV/nMN//LYJfUUDNPuVSAXTXHO6tttZ4EDgJk5i8i4m5gO/B2K4oszUavD6i3YAqwuLTskzCSGmrmDv0CsDci9kTEVlYWPc+s6/Mb4AsAEfFp4G5gMOdUPoTbLYz6JIykRhoGembeAJ4BZoDXWXma5VJEPB8Rh6rdngWeiojXgB8BT2RmtqvoUjVaGPVJGEm309QXizLzLCuLnbVtz9X8fhk40NrSBs/kxNiaLx2t55Mwkm7Hb4r2kNX58W/+5yXefW95zTk30pDUiC/n6jGH948y99xf8c9//Zm6T8JI0ka8Q+9RbqQhabO8Q5ekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEL4TdEBNz1XYWpmnuuLS+zYNszkxJjfUJX6lIE+wFZ3SFp9u2NlccmNNKQ+5pTLAKu3Q5IbaUj9y0AfYBttmOFGGlJ/MtAH2EYbZriRhtSfDPQBNjkxxvCWoTVtbqQh9S8XRQfY6sKnT7lIZTDQB5wbaUjlcMpFkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1Ihmgr0iDgYEfMRcSUijm7Q58sRcTkiLkXED1tbpiSpkYZ7ikbEEHAC+EvgGnAhIs5k5uWaPnuBY8CBzHw3Ij7RroIlSfU1c4f+AHAlM69m5vvAC8Cj6/o8BZzIzHcBMvPt1pYpSWqkmUAfBd6qOb5Wbat1L3BvRLwcEecj4mC9D4qIpyNiNiJmFxYW7qxiSVJdrVoUvQvYCzwIHAH+LSK2re+UmSczczwzx0dGRlp0aUkSNBfoFWBXzfHOaluta8CZzFzOzF8Dv2Il4CVJHdJMoF8A9kbEnojYCjwGnFnXZ5qVu3MiYjsrUzBXW1inJKmBhoGemTeAZ4AZ4HXgVGZeiojnI+JQtdsM8E5EXAbOAZOZ+U67ipYkfVBkZlcuPD4+nrOzs125tiT1q4h4JTPH653zm6KSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIe7qdgFSaabnKkzNzHN9cYkd24aZnBjj8P7RbpelAWCgSy00PVfh2OmLLC3fBKCyuMSx0xcBDHW1nVMuUgtNzczfCvNVS8s3mZqZ71JFGiQGutRC1xeXNtUutZKBLrXQjm3Dm2qXWslAl1pocmKM4S1Da9qGtwwxOTHWpYo0SFwUlVpodeHTp1zUDQa61GKH948a4OoKp1wkqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCNBXoEXEwIuYj4kpEHL1Nvy9FREbEeOtKlCQ1o2GgR8QQcAJ4GNgHHImIfXX63QP8HfA/rS5SktRYM3foDwBXMvNqZr4PvAA8Wqfft4DvAL9vYX2SpCY1E+ijwFs1x9eqbbdExGeBXZn54u0+KCKejojZiJhdWFjYdLGSpI196EXRiPgI8F3g2UZ9M/NkZo5n5vjIyMiHvbQkqUYzgV4BdtUc76y2rboHuA/4WUS8AXwOOOPCqCR1VjOBfgHYGxF7ImIr8BhwZvVkZv4uM7dn5u7M3A2cBw5l5mxbKpYk1dUw0DPzBvAMMAO8DpzKzEsR8XxEHGp3gZKk5jS1p2hmngXOrmt7boO+D374siS12/Rcxc2sC+Mm0dIAmp6rcOz0RZaWbwJQWVzi2OmLAIZ6H/Or/9IAmpqZvxXmq5aWbzI1M9+litQKBro0gK4vLm2qXf3BQJcG0I5tw5tqV38w0KUBNDkxxvCWoTVtw1uGmJwYA1bm2A8cf4k9R1/kwPGXmJ6r1PsY9RgXRaUBtLrwWe8pFxdM+5eBLg2ow/tH6wb07RZMDfTe5pSLpDVcMO1fBrqkNVww7V8GuqQ1Gi2Yqnc5hy5pjdstmK7n6wN6i4Eu6QM2WjCt5dMwvccpF0l3xNcH9B4DXdId8WmY3mOgS7ojPg3Tewx0SXfEp2F6j4uiku7IZp6GUWcY6JLuWDNPw6hznHKRpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiF8OZekvuR+ph9koEvqO+5nWp9TLpL6jvuZ1megS+o77mdan4Euqe+4n2l9BrqkvuN+pvW5KCqp77ifaX1NBXpEHAT+BRgC/j0zj687/zXgb4EbwALwN5n5ZotrlaRb3M/0gxpOuUTEEHACeBjYBxyJiH3rus0B45n5p8B/AP/Y6kIlSbfXzBz6A8CVzLyame8DLwCP1nbIzHOZ+V718Dyws7VlSpIaaSbQR4G3ao6vVds28iTw03onIuLpiJiNiNmFhYXmq5QkNdTSp1wi4nFgHJiqdz4zT2bmeGaOj4yMtPLSkjTwmlkUrQC7ao53VtvWiIiHgK8Dn8/MP7SmPElSs5q5Q78A7I2IPRGxFXgMOFPbISL2A98DDmXm260vU5LUSMNAz8wbwDPADPA6cCozL0XE8xFxqNptCvgo8JOIeDUizmzwcZKkNmnqOfTMPAucXdf2XM3vD7W4LknSJvnVf0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRBuEi1JbTI9V+noRtYGuiS1wfRchWOnL7K0fBOAyuISx05fBGhbqDvlIkltMDUzfyvMVy0t32RqZr5t1zTQJakNri8ubaq9FQx0SWqDHduGN9XeCga6JLXB5MQYw1uG1rQNbxlicmKsbdd0UVSS2mB14dOnXCSpAIf3j7Y1wNdzykWSCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUiMjM7lw4YgF4sysX703bgd92u4ge5xg15hjdXgnj86nMHKl3omuBrrUiYjYzx7tdRy9zjBpzjG6v9PFxykWSCmGgS1IhDPTecbLbBfQBx6gxx+j2ih4f59AlqRDeoUtSIQx0SSqEgd5hEXEwIuYj4kpEHK1z/msRcTkifhkR/x0Rn+pGnd3UaIxq+n0pIjIiin0MrZ5mxicivlz9O7oUET/sdI3d1sS/s09GxLmImKv+W3ukG3W2XGb606EfYAj4X+BPgK3Aa8C+dX3+HPij6u9fAX7c7bp7bYyq/e4Bfg6cB8a7XXcvjQ+wF5gD/rh6/Ilu192DY3QS+Er1933AG92uuxU/3qF31gPAlcy8mpnvAy8Aj9Z2yMxzmfle9fA8sLPDNXZbwzGq+hbwHeD3nSyuBzQzPk8BJzLzXYDMfLvDNXZbM2OUwMeqv38cuN7B+trGQO+sUeCtmuNr1baNPAn8tK0V9Z6GYxQRnwV2ZeaLnSysRzTzN3QvcG9EvBwR5yPiYMeq6w3NjNE3gMcj4hpwFvhqZ0prLzeJ7lER8TgwDny+27X0koj4CPBd4Ikul9LL7mJl2uVBVv6H9/OIuD8zF7taVW85Anw/M/8pIv4M+EFE3JeZ/9ftwj4M79A7qwLsqjneWW1bIyIeAr4OHMrMP3Sotl7RaIzuAe4DfhYRbwCfA84M0MJoM39D14Azmbmcmb8GfsVKwA+KZsboSeAUQGb+AriblRd39TUDvbMuAHsjYk9EbAUeA87UdoiI/cD3WAnzQZv7hAZjlJm/y8ztmbk7M3ezss5wKDNnu1NuxzX8GwKmWbk7JyK2szIFc7WTRXZZM2P0G+ALABHxaVYCfaGjVbaBgd5BmXkDeAaYAV4HTmXmpYh4PiIOVbtNAR8FfhIRr0bE+j/EojU5RgOryfGZAd6JiMvAOWAyM9/pTsWd1+QYPQs8FRGvAT8CnsjqIy/9zK/+S1IhvEOXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQ/w81ztLjYDMkvgAAAABJRU5ErkJggg==\n", 141 | "text/plain": [ 142 | "
" 143 | ] 144 | }, 145 | "metadata": { 146 | "needs_background": "light" 147 | }, 148 | "output_type": "display_data" 149 | } 150 | ], 151 | "source": [ 152 | "plt.scatter(py_pred[:,0], py_pred[:,1])" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "metadata": { 159 | "scrolled": true 160 | }, 161 | "outputs": [ 162 | { 163 | "data": { 164 | "text/plain": [ 165 | "" 166 | ] 167 | }, 168 | "execution_count": 6, 169 | "metadata": {}, 170 | "output_type": "execute_result" 171 | }, 172 | { 173 | "data": { 174 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAPCklEQVR4nO3dX2jd533H8fe3ik00mtYDqxeW3doDR9Qkoy4idPii2ZpNTi4c00IXQy7CsgQKKRsNApuO0KYXdSdWtgsP6u2iUGhTdxjhERcNFpdCqIcVlNTYQcVzk8bHF1FD1Juojex9d6Ejc6Qc+Rw55+9z3i8Q6Pf8Hs7vy4P84efn+Z3fE5mJJKn/faTbBUiSWsNAl6RCGOiSVAgDXZIKYaBLUiHu6taFt2/fnrt37+7W5SWpL73yyiu/zcyReue6Fui7d+9mdna2W5eXpL4UEW9udM4pF0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhuvbFolaYnqswNTPP9cUldmwbZnJijMP7R7tdliR1Rd8G+vRchWOnL7K0fBOAyuISx05fBDDUJQ2kvp1ymZqZvxXmq5aWbzI1M9+liiSpu/o20K8vLtVtrywuceD4S0zPVTpckSR1V98G+o5twxueW51+MdQlDZK+DfTJiTGGtwxteN7pF0mDpm8XRVcXPqdm5qlsMP2y0bSMJJWob+/QYSXUXz76F4xuMP1yu2kZSSpNXwf6qnrTL8NbhpicGOtSRZLUeX075VKrdvrFLxlJGlRFBDqshLoBLmmQFTHlIkky0CWpGAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKkQxXyzqFLe9k9SrDPRNcNs7Sb3MKZdNcNs7Sb3MQN+Ejd6v7nvXJfUCA30TNnq/uu9dl9QLDPRN8L3rknqZi6Kb4HvXJfUyA32TfO+6pF7llIskFcJAl6RCGOiSVAjn0LvA1wdIagcDvcN8fYCkdmlqyiUiDkbEfERciYijdc5/MiLORcRcRPwyIh5pfall8PUBktqlYaBHxBBwAngY2AcciYh967r9A3AqM/cDjwH/2upCS+HrAyS1SzN36A8AVzLzama+D7wAPLquTwIfq/7+ceB660osi68PkNQuzQT6KPBWzfG1alutbwCPR8Q14Czw1XofFBFPR8RsRMwuLCzcQbn9r9HrA6bnKhw4/hJ7jr7IgeMvMT1X6UaZkvpQqx5bPAJ8PzN3Ao8AP4iID3x2Zp7MzPHMHB8ZGWnRpfvL4f2jfPuL9zO6bZgARrcN8+0v3s/h/aO3Fkwri0skKwumf//jV/nMN//LYJfUUDNPuVSAXTXHO6tttZ4EDgJk5i8i4m5gO/B2K4oszUavD6i3YAqwuLTskzCSGmrmDv0CsDci9kTEVlYWPc+s6/Mb4AsAEfFp4G5gMOdUPoTbLYz6JIykRhoGembeAJ4BZoDXWXma5VJEPB8Rh6rdngWeiojXgB8BT2RmtqvoUjVaGPVJGEm309QXizLzLCuLnbVtz9X8fhk40NrSBs/kxNiaLx2t55Mwkm7Hb4r2kNX58W/+5yXefW95zTk30pDUiC/n6jGH948y99xf8c9//Zm6T8JI0ka8Q+9RbqQhabO8Q5ekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEL4TdEBNz1XYWpmnuuLS+zYNszkxJjfUJX6lIE+wFZ3SFp9u2NlccmNNKQ+5pTLAKu3Q5IbaUj9y0AfYBttmOFGGlJ/MtAH2EYbZriRhtSfDPQBNjkxxvCWoTVtbqQh9S8XRQfY6sKnT7lIZTDQB5wbaUjlcMpFkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1Ihmgr0iDgYEfMRcSUijm7Q58sRcTkiLkXED1tbpiSpkYZ7ikbEEHAC+EvgGnAhIs5k5uWaPnuBY8CBzHw3Ij7RroIlSfU1c4f+AHAlM69m5vvAC8Cj6/o8BZzIzHcBMvPt1pYpSWqkmUAfBd6qOb5Wbat1L3BvRLwcEecj4mC9D4qIpyNiNiJmFxYW7qxiSVJdrVoUvQvYCzwIHAH+LSK2re+UmSczczwzx0dGRlp0aUkSNBfoFWBXzfHOaluta8CZzFzOzF8Dv2Il4CVJHdJMoF8A9kbEnojYCjwGnFnXZ5qVu3MiYjsrUzBXW1inJKmBhoGemTeAZ4AZ4HXgVGZeiojnI+JQtdsM8E5EXAbOAZOZ+U67ipYkfVBkZlcuPD4+nrOzs125tiT1q4h4JTPH653zm6KSVAgDXZIKYaBLUiEMdEkqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIe7qdgFSaabnKkzNzHN9cYkd24aZnBjj8P7RbpelAWCgSy00PVfh2OmLLC3fBKCyuMSx0xcBDHW1nVMuUgtNzczfCvNVS8s3mZqZ71JFGiQGutRC1xeXNtUutZKBLrXQjm3Dm2qXWslAl1pocmKM4S1Da9qGtwwxOTHWpYo0SFwUlVpodeHTp1zUDQa61GKH948a4OoKp1wkqRAGuiQVwkCXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCNBXoEXEwIuYj4kpEHL1Nvy9FREbEeOtKlCQ1o2GgR8QQcAJ4GNgHHImIfXX63QP8HfA/rS5SktRYM3foDwBXMvNqZr4PvAA8Wqfft4DvAL9vYX2SpCY1E+ijwFs1x9eqbbdExGeBXZn54u0+KCKejojZiJhdWFjYdLGSpI196EXRiPgI8F3g2UZ9M/NkZo5n5vjIyMiHvbQkqUYzgV4BdtUc76y2rboHuA/4WUS8AXwOOOPCqCR1VjOBfgHYGxF7ImIr8BhwZvVkZv4uM7dn5u7M3A2cBw5l5mxbKpYk1dUw0DPzBvAMMAO8DpzKzEsR8XxEHGp3gZKk5jS1p2hmngXOrmt7boO+D374siS12/Rcxc2sC+Mm0dIAmp6rcOz0RZaWbwJQWVzi2OmLAIZ6H/Or/9IAmpqZvxXmq5aWbzI1M9+litQKBro0gK4vLm2qXf3BQJcG0I5tw5tqV38w0KUBNDkxxvCWoTVtw1uGmJwYA1bm2A8cf4k9R1/kwPGXmJ6r1PsY9RgXRaUBtLrwWe8pFxdM+5eBLg2ow/tH6wb07RZMDfTe5pSLpDVcMO1fBrqkNVww7V8GuqQ1Gi2Yqnc5hy5pjdstmK7n6wN6i4Eu6QM2WjCt5dMwvccpF0l3xNcH9B4DXdId8WmY3mOgS7ojPg3Tewx0SXfEp2F6j4uiku7IZp6GUWcY6JLuWDNPw6hznHKRpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQBrokFcJAl6RCGOiSVAgDXZIKYaBLUiF8OZekvuR+ph9koEvqO+5nWp9TLpL6jvuZ1megS+o77mdan4Euqe+4n2l9BrqkvuN+pvW5KCqp77ifaX1NBXpEHAT+BRgC/j0zj687/zXgb4EbwALwN5n5ZotrlaRb3M/0gxpOuUTEEHACeBjYBxyJiH3rus0B45n5p8B/AP/Y6kIlSbfXzBz6A8CVzLyame8DLwCP1nbIzHOZ+V718Dyws7VlSpIaaSbQR4G3ao6vVds28iTw03onIuLpiJiNiNmFhYXmq5QkNdTSp1wi4nFgHJiqdz4zT2bmeGaOj4yMtPLSkjTwmlkUrQC7ao53VtvWiIiHgK8Dn8/MP7SmPElSs5q5Q78A7I2IPRGxFXgMOFPbISL2A98DDmXm260vU5LUSMNAz8wbwDPADPA6cCozL0XE8xFxqNptCvgo8JOIeDUizmzwcZKkNmnqOfTMPAucXdf2XM3vD7W4LknSJvnVf0kqhIEuSYUw0CWpEAa6JBXCQJekQhjoklQIA12SCmGgS1IhDHRJKoSBLkmFMNAlqRBuEi1JbTI9V+noRtYGuiS1wfRchWOnL7K0fBOAyuISx05fBGhbqDvlIkltMDUzfyvMVy0t32RqZr5t1zTQJakNri8ubaq9FQx0SWqDHduGN9XeCga6JLXB5MQYw1uG1rQNbxlicmKsbdd0UVSS2mB14dOnXCSpAIf3j7Y1wNdzykWSCmGgS1IhDHRJKoSBLkmFMNAlqRAGuiQVwkCXpEIY6JJUiMjM7lw4YgF4sysX703bgd92u4ge5xg15hjdXgnj86nMHKl3omuBrrUiYjYzx7tdRy9zjBpzjG6v9PFxykWSCmGgS1IhDPTecbLbBfQBx6gxx+j2ih4f59AlqRDeoUtSIQx0SSqEgd5hEXEwIuYj4kpEHK1z/msRcTkifhkR/x0Rn+pGnd3UaIxq+n0pIjIiin0MrZ5mxicivlz9O7oUET/sdI3d1sS/s09GxLmImKv+W3ukG3W2XGb606EfYAj4X+BPgK3Aa8C+dX3+HPij6u9fAX7c7bp7bYyq/e4Bfg6cB8a7XXcvjQ+wF5gD/rh6/Ilu192DY3QS+Er1933AG92uuxU/3qF31gPAlcy8mpnvAy8Aj9Z2yMxzmfle9fA8sLPDNXZbwzGq+hbwHeD3nSyuBzQzPk8BJzLzXYDMfLvDNXZbM2OUwMeqv38cuN7B+trGQO+sUeCtmuNr1baNPAn8tK0V9Z6GYxQRnwV2ZeaLnSysRzTzN3QvcG9EvBwR5yPiYMeq6w3NjNE3gMcj4hpwFvhqZ0prLzeJ7lER8TgwDny+27X0koj4CPBd4Ikul9LL7mJl2uVBVv6H9/OIuD8zF7taVW85Anw/M/8pIv4M+EFE3JeZ/9ftwj4M79A7qwLsqjneWW1bIyIeAr4OHMrMP3Sotl7RaIzuAe4DfhYRbwCfA84M0MJoM39D14Azmbmcmb8GfsVKwA+KZsboSeAUQGb+AriblRd39TUDvbMuAHsjYk9EbAUeA87UdoiI/cD3WAnzQZv7hAZjlJm/y8ztmbk7M3ezss5wKDNnu1NuxzX8GwKmWbk7JyK2szIFc7WTRXZZM2P0G+ALABHxaVYCfaGjVbaBgd5BmXkDeAaYAV4HTmXmpYh4PiIOVbtNAR8FfhIRr0bE+j/EojU5RgOryfGZAd6JiMvAOWAyM9/pTsWd1+QYPQs8FRGvAT8CnsjqIy/9zK/+S1IhvEOXpEIY6JJUCANdkgphoEtSIQx0SSqEgS5JhTDQJakQ/w81ztLjYDMkvgAAAABJRU5ErkJggg==\n", 175 | "text/plain": [ 176 | "
" 177 | ] 178 | }, 179 | "metadata": { 180 | "needs_background": "light" 181 | }, 182 | "output_type": "display_data" 183 | } 184 | ], 185 | "source": [ 186 | "plt.scatter(cpp_pred[:,0], cpp_pred[:,1])" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "### Error" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 7, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "err_arr = py_pred-cpp_pred" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 8, 208 | "metadata": { 209 | "scrolled": false 210 | }, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "text/plain": [ 215 | "array([[-1.70242310e-07, 1.70242310e-07],\n", 216 | " [-3.09907913e-07, 3.39710236e-07],\n", 217 | " [ 3.57818604e-07, -3.28016281e-07],\n", 218 | " [ 2.61816978e-07, -2.76718140e-07],\n", 219 | " [-4.14941788e-07, 3.85139465e-07],\n", 220 | " [ 3.84365082e-07, -3.54562759e-07],\n", 221 | " [-2.89789677e-07, 3.12141418e-07],\n", 222 | " [ 2.87170410e-08, 1.08528136e-09],\n", 223 | " [ 1.00709438e-07, -1.37962341e-07],\n", 224 | " [ 4.16618347e-07, -4.01717186e-07],\n", 225 | " [ 1.57196045e-07, -1.57196045e-07],\n", 226 | " [-5.36945343e-07, 4.99692440e-07],\n", 227 | " [ 2.53860474e-07, -2.53860474e-07]])" 228 | ] 229 | }, 230 | "execution_count": 8, 231 | "metadata": {}, 232 | "output_type": "execute_result" 233 | } 234 | ], 235 | "source": [ 236 | "err_arr" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "### Norm" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 9, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "Norms for python predictions and c++ predictions: 2.948778180317615 2.9487779999525228\n" 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "py_pred_norm = np.linalg.norm(py_pred)\n", 261 | "cpp_pred_norm = np.linalg.norm(cpp_pred)\n", 262 | "\n", 263 | "print(\"Norms for python predictions and c++ predictions:\", py_pred_norm, cpp_pred_norm)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "### np.allclose()" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 10, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "True" 282 | ] 283 | }, 284 | "execution_count": 10, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | } 288 | ], 289 | "source": [ 290 | "# True if two arrays are element-wise equal within a tolerance.\n", 291 | "np.allclose(py_pred, cpp_pred) # default rtol=1e-05 " 292 | ] 293 | } 294 | ], 295 | "metadata": { 296 | "kernelspec": { 297 | "display_name": "Python 3", 298 | "language": "python", 299 | "name": "python3" 300 | }, 301 | "language_info": { 302 | "codemirror_mode": { 303 | "name": "ipython", 304 | "version": 3 305 | }, 306 | "file_extension": ".py", 307 | "mimetype": "text/x-python", 308 | "name": "python", 309 | "nbconvert_exporter": "python", 310 | "pygments_lexer": "ipython3", 311 | "version": "3.6.9" 312 | } 313 | }, 314 | "nbformat": 4, 315 | "nbformat_minor": 2 316 | } 317 | -------------------------------------------------------------------------------- /Compare_Predictions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import numpy as np\n", 11 | "import os\n", 12 | "import sys\n", 13 | "import matplotlib.pyplot as plt" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "### Load Predictions" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "datapath_py =\"resources/saved/py_predicted.txt\"\n", 30 | "datapath_cpp =\"resources/saved/cpp_predicted.txt\"\n", 31 | "\n", 32 | "def read_predictions(datapath):\n", 33 | " with open(datapath) as file_content:\n", 34 | " file_content = [i for i in file_content.readlines()]\n", 35 | " \n", 36 | " out=[]\n", 37 | " for i in file_content:\n", 38 | " if i == '\\n':\n", 39 | " print(\"Skip Empty line\")\n", 40 | " continue\n", 41 | " out_pair = list(map(float, i.split()))\n", 42 | " out.append(out_pair)\n", 43 | " \n", 44 | " return np.array(out)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "Skip Empty line\n", 57 | "Skip Empty line\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "py_pred = read_predictions(datapath_py)\n", 63 | "cpp_pred = read_predictions(datapath_cpp)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### Print Predictions" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stdout", 80 | "output_type": "stream", 81 | "text": [ 82 | "Python fwd pass: [[5.79098511 1.68282413]\n", 83 | " [3.79514885 1.50127339]\n", 84 | " [3.05594134 2.65454769]\n", 85 | " [3.59553909 1.0229435 ]\n", 86 | " [3.67793489 0.66021532]\n", 87 | " [4.15732527 0.90474343]\n", 88 | " [4.09724522 0.99075013]\n", 89 | " [2.9376111 1.04406965]\n", 90 | " [1.8372159 2.59138894]\n", 91 | " [1.23373771 2.04912949]\n", 92 | " [3.48744416 0.73090637]\n", 93 | " [2.71336007 0.22554693]\n", 94 | " [2.82046795 0.40351319]\n", 95 | " [2.87005043 1.63838708]\n", 96 | " [0.70065492 2.67366743]\n", 97 | " [0.88620341 2.66380835]\n", 98 | " [3.32769847 0.81819201]\n", 99 | " [2.49346495 0.60585958]\n", 100 | " [0.74832189 2.84791136]\n", 101 | " [2.2085979 1.51993418]\n", 102 | " [0.7664963 2.69486427]\n", 103 | " [2.67619324 0.70304364]\n", 104 | " [0.80111802 2.66594386]\n", 105 | " [1.17281389 3.57055759]\n", 106 | " [1.21135879 2.25158095]\n", 107 | " [1.03086376 2.40085602]\n", 108 | " [0.37634149 3.37193155]\n", 109 | " [1.39961958 2.71279001]\n", 110 | " [1.21885276 2.28670406]\n", 111 | " [0.65545332 3.94429827]\n", 112 | " [1.43623936 2.53711128]\n", 113 | " [1.88981557 3.04195762]\n", 114 | " [1.25789416 5.09653568]\n", 115 | " [1.65704513 5.76212788]]\n", 116 | "C++ fwd pass: [[5.790986 1.682824]\n", 117 | " [3.795149 1.501273]\n", 118 | " [3.055941 2.654548]\n", 119 | " [3.595539 1.022943]\n", 120 | " [3.677935 0.660215]\n", 121 | " [4.157325 0.904744]\n", 122 | " [4.097245 0.99075 ]\n", 123 | " [2.937612 1.04407 ]\n", 124 | " [1.837216 2.591389]\n", 125 | " [1.233738 2.04913 ]\n", 126 | " [3.487444 0.730906]\n", 127 | " [2.71336 0.225547]\n", 128 | " [2.820468 0.403513]\n", 129 | " [2.87005 1.638387]\n", 130 | " [0.700655 2.673667]\n", 131 | " [0.886203 2.663809]\n", 132 | " [3.327699 0.818192]\n", 133 | " [2.493465 0.60586 ]\n", 134 | " [0.748322 2.847911]\n", 135 | " [2.208598 1.519934]\n", 136 | " [0.766496 2.694864]\n", 137 | " [2.676193 0.703044]\n", 138 | " [0.801118 2.665944]\n", 139 | " [1.172814 3.570558]\n", 140 | " [1.211359 2.251581]\n", 141 | " [1.030864 2.400856]\n", 142 | " [0.376341 3.371932]\n", 143 | " [1.39962 2.71279 ]\n", 144 | " [1.218853 2.286704]\n", 145 | " [0.655453 3.944298]\n", 146 | " [1.436239 2.537111]\n", 147 | " [1.889816 3.041958]\n", 148 | " [1.257894 5.096536]\n", 149 | " [1.657045 5.762127]]\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "print(\"Python fwd pass: \", py_pred)\n", 155 | "print(\"C++ fwd pass: \", cpp_pred)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "### Scatter Plots" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 5, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "" 174 | ] 175 | }, 176 | "execution_count": 5, 177 | "metadata": {}, 178 | "output_type": "execute_result" 179 | }, 180 | { 181 | "data": { 182 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD6CAYAAACIyQ0UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAREUlEQVR4nO3db4hd9Z3H8c/HyRRH6zoLXoqZmI2PBrSyjnuRlhTZjbSxW9cN0gftYmHLQp64xdJlSrJPWmEXhUDpPljKBrXbpbayxBio7ZotJKVr2dpOHG2MMUtxFTO2ZEodqnVYx/G7D+ZOzCT3z7lzz7nnd859vyA4c3Mm+V7Rz/nd7/me33FECACQrsvKLgAA0B1BDQCJI6gBIHEENQAkjqAGgMQR1ACQuExBbXvS9iHbL9k+bfujRRcGAFizJeNx/yTpqYj4tO0PSLqi28HXXHNN7NixY9DaAGBknDhx4jcR0Wj3ez2D2vbVkm6T9NeSFBHvSHqn28/s2LFDc3Nz/VcKACPK9qudfi9L6+N6SYuSvml73vZDtq9s85fstT1ne25xcXGAcgEAF8oS1Fsk3SLpGxExI+n3kvZdfFBEHIyIZkQ0G422q3cAwCZkCeqzks5GxDOt7w9pLbgBAEPQM6gj4teSXrM93XrpdkkvFloVAOC8rFMfX5D0aGvi42VJny+uJADAhTIFdUQ8J6lZcC0AgDayrqhRsiPzCzpw9IxeX1rW1skJze6e1p6ZqbLLAjAEBHUFHJlf0P7DJ7W8sipJWlha1v7DJyWJsAZGAHt9VMCBo2fOh/S65ZVVHTh6pqSKAAwTQV0Bry8t9/U6gHohqCtg6+REX68DqBeCugJmd09rYnxsw2sT42Oa3T3d4ScA1AkXEytg/YIhUx/AaCKoK2LPzBTBDIwoWh8AkDiCGgASR1ADQOIIagBIHEENAIkjqAEgcQQ1ACSOoAaAxBHUAJA4ghoAEkdQA0DiCGoASBxBDQCJI6gBIHEENQAkjv2oc3ZkfoEN/gHkiqDO0ZH5Be0/fPL8E8MXlpa1//BJSSKsAWwarY8cHTh65nxIr1teWdWBo2dKqghAHRDUOXp9abmv1wEgi0ytD9uvSHpT0qqkdyOiWWRRVbV1ckILbUJ56+RECdUAqIt+VtR/FhE3E9Kdze6e1sT42IbXJsbHNLt7uqSKANQBFxNztH7BkKkPAHlyRPQ+yP5fSW9ICkn/EhEH2xyzV9JeSdq+ffufvPrqqzmXCgD1ZftEp45F1tbHxyLiFkmflHSv7dsuPiAiDkZEMyKajUZjgHIBABfKFNQRsdD65zlJT0i6tciiAADv6xnUtq+0fdX615I+IemFogsDAKzJcjHxQ5KesL1+/Hci4qlCqwIAnNczqCPiZUl/PIRaAABtcGciACSOoAaAxBHUAJA4ghoAEkdQA0DiCGoASBxBDQCJI6gBIHEENQAkjqAGgMQR1ACQOIIaABJHUANA4ghqAEgcQQ0AiSOoASBxBDUAJI6gBoDEEdQAkDiCGgASR1ADQOIIagBIHEENAIkjqAEgcQQ1ACSOoAaAxGUOattjtudtP1lkQQCAjfpZUd8n6XRRhQAA2tuS5SDb2yR9StI/SvpSoRVt0pH5BR04ekavLy1r6+SEZndPa8/MVNllAcDAMgW1pK9L+rKkqzodYHuvpL2StH379sEr68OR+QXtP3xSyyurkqSFpWXtP3xSkghrAJXXs/Vh+05J5yLiRLfjIuJgRDQjotloNHIrMIsDR8+cD+l1yyurOnD0zFDrAIAiZOlR75R0l+1XJD0maZftbxdaVZ9eX1ru63UAqJKeQR0R+yNiW0TskPQZScci4p7CK+vD1smJvl4HgCqpxRz17O5pTYyPbXhtYnxMs7unS6oIAPKT9WKiJCkifiTpR4VUMoD1C4ZMfQCoo76COmV7ZqYIZgC1VIvWBwDUGUENAIkjqAEgcQQ1ACSOoAaAxBHUAJC42oznlYEd+wAMA0G9SezYB2BYaH1sEjv2ARgWgnqT2LEPwLAQ1JvEjn0AhoWg3iR27AMwLFxM3CR27AMwLMkEdRVH3dixD8AwJBHUjLoBQGdJ9KgZdQOAzpIIakbdAKCzJIKaUTcA6CyJoGbUDQA6S+JiIqNuANBZEkEtMeoGAJ0k0foAAHRGUANA4ghqAEhcMj1q5K+Kt+UDuBRBXVPclg/UR8/Wh+3Lbf/M9vO2T9m+fxiFYTDclg/UR5YV9f9J2hURb9kel/S07f+IiJ8WXBsGwG35QH30XFHHmrda3463fkWhVWFg3JYP1EemqQ/bY7afk3RO0g8j4pk2x+y1PWd7bnFxMe860SduywfqI1NQR8RqRNwsaZukW21/uM0xByOiGRHNRqORd53o056ZKT1w902ampyQJU1NTuiBu2/iQiJQQX1NfUTEku3jku6Q9EIxJeVn1MfTuC0fqIcsUx8N25OtryckfVzSS0UXNqj18bSFpWWF3h9POzK/UHZpANCXLK2PayUdt/0LST/XWo/6yWLLGhzjaQDqomfrIyJ+IWlmCLXkivE0AHVR270+GE8DUBe1DepUx9OOzC9o54PHdP2+72vng8fomQPoqXZ7fVw46XH1xLis0Nsr70mSLh8v97zE/hsANsMR+d9k2Gw2Y25uLvc/t5eLg7CTP7xiXF/5ixslDffxXzsfPKaFNj3yqckJ/WTfrsL+3lE26iOaqA7bJyKi2e73arWibjfp0c4bb69o9tDzUkgr762dqIaxuuUC53DxCQZ1UasedT+Bt7Ia50N6XdHje8O4wEkP/H2MaKIuahXUeQRekavboi9wcpPPRnyCQV3UKqjbBWG/ihzfu3D/DUkas8+v8PIIU1aQGzGiibqoVVC324jono9s1+TE+CXHjo9Z45d5w2vDGN/bMzN1/oSyGhv744OGNSvIjVId0QT6VauLiVL7jYj+Yc9Nba/+S8Od+ljXbeU7yN+/dXKi7VTJqK4g1/9dMvWBqqtdUHfSaSe5Mv6nLWrlO7t7+pLxxFFfQbKDIOqgVq2Pqiiqd8oe1EA9jcyKejOKulmiyJUvK0igfgjqDoq8WYLeKYB+ENQdFHXBbx0rXwBZ0aPugFE3AKlgRd1BP6NubPwDoEisqDvIerPEkfkFzR56fsNt27OHnh/Z27YB5I+g7iDrqNv93zulldWNmzutrIbu/96pIVYLoM5ofXSR5YLfG2+v9PU6APSLFTUAJI6gHlC7DZ+6vQ4A/SKoB/TVu268ZBe+8cusr951Y0kVAagbetQD4i5DAEUjqHPAXYYAikRQ54SbXgAUpWeP2vZ1to/bftH2Kdv3DaOwKuFZhQCKlOVi4ruS/i4ibpD0EUn32r6h2LKqhWcVAihSz6COiF9FxLOtr9+UdFoSn+kvwAZOAIrU13ie7R2SZiQ90+b39tqesz23uLiYT3UVwdOuARQpc1Db/qCkxyV9MSJ+d/HvR8TBiGhGRLPRaORZY/J42jWAImWa+rA9rrWQfjQiDhdbUvUwSw2gSD2D2rYlPSzpdER8rfiSqolZagBFydL62Cnpc5J22X6u9evPC64LANDSc0UdEU9Lcq/jAADFYFMmAEgcQQ0AiSOoASBxbMoEAAMqelM2ghqVxG6FSMX6pmzr+/2sb8omKbf/Jml9oHLYrRApGcambAQ1KofdCpGSYWzKRlCjctitECkZxqZsBDV6OjK/oJ0PHtP1+76vnQ8eK73FwG6FSMkwNmUjqNFViv3gOu5WmNrJENntmZnSA3ffpKnJCVnS1OSEHrj7JqY+MDzd+sFlTVnUbbfCYUwNoFhFb8pGUKOrVPvBddqtMMWTIdJC6wNd0Q8uXqonQ6SDoEZXdewHp4aTIXohqNHVMC6UjDpOhuiFHjV6qlM/OEV1uziK/BHUQAI4GaIbWh8AkDiCGgASR1ADQOIIagBIHEENAIkjqAEgcQQ1ACSOoAaAxHHDC2qBh92izghqVF6q+zlz8kBeerY+bD9i+5ztF4ZRENCvFB92W8aTcXhKTH1l6VH/q6Q7Cq4D2LQU93Me9skjxUemIT89gzoifizpt0OoBdiUFPdzHvbJI8VPFchPblMftvfanrM9t7i4mNcfC/SU4n7Owz55dDoBLCwt0wapgdyCOiIORkQzIpqNRiOvPxboKcWHGwz75NHtBEAbpPqY+kAtpLaf87AfBjC7e3rD5MvFeFhutRHUQBeDjNgN8+Rx4YlhIcGLqxhMlvG870r6b0nTts/a/pviywLKV7VJij0zU/rJvl2aSvDiKgaTZerjsxFxbUSMR8S2iHh4GIUBZavqJEWKF1cxGFofQAdZR+xSuwORh+XWD0GNZKQWeFsnJ9r2ey9sIaR6+3pqF1cxGHbPQxJS7AdnaSFUtT2CamFFjSR0C7yyVoZZWgj93IGY2icGVAdBjSSkuF+H1LuFkKU9IqXbIkE10PpA4bLs6pbifh1ZZJ2woEWCQRDUKFTW3nNVR8qy3r6e6icGVAOtDxQqa++5yiNlWSYssrZIgHYIahSqn5VknUfK2u3FUYVPDEgDrQ8Uqqq957yluMMfqoMVNQrFSvJ9df7EgGIR1ChUlXvPQCoIahQur5UkN4xgVBHUqARuGMEo42IiKoEbRjDKCGpUAjeMYJQR1KgExvwwyghqVEJVbzEH8sDFRCSn23QHUx8YRQQ1ktJruoNgxiii9YGkMN0BXIqgRlKY7gAuRVAjKUx3AJciqJEUpjuAS3ExEUlhugO4FEGN5DDdAWyUqfVh+w7bZ2z/0va+oosCALyvZ1DbHpP0z5I+KekGSZ+1fUPRhQEA1mRZUd8q6ZcR8XJEvCPpMUl/WWxZAIB1WYJ6StJrF3x/tvXaBrb32p6zPbe4uJhXfQAw8nIbz4uIgxHRjIhmo9HI648FgJGXJagXJF13wffbWq8BAIbAEdH9AHuLpP+RdLvWAvrnkv4qIk51+ZlFSa/mWGdZrpH0m7KLKBDvr9p4f9V28fv7o4ho247oOUcdEe/a/ltJRyWNSXqkW0i3fqYWvQ/bcxHRLLuOovD+qo33V239vL9MN7xExA8k/WCgqgAAm8JeHwCQOIK6u4NlF1Aw3l+18f6qLfP763kxEQBQLlbUAJA4ghoAEkdQt2H7EdvnbL9Qdi1FsH2d7eO2X7R9yvZ9ZdeUJ9uX2/6Z7edb7+/+smsqgu0x2/O2nyy7lrzZfsX2SdvP2Z4ru5682Z60fcj2S7ZP2/5o1+PpUV/K9m2S3pL0bxHx4bLryZvtayVdGxHP2r5K0glJeyLixZJLy4VtS7oyIt6yPS7paUn3RcRPSy4tV7a/JKkp6Q8i4s6y68mT7VckNSOilje82P6WpP+KiIdsf0DSFRGx1Ol4VtRtRMSPJf227DqKEhG/iohnW1+/Kem02my0VVWx5q3Wt+OtX7VakdjeJulTkh4quxb0x/bVkm6T9LAkRcQ73UJaIqhHnu0dkmYkPVNuJflqtQWek3RO0g8jolbvT9LXJX1Z0ntlF1KQkPSftk/Y3lt2MTm7XtKipG+2WlcP2b6y2w8Q1CPM9gclPS7pixHxu7LryVNErEbEzVrbROxW27VpYdm+U9K5iDhRdi0F+lhE3KK1B5bc22pH1sUWSbdI+kZEzEj6vaSuT84iqEdUq3f7uKRHI+Jw2fUUpfWR8rikO8quJUc7Jd3V6uM+JmmX7W+XW1K+ImKh9c9zkp7Q2gNM6uKspLMXfMo7pLXg7oigHkGti20PSzodEV8ru5682W7Ynmx9PSHp45JeKreq/ETE/ojYFhE7JH1G0rGIuKfksnJj+8rWRW61WgKfkFSbCayI+LWk12xPt166XVLXC/k8hbwN29+V9KeSrrF9VtJXIuLhcqvK1U5Jn5N0stXHlaS/b22+VQfXSvpW63mfl0n694io3QhbjX1I0hNr6wltkfSdiHiq3JJy9wVJj7YmPl6W9PluBzOeBwCJo/UBAIkjqAEgcQQ1ACSOoAaAxBHUAJA4ghoAEkdQA0Di/h85K3tJLjWDlQAAAABJRU5ErkJggg==\n", 183 | "text/plain": [ 184 | "
" 185 | ] 186 | }, 187 | "metadata": { 188 | "needs_background": "light" 189 | }, 190 | "output_type": "display_data" 191 | } 192 | ], 193 | "source": [ 194 | "plt.scatter(py_pred[:,0], py_pred[:,1])" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 6, 200 | "metadata": { 201 | "scrolled": true 202 | }, 203 | "outputs": [ 204 | { 205 | "data": { 206 | "text/plain": [ 207 | "" 208 | ] 209 | }, 210 | "execution_count": 6, 211 | "metadata": {}, 212 | "output_type": "execute_result" 213 | }, 214 | { 215 | "data": { 216 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD6CAYAAACIyQ0UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAREUlEQVR4nO3db4hd9Z3H8c/HyRRH6zoLXoqZmI2PBrSyjnuRlhTZjbSxW9cN0gftYmHLQp64xdJlSrJPWmEXhUDpPljKBrXbpbayxBio7ZotJKVr2dpOHG2MMUtxFTO2ZEodqnVYx/G7D+ZOzCT3z7lzz7nnd859vyA4c3Mm+V7Rz/nd7/me33FECACQrsvKLgAA0B1BDQCJI6gBIHEENQAkjqAGgMQR1ACQuExBbXvS9iHbL9k+bfujRRcGAFizJeNx/yTpqYj4tO0PSLqi28HXXHNN7NixY9DaAGBknDhx4jcR0Wj3ez2D2vbVkm6T9NeSFBHvSHqn28/s2LFDc3Nz/VcKACPK9qudfi9L6+N6SYuSvml73vZDtq9s85fstT1ne25xcXGAcgEAF8oS1Fsk3SLpGxExI+n3kvZdfFBEHIyIZkQ0G422q3cAwCZkCeqzks5GxDOt7w9pLbgBAEPQM6gj4teSXrM93XrpdkkvFloVAOC8rFMfX5D0aGvi42VJny+uJADAhTIFdUQ8J6lZcC0AgDayrqhRsiPzCzpw9IxeX1rW1skJze6e1p6ZqbLLAjAEBHUFHJlf0P7DJ7W8sipJWlha1v7DJyWJsAZGAHt9VMCBo2fOh/S65ZVVHTh6pqSKAAwTQV0Bry8t9/U6gHohqCtg6+REX68DqBeCugJmd09rYnxsw2sT42Oa3T3d4ScA1AkXEytg/YIhUx/AaCKoK2LPzBTBDIwoWh8AkDiCGgASR1ADQOIIagBIHEENAIkjqAEgcQQ1ACSOoAaAxBHUAJA4ghoAEkdQA0DiCGoASBxBDQCJI6gBIHEENQAkjv2oc3ZkfoEN/gHkiqDO0ZH5Be0/fPL8E8MXlpa1//BJSSKsAWwarY8cHTh65nxIr1teWdWBo2dKqghAHRDUOXp9abmv1wEgi0ytD9uvSHpT0qqkdyOiWWRRVbV1ckILbUJ56+RECdUAqIt+VtR/FhE3E9Kdze6e1sT42IbXJsbHNLt7uqSKANQBFxNztH7BkKkPAHlyRPQ+yP5fSW9ICkn/EhEH2xyzV9JeSdq+ffufvPrqqzmXCgD1ZftEp45F1tbHxyLiFkmflHSv7dsuPiAiDkZEMyKajUZjgHIBABfKFNQRsdD65zlJT0i6tciiAADv6xnUtq+0fdX615I+IemFogsDAKzJcjHxQ5KesL1+/Hci4qlCqwIAnNczqCPiZUl/PIRaAABtcGciACSOoAaAxBHUAJA4ghoAEkdQA0DiCGoASBxBDQCJI6gBIHEENQAkjqAGgMQR1ACQOIIaABJHUANA4ghqAEgcQQ0AiSOoASBxBDUAJI6gBoDEEdQAkDiCGgASR1ADQOIIagBIHEENAIkjqAEgcQQ1ACSOoAaAxGUOattjtudtP1lkQQCAjfpZUd8n6XRRhQAA2tuS5SDb2yR9StI/SvpSoRVt0pH5BR04ekavLy1r6+SEZndPa8/MVNllAcDAMgW1pK9L+rKkqzodYHuvpL2StH379sEr68OR+QXtP3xSyyurkqSFpWXtP3xSkghrAJXXs/Vh+05J5yLiRLfjIuJgRDQjotloNHIrMIsDR8+cD+l1yyurOnD0zFDrAIAiZOlR75R0l+1XJD0maZftbxdaVZ9eX1ru63UAqJKeQR0R+yNiW0TskPQZScci4p7CK+vD1smJvl4HgCqpxRz17O5pTYyPbXhtYnxMs7unS6oIAPKT9WKiJCkifiTpR4VUMoD1C4ZMfQCoo76COmV7ZqYIZgC1VIvWBwDUGUENAIkjqAEgcQQ1ACSOoAaAxBHUAJC42oznlYEd+wAMA0G9SezYB2BYaH1sEjv2ARgWgnqT2LEPwLAQ1JvEjn0AhoWg3iR27AMwLFxM3CR27AMwLMkEdRVH3dixD8AwJBHUjLoBQGdJ9KgZdQOAzpIIakbdAKCzJIKaUTcA6CyJoGbUDQA6S+JiIqNuANBZEkEtMeoGAJ0k0foAAHRGUANA4ghqAEhcMj1q5K+Kt+UDuBRBXVPclg/UR8/Wh+3Lbf/M9vO2T9m+fxiFYTDclg/UR5YV9f9J2hURb9kel/S07f+IiJ8WXBsGwG35QH30XFHHmrda3463fkWhVWFg3JYP1EemqQ/bY7afk3RO0g8j4pk2x+y1PWd7bnFxMe860SduywfqI1NQR8RqRNwsaZukW21/uM0xByOiGRHNRqORd53o056ZKT1w902ampyQJU1NTuiBu2/iQiJQQX1NfUTEku3jku6Q9EIxJeVn1MfTuC0fqIcsUx8N25OtryckfVzSS0UXNqj18bSFpWWF3h9POzK/UHZpANCXLK2PayUdt/0LST/XWo/6yWLLGhzjaQDqomfrIyJ+IWlmCLXkivE0AHVR270+GE8DUBe1DepUx9OOzC9o54PHdP2+72vng8fomQPoqXZ7fVw46XH1xLis0Nsr70mSLh8v97zE/hsANsMR+d9k2Gw2Y25uLvc/t5eLg7CTP7xiXF/5ixslDffxXzsfPKaFNj3yqckJ/WTfrsL+3lE26iOaqA7bJyKi2e73arWibjfp0c4bb69o9tDzUkgr762dqIaxuuUC53DxCQZ1UasedT+Bt7Ia50N6XdHje8O4wEkP/H2MaKIuahXUeQRekavboi9wcpPPRnyCQV3UKqjbBWG/ihzfu3D/DUkas8+v8PIIU1aQGzGiibqoVVC324jono9s1+TE+CXHjo9Z45d5w2vDGN/bMzN1/oSyGhv744OGNSvIjVId0QT6VauLiVL7jYj+Yc9Nba/+S8Od+ljXbeU7yN+/dXKi7VTJqK4g1/9dMvWBqqtdUHfSaSe5Mv6nLWrlO7t7+pLxxFFfQbKDIOqgVq2Pqiiqd8oe1EA9jcyKejOKulmiyJUvK0igfgjqDoq8WYLeKYB+ENQdFHXBbx0rXwBZ0aPugFE3AKlgRd1BP6NubPwDoEisqDvIerPEkfkFzR56fsNt27OHnh/Z27YB5I+g7iDrqNv93zulldWNmzutrIbu/96pIVYLoM5ofXSR5YLfG2+v9PU6APSLFTUAJI6gHlC7DZ+6vQ4A/SKoB/TVu268ZBe+8cusr951Y0kVAagbetQD4i5DAEUjqHPAXYYAikRQ54SbXgAUpWeP2vZ1to/bftH2Kdv3DaOwKuFZhQCKlOVi4ruS/i4ibpD0EUn32r6h2LKqhWcVAihSz6COiF9FxLOtr9+UdFoSn+kvwAZOAIrU13ie7R2SZiQ90+b39tqesz23uLiYT3UVwdOuARQpc1Db/qCkxyV9MSJ+d/HvR8TBiGhGRLPRaORZY/J42jWAImWa+rA9rrWQfjQiDhdbUvUwSw2gSD2D2rYlPSzpdER8rfiSqolZagBFydL62Cnpc5J22X6u9evPC64LANDSc0UdEU9Lcq/jAADFYFMmAEgcQQ0AiSOoASBxbMoEAAMqelM2ghqVxG6FSMX6pmzr+/2sb8omKbf/Jml9oHLYrRApGcambAQ1KofdCpGSYWzKRlCjctitECkZxqZsBDV6OjK/oJ0PHtP1+76vnQ8eK73FwG6FSMkwNmUjqNFViv3gOu5WmNrJENntmZnSA3ffpKnJCVnS1OSEHrj7JqY+MDzd+sFlTVnUbbfCYUwNoFhFb8pGUKOrVPvBddqtMMWTIdJC6wNd0Q8uXqonQ6SDoEZXdewHp4aTIXohqNHVMC6UjDpOhuiFHjV6qlM/OEV1uziK/BHUQAI4GaIbWh8AkDiCGgASR1ADQOIIagBIHEENAIkjqAEgcQQ1ACSOoAaAxHHDC2qBh92izghqVF6q+zlz8kBeerY+bD9i+5ztF4ZRENCvFB92W8aTcXhKTH1l6VH/q6Q7Cq4D2LQU93Me9skjxUemIT89gzoifizpt0OoBdiUFPdzHvbJI8VPFchPblMftvfanrM9t7i4mNcfC/SU4n7Owz55dDoBLCwt0wapgdyCOiIORkQzIpqNRiOvPxboKcWHGwz75NHtBEAbpPqY+kAtpLaf87AfBjC7e3rD5MvFeFhutRHUQBeDjNgN8+Rx4YlhIcGLqxhMlvG870r6b0nTts/a/pviywLKV7VJij0zU/rJvl2aSvDiKgaTZerjsxFxbUSMR8S2iHh4GIUBZavqJEWKF1cxGFofQAdZR+xSuwORh+XWD0GNZKQWeFsnJ9r2ey9sIaR6+3pqF1cxGHbPQxJS7AdnaSFUtT2CamFFjSR0C7yyVoZZWgj93IGY2icGVAdBjSSkuF+H1LuFkKU9IqXbIkE10PpA4bLs6pbifh1ZZJ2woEWCQRDUKFTW3nNVR8qy3r6e6icGVAOtDxQqa++5yiNlWSYssrZIgHYIahSqn5VknUfK2u3FUYVPDEgDrQ8Uqqq957yluMMfqoMVNQrFSvJ9df7EgGIR1ChUlXvPQCoIahQur5UkN4xgVBHUqARuGMEo42IiKoEbRjDKCGpUAjeMYJQR1KgExvwwyghqVEJVbzEH8sDFRCSn23QHUx8YRQQ1ktJruoNgxiii9YGkMN0BXIqgRlKY7gAuRVAjKUx3AJciqJEUpjuAS3ExEUlhugO4FEGN5DDdAWyUqfVh+w7bZ2z/0va+oosCALyvZ1DbHpP0z5I+KekGSZ+1fUPRhQEA1mRZUd8q6ZcR8XJEvCPpMUl/WWxZAIB1WYJ6StJrF3x/tvXaBrb32p6zPbe4uJhXfQAw8nIbz4uIgxHRjIhmo9HI648FgJGXJagXJF13wffbWq8BAIbAEdH9AHuLpP+RdLvWAvrnkv4qIk51+ZlFSa/mWGdZrpH0m7KLKBDvr9p4f9V28fv7o4ho247oOUcdEe/a/ltJRyWNSXqkW0i3fqYWvQ/bcxHRLLuOovD+qo33V239vL9MN7xExA8k/WCgqgAAm8JeHwCQOIK6u4NlF1Aw3l+18f6qLfP763kxEQBQLlbUAJA4ghoAEkdQt2H7EdvnbL9Qdi1FsH2d7eO2X7R9yvZ9ZdeUJ9uX2/6Z7edb7+/+smsqgu0x2/O2nyy7lrzZfsX2SdvP2Z4ru5682Z60fcj2S7ZP2/5o1+PpUV/K9m2S3pL0bxHx4bLryZvtayVdGxHP2r5K0glJeyLixZJLy4VtS7oyIt6yPS7paUn3RcRPSy4tV7a/JKkp6Q8i4s6y68mT7VckNSOilje82P6WpP+KiIdsf0DSFRGx1Ol4VtRtRMSPJf227DqKEhG/iohnW1+/Kem02my0VVWx5q3Wt+OtX7VakdjeJulTkh4quxb0x/bVkm6T9LAkRcQ73UJaIqhHnu0dkmYkPVNuJflqtQWek3RO0g8jolbvT9LXJX1Z0ntlF1KQkPSftk/Y3lt2MTm7XtKipG+2WlcP2b6y2w8Q1CPM9gclPS7pixHxu7LryVNErEbEzVrbROxW27VpYdm+U9K5iDhRdi0F+lhE3KK1B5bc22pH1sUWSbdI+kZEzEj6vaSuT84iqEdUq3f7uKRHI+Jw2fUUpfWR8rikO8quJUc7Jd3V6uM+JmmX7W+XW1K+ImKh9c9zkp7Q2gNM6uKspLMXfMo7pLXg7oigHkGti20PSzodEV8ru5682W7Ynmx9PSHp45JeKreq/ETE/ojYFhE7JH1G0rGIuKfksnJj+8rWRW61WgKfkFSbCayI+LWk12xPt166XVLXC/k8hbwN29+V9KeSrrF9VtJXIuLhcqvK1U5Jn5N0stXHlaS/b22+VQfXSvpW63mfl0n694io3QhbjX1I0hNr6wltkfSdiHiq3JJy9wVJj7YmPl6W9PluBzOeBwCJo/UBAIkjqAEgcQQ1ACSOoAaAxBHUAJA4ghoAEkdQA0Di/h85K3tJLjWDlQAAAABJRU5ErkJggg==\n", 217 | "text/plain": [ 218 | "
" 219 | ] 220 | }, 221 | "metadata": { 222 | "needs_background": "light" 223 | }, 224 | "output_type": "display_data" 225 | } 226 | ], 227 | "source": [ 228 | "plt.scatter(cpp_pred[:,0], cpp_pred[:,1])" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "### Error" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 7, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "err_arr = py_pred-cpp_pred" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 8, 250 | "metadata": { 251 | "scrolled": false 252 | }, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "text/plain": [ 257 | "array([[-8.92578125e-07, 1.34826660e-07],\n", 258 | " [-1.50512695e-07, 3.93630981e-07],\n", 259 | " [ 3.43307495e-07, -3.08654785e-07],\n", 260 | " [ 9.30175781e-08, 4.96704102e-07],\n", 261 | " [-1.14974976e-07, 3.18202972e-07],\n", 262 | " [ 2.67791748e-07, -5.67001343e-07],\n", 263 | " [ 2.16369629e-07, 1.33991241e-07],\n", 264 | " [-8.96942139e-07, -3.52210999e-07],\n", 265 | " [-9.95788574e-08, -5.91888427e-08],\n", 266 | " [-2.92861938e-07, -5.13916016e-07],\n", 267 | " [ 1.62368774e-07, 3.67301941e-07],\n", 268 | " [ 7.11822508e-08, -7.37400055e-08],\n", 269 | " [-5.10864258e-08, 1.93130493e-07],\n", 270 | " [ 4.30297852e-07, 8.40072631e-08],\n", 271 | " [-7.60841370e-08, 4.30877686e-07],\n", 272 | " [ 4.08241272e-07, -6.54205322e-07],\n", 273 | " [-5.30838013e-07, 5.15747067e-09],\n", 274 | " [-5.32531739e-08, -4.22344208e-07],\n", 275 | " [-1.09169006e-07, 3.57879639e-07],\n", 276 | " [-1.01516723e-07, 1.77398682e-07],\n", 277 | " [ 3.00697327e-07, 2.73071289e-07],\n", 278 | " [ 2.37304687e-07, -3.60340118e-07],\n", 279 | " [ 1.62429810e-08, -1.38992310e-07],\n", 280 | " [-1.07635498e-07, -4.05700684e-07],\n", 281 | " [-2.14370728e-07, -4.64019774e-08],\n", 282 | " [-2.38098145e-07, 1.80664061e-08],\n", 283 | " [ 4.91937637e-07, -4.47113037e-07],\n", 284 | " [-4.20684815e-07, 1.23596191e-08],\n", 285 | " [-2.41592407e-07, 6.34155275e-08],\n", 286 | " [ 3.24317932e-07, 2.67364502e-07],\n", 287 | " [ 3.61763000e-07, 2.82348633e-07],\n", 288 | " [-4.31076050e-07, -3.83193970e-07],\n", 289 | " [ 1.58363342e-07, -3.17321778e-07],\n", 290 | " [ 1.25961304e-07, 8.76281738e-07]])" 291 | ] 292 | }, 293 | "execution_count": 8, 294 | "metadata": {}, 295 | "output_type": "execute_result" 296 | } 297 | ], 298 | "source": [ 299 | "err_arr" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "### Norm" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 9, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "Norms for python predictions and c++ predictions: 20.656949141902103 20.65694939092029\n" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "py_pred_norm = np.linalg.norm(py_pred)\n", 324 | "cpp_pred_norm = np.linalg.norm(cpp_pred)\n", 325 | "\n", 326 | "print(\"Norms for python predictions and c++ predictions:\", py_pred_norm, cpp_pred_norm)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "### np.allclose()" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 10, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "data": { 343 | "text/plain": [ 344 | "True" 345 | ] 346 | }, 347 | "execution_count": 10, 348 | "metadata": {}, 349 | "output_type": "execute_result" 350 | } 351 | ], 352 | "source": [ 353 | "# True if two arrays are element-wise equal within a tolerance.\n", 354 | "np.allclose(py_pred, cpp_pred) # default rtol=1e-05 " 355 | ] 356 | } 357 | ], 358 | "metadata": { 359 | "kernelspec": { 360 | "display_name": "Python 3", 361 | "language": "python", 362 | "name": "python3" 363 | }, 364 | "language_info": { 365 | "codemirror_mode": { 366 | "name": "ipython", 367 | "version": 3 368 | }, 369 | "file_extension": ".py", 370 | "mimetype": "text/x-python", 371 | "name": "python", 372 | "nbconvert_exporter": "python", 373 | "pygments_lexer": "ipython3", 374 | "version": "3.7.5" 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 2 379 | } 380 | --------------------------------------------------------------------------------