├── code ├── node2vec │ ├── stdafx.cpp │ ├── node2vec │ ├── stdafx.h │ ├── Makefile │ ├── Makefile.ex │ ├── graph │ │ └── karate.edgelist │ ├── targetver.h │ ├── ReadMe.txt │ ├── node2vec.cpp │ └── emb │ │ └── karate.emb ├── read_results.py ├── get_histograms.py ├── get_node2vec.py ├── main.py └── main_data_augmentation.py ├── results_graph_cnn_github.png ├── image_example_graph_cnn_github.png ├── datasets ├── data_as_adj │ └── README.txt ├── raw_node2vec │ └── README.txt ├── tensors │ └── README.txt ├── results │ ├── collab_2017-08-08_11_23_48_parameters.json │ ├── reddit_multi_5K_2017-08-08_13_27_19_parameters.json │ ├── imdb_action_romance_2017-08-08_11_06_36_parameters.json │ ├── reddit_subreddit_10K_2017-08-08_15_38_03_parameters.json │ ├── reddit_iama_askreddit_atheism_trollx_2017-08-08_12_35_28_parameters.json │ └── imdb_action_romance_augmentation_2017-08-09_15_34_09_parameters.json └── classes │ ├── imdb_action_romance │ └── imdb_action_romance_classes.txt │ ├── reddit_iama_askreddit_atheism_trollx │ └── reddit_iama_askreddit_atheism_trollx_classes.txt │ ├── collab │ └── collab_classes.txt │ └── reddit_multi_5K │ └── reddit_multi_5K_classes.txt ├── baselines ├── graphlet_kernel.py └── wl_subtree_kernel.py └── README.md /code/node2vec/stdafx.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | -------------------------------------------------------------------------------- /code/node2vec/node2vec: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tixierae/graph_2D_CNN/HEAD/code/node2vec/node2vec -------------------------------------------------------------------------------- /code/node2vec/stdafx.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "targetver.h" 4 | 5 | #include "Snap.h" 6 | 7 | -------------------------------------------------------------------------------- /code/node2vec/Makefile: -------------------------------------------------------------------------------- 1 | include ../../Makefile.config 2 | include Makefile.ex 3 | include ../Makefile.exmain 4 | -------------------------------------------------------------------------------- /results_graph_cnn_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tixierae/graph_2D_CNN/HEAD/results_graph_cnn_github.png -------------------------------------------------------------------------------- /image_example_graph_cnn_github.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tixierae/graph_2D_CNN/HEAD/image_example_graph_cnn_github.png -------------------------------------------------------------------------------- /datasets/data_as_adj/README.txt: -------------------------------------------------------------------------------- 1 | # adjacency matrices can be downloaded at: https://www.dropbox.com/s/hn34a8trrpjik1i/data_as_adj.zip?dl=0 -------------------------------------------------------------------------------- /datasets/raw_node2vec/README.txt: -------------------------------------------------------------------------------- 1 | # pre-computed node2vec embeddings can be downloaded at: https://www.dropbox.com/s/acdz80qu7pod88l/raw_node2vec.zip?dl=0 -------------------------------------------------------------------------------- /datasets/tensors/README.txt: -------------------------------------------------------------------------------- 1 | # pre-computed image representations of graphs (histograms) can be downloaded at: https://www.dropbox.com/s/qqnfrk4798gnlg2/tensors.zip?dl=0 -------------------------------------------------------------------------------- /code/node2vec/Makefile.ex: -------------------------------------------------------------------------------- 1 | MAIN = node2vec 2 | DEPH = $(EXSNAPADV)/n2v.h $(EXSNAPADV)/word2vec.h $(EXSNAPADV)/biasedrandomwalk.h 3 | DEPCPP = $(EXSNAPADV)/n2v.cpp $(EXSNAPADV)/word2vec.cpp $(EXSNAPADV)/biasedrandomwalk.cpp 4 | CXXFLAGS += $(CXXOPENMP) 5 | -------------------------------------------------------------------------------- /datasets/results/collab_2017-08-08_11_23_48_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "dataset": "collab", 4 | "definition": 9, 5 | "dim_ordering": "th", 6 | "drop_rate": 0.3, 7 | "my_optimizer": "adam", 8 | "my_patience": 5, 9 | "n_channels": 5, 10 | "n_folds": 10, 11 | "n_repeats": 3, 12 | "nb_epochs": 50, 13 | "p": "0.25", 14 | "path_root": "/home/antoine/Desktop/graph_2D_CNN/datasets/", 15 | "q": "4" 16 | } -------------------------------------------------------------------------------- /datasets/results/reddit_multi_5K_2017-08-08_13_27_19_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "dataset": "reddit_multi_5K", 4 | "definition": 9, 5 | "dim_ordering": "th", 6 | "drop_rate": 0.3, 7 | "my_optimizer": "adam", 8 | "my_patience": 5, 9 | "n_channels": 2, 10 | "n_folds": 10, 11 | "n_repeats": 3, 12 | "nb_epochs": 50, 13 | "p": "4", 14 | "path_root": "/home/antoine/Desktop/graph_2D_CNN/datasets/", 15 | "q": "0.25" 16 | } -------------------------------------------------------------------------------- /datasets/results/imdb_action_romance_2017-08-08_11_06_36_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "dataset": "imdb_action_romance", 4 | "definition": 14, 5 | "dim_ordering": "th", 6 | "drop_rate": 0.3, 7 | "my_optimizer": "adam", 8 | "my_patience": 5, 9 | "n_channels": 5, 10 | "n_folds": 10, 11 | "n_repeats": 3, 12 | "nb_epochs": 50, 13 | "p": "1", 14 | "path_root": "/home/antoine/Desktop/graph_2D_CNN/datasets/", 15 | "q": "1" 16 | } -------------------------------------------------------------------------------- /datasets/results/reddit_subreddit_10K_2017-08-08_15_38_03_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "dataset": "reddit_subreddit_10K", 4 | "definition": 9, 5 | "dim_ordering": "th", 6 | "drop_rate": 0.3, 7 | "my_optimizer": "adam", 8 | "my_patience": 5, 9 | "n_channels": 5, 10 | "n_folds": 10, 11 | "n_repeats": 3, 12 | "nb_epochs": 50, 13 | "p": "1", 14 | "path_root": "/home/antoine/Desktop/graph_2D_CNN/datasets/", 15 | "q": "1" 16 | } -------------------------------------------------------------------------------- /datasets/results/reddit_iama_askreddit_atheism_trollx_2017-08-08_12_35_28_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "dataset": "reddit_iama_askreddit_atheism_trollx", 4 | "definition": 9, 5 | "dim_ordering": "th", 6 | "drop_rate": 0.3, 7 | "my_optimizer": "adam", 8 | "my_patience": 5, 9 | "n_channels": 5, 10 | "n_folds": 10, 11 | "n_repeats": 3, 12 | "nb_epochs": 50, 13 | "p": "2", 14 | "path_root": "/home/antoine/Desktop/graph_2D_CNN/datasets/", 15 | "q": "0.5" 16 | } -------------------------------------------------------------------------------- /datasets/results/imdb_action_romance_augmentation_2017-08-09_15_34_09_parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 32, 3 | "dataset": "imdb_action_romance", 4 | "definition": 14, 5 | "dim_ordering": "th", 6 | "drop_rate": 0.3, 7 | "my_optimizer": "adam", 8 | "my_patience": 5, 9 | "n_bootstrap": 0.1, 10 | "n_channels": 5, 11 | "n_folds": 10, 12 | "n_repeats": 3, 13 | "nb_epochs": 50, 14 | "p": "1", 15 | "path_root": "/home/antoine/Desktop/graph_2D_CNN/datasets/", 16 | "q": "1" 17 | } -------------------------------------------------------------------------------- /code/node2vec/graph/karate.edgelist: -------------------------------------------------------------------------------- 1 | 1 32 2 | 1 22 3 | 1 20 4 | 1 18 5 | 1 14 6 | 1 13 7 | 1 12 8 | 1 11 9 | 1 9 10 | 1 8 11 | 1 7 12 | 1 6 13 | 1 5 14 | 1 4 15 | 1 3 16 | 1 2 17 | 2 31 18 | 2 22 19 | 2 20 20 | 2 18 21 | 2 14 22 | 2 8 23 | 2 4 24 | 2 3 25 | 3 14 26 | 3 9 27 | 3 10 28 | 3 33 29 | 3 29 30 | 3 28 31 | 3 8 32 | 3 4 33 | 4 14 34 | 4 13 35 | 4 8 36 | 5 11 37 | 5 7 38 | 6 17 39 | 6 11 40 | 6 7 41 | 7 17 42 | 9 34 43 | 9 33 44 | 9 33 45 | 10 34 46 | 14 34 47 | 15 34 48 | 15 33 49 | 16 34 50 | 16 33 51 | 19 34 52 | 19 33 53 | 20 34 54 | 21 34 55 | 21 33 56 | 23 34 57 | 23 33 58 | 24 30 59 | 24 34 60 | 24 33 61 | 24 28 62 | 24 26 63 | 25 32 64 | 25 28 65 | 25 26 66 | 26 32 67 | 27 34 68 | 27 30 69 | 28 34 70 | 29 34 71 | 29 32 72 | 30 34 73 | 30 33 74 | 31 34 75 | 31 33 76 | 32 34 77 | 32 33 78 | 33 34 79 | -------------------------------------------------------------------------------- /code/node2vec/targetver.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // The following macros define the minimum required platform. The minimum required platform 4 | // is the earliest version of Windows, Internet Explorer etc. that has the necessary features to run 5 | // your application. The macros work by enabling all features available on platform versions up to and 6 | // including the version specified. 7 | 8 | // Modify the following defines if you have to target a platform prior to the ones specified below. 9 | // Refer to MSDN for the latest info on corresponding values for different platforms. 10 | #ifndef _WIN32_WINNT // Specifies that the minimum required platform is Windows Vista. 11 | #define _WIN32_WINNT 0x0600 // Change this to the appropriate value to target other versions of Windows. 12 | #endif 13 | 14 | -------------------------------------------------------------------------------- /code/read_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | path_to_results = './results/' 6 | 7 | results_names = os.listdir(path_to_results) 8 | 9 | my_prec = 2 # desired precision 10 | 11 | for name in results_names: 12 | with open(path_to_results + name, 'r') as my_file: 13 | tmp = json.load(my_file) 14 | vals = [elt[1] for elt in tmp['outputs']] # 'outputs' contains loss, accuracy for each repeat of each fold 15 | vals = [val*100 for val in vals] 16 | print '=======',name,'=======' 17 | print 'mean:', round(np.mean(vals),my_prec) 18 | print 'median:', round(np.median(vals),my_prec) 19 | print 'max:', round(max(vals),my_prec) 20 | print 'min:', round(min(vals),my_prec) 21 | print 'stdev', round(np.std(vals),my_prec) 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /code/node2vec/ReadMe.txt: -------------------------------------------------------------------------------- 1 | ======================================================================== 2 | Node2vec 3 | ======================================================================== 4 | 5 | node2vec is an algorithmic framework for representational learning on graphs. Given any graph, it can learn continuous feature representations for the nodes, which can then be used for various downstream machine learning tasks. 6 | 7 | The code works under Windows with Visual Studio or Cygwin with GCC, 8 | Mac OS X, Linux and other Unix variants with GCC. Make sure that a 9 | C++ compiler is installed on the system. Visual Studio project files 10 | and makefiles are provided. For makefiles, compile the code with 11 | "make all". 12 | 13 | ///////////////////////////////////////////////////////////////////////////// 14 | 15 | Parameters: 16 | Input graph path (-i:) 17 | Output graph path (-o:) 18 | Number of dimensions. Default is 128 (-d:) 19 | Length of walk per source. Default is 80 (-l:) 20 | Number of walks per source. Default is 10 (-r:) 21 | Context size for optimization. Default is 10 (-k:) 22 | Number of epochs in SGD. Default is 1 (-e:) 23 | Return hyperparameter. Default is 1 (-p:) 24 | Inout hyperparameter. Default is 1 (-q:) 25 | Verbose output. (-v) 26 | Graph is directed. (-dr) 27 | Graph is weighted. (-w) 28 | 29 | ///////////////////////////////////////////////////////////////////////////// 30 | 31 | Usage: 32 | ./node2vec -i:graph/karate.edgelist -o:emb/karate.emb -l:3 -d:24 -p:0.3 -dr -v 33 | -------------------------------------------------------------------------------- /code/node2vec/node2vec.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | 3 | #include "n2v.h" 4 | 5 | #ifdef USE_OPENMP 6 | #include 7 | #endif 8 | 9 | void ParseArgs(int& argc, char* argv[], TStr& InFile, TStr& OutFile, 10 | int& Dimensions, int& WalkLen, int& NumWalks, int& WinSize, int& Iter, 11 | bool& Verbose, double& ParamP, double& ParamQ, bool& Directed, bool& Weighted) { 12 | Env = TEnv(argc, argv, TNotify::StdNotify); 13 | Env.PrepArgs(TStr::Fmt("\nAn algorithmic framework for representational learning on graphs.")); 14 | InFile = Env.GetIfArgPrefixStr("-i:", "graph/karate.edgelist", 15 | "Input graph path"); 16 | OutFile = Env.GetIfArgPrefixStr("-o:", "emb/karate.emb", 17 | "Output graph path"); 18 | Dimensions = Env.GetIfArgPrefixInt("-d:", 128, 19 | "Number of dimensions. Default is 128"); 20 | WalkLen = Env.GetIfArgPrefixInt("-l:", 80, 21 | "Length of walk per source. Default is 80"); 22 | NumWalks = Env.GetIfArgPrefixInt("-r:", 10, 23 | "Number of walks per source. Default is 10"); 24 | WinSize = Env.GetIfArgPrefixInt("-k:", 10, 25 | "Context size for optimization. Default is 10"); 26 | Iter = Env.GetIfArgPrefixInt("-e:", 1, 27 | "Number of epochs in SGD. Default is 1"); 28 | ParamP = Env.GetIfArgPrefixFlt("-p:", 1, 29 | "Return hyperparameter. Default is 1"); 30 | ParamQ = Env.GetIfArgPrefixFlt("-q:", 1, 31 | "Inout hyperparameter. Default is 1"); 32 | Verbose = Env.IsArgStr("-v", "Verbose output."); 33 | Directed = Env.IsArgStr("-dr", "Graph is directed."); 34 | Weighted = Env.IsArgStr("-w", "Graph is weighted."); 35 | } 36 | 37 | void ReadGraph(TStr& InFile, bool& Directed, bool& Weighted, bool& Verbose, PWNet& InNet) { 38 | TFIn FIn(InFile); 39 | int64 LineCnt = 0; 40 | try { 41 | while (!FIn.Eof()) { 42 | TStr Ln; 43 | FIn.GetNextLn(Ln); 44 | TStr Line, Comment; 45 | Ln.SplitOnCh(Line,'#',Comment); 46 | TStrV Tokens; 47 | Line.SplitOnWs(Tokens); 48 | if(Tokens.Len()<2){ continue; } 49 | int64 SrcNId = Tokens[0].GetInt(); 50 | int64 DstNId = Tokens[1].GetInt(); 51 | double Weight = 1.0; 52 | if (Weighted) { Weight = Tokens[2].GetFlt(); } 53 | if (!InNet->IsNode(SrcNId)){ InNet->AddNode(SrcNId); } 54 | if (!InNet->IsNode(DstNId)){ InNet->AddNode(DstNId); } 55 | InNet->AddEdge(SrcNId,DstNId,Weight); 56 | if (!Directed){ InNet->AddEdge(DstNId,SrcNId,Weight); } 57 | LineCnt++; 58 | } 59 | if (Verbose) { printf("Read %lld lines from %s\n", (long long)LineCnt, InFile.CStr()); } 60 | } catch (PExcept Except) { 61 | if (Verbose) { 62 | printf("Read %lld lines from %s, then %s\n", (long long)LineCnt, InFile.CStr(), 63 | Except->GetStr().CStr()); 64 | } 65 | } 66 | } 67 | 68 | void WriteOutput(TStr& OutFile, TIntFltVH& EmbeddingsHV) { 69 | TFOut FOut(OutFile); 70 | bool First = 1; 71 | for (int i = EmbeddingsHV.FFirstKeyId(); EmbeddingsHV.FNextKeyId(i);) { 72 | if (First) { 73 | FOut.PutInt(EmbeddingsHV.Len()); 74 | FOut.PutCh(' '); 75 | FOut.PutInt(EmbeddingsHV[i].Len()); 76 | FOut.PutLn(); 77 | First = 0; 78 | } 79 | FOut.PutInt(EmbeddingsHV.GetKey(i)); 80 | for (int64 j = 0; j < EmbeddingsHV[i].Len(); j++) { 81 | FOut.PutCh(' '); 82 | FOut.PutFlt(EmbeddingsHV[i][j]); 83 | } 84 | FOut.PutLn(); 85 | } 86 | } 87 | 88 | int main(int argc, char* argv[]) { 89 | TStr InFile,OutFile; 90 | int Dimensions, WalkLen, NumWalks, WinSize, Iter; 91 | double ParamP, ParamQ; 92 | bool Directed, Weighted, Verbose; 93 | ParseArgs(argc, argv, InFile, OutFile, Dimensions, WalkLen, NumWalks, WinSize, 94 | Iter, Verbose, ParamP, ParamQ, Directed, Weighted); 95 | PWNet InNet = PWNet::New(); 96 | TIntFltVH EmbeddingsHV; 97 | ReadGraph(InFile, Directed, Weighted, Verbose, InNet); 98 | node2vec(InNet, ParamP, ParamQ, Dimensions, WalkLen, NumWalks, WinSize, Iter, 99 | Verbose, EmbeddingsHV); 100 | WriteOutput(OutFile, EmbeddingsHV); 101 | return 0; 102 | } 103 | -------------------------------------------------------------------------------- /baselines/graphlet_kernel.py: -------------------------------------------------------------------------------- 1 | """ 2 | - Sample code for: 3 | "Classifying Graphs as Images with Convolutional Neural Networks" 4 | arXiv preprint arXiv:1708.02218 5 | 6 | - Computes the graphlet kernel by sampling "num_samples" graphlets of size 6 from each graph. 7 | 8 | - Datasets should be placed into the "datasets" folder and can be downloaded from: https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 9 | 10 | - To run, use the following command: 11 | python graphlet_kernel.py dataset num_samples 12 | """ 13 | 14 | import networkx as nx 15 | import numpy as np 16 | import sys 17 | from math import factorial 18 | from sympy.utilities.iterables import multiset_permutations 19 | 20 | np.random.seed(None) 21 | 22 | def load_data(ds_name): 23 | node2graph = {} 24 | Gs = [] 25 | 26 | with open("../datasets/%s/%s_graph_indicator.txt"%(ds_name,ds_name), "rb") as f: 27 | c = 1 28 | for line in f: 29 | node2graph[c] = int(line[:-1]) 30 | if not node2graph[c] == len(Gs): 31 | Gs.append(nx.Graph()) 32 | Gs[-1].add_node(c) 33 | c += 1 34 | 35 | with open("../datasets/%s/%s_A.txt"%(ds_name,ds_name), "rb") as f: 36 | for line in f: 37 | edge = line[:-1].split(",") 38 | edge[1] = edge[1].replace(" ", "") 39 | Gs[node2graph[int(edge[0])]-1].add_edge(int(edge[0]), int(edge[1])) 40 | 41 | labels = [] 42 | with open("../datasets/%s/%s_graph_labels.txt"%(ds_name,ds_name), "rb") as f: 43 | for line in f: 44 | labels.append(int(line[:-1])) 45 | 46 | labels = np.array(labels, dtype=np.float) 47 | return Gs, labels 48 | 49 | 50 | def generate_permutation_matrix(): 51 | P = np.zeros((2**15,2**15),dtype=np.uint8) 52 | 53 | for a in range(2): 54 | for b in range(2): 55 | for c in range(2): 56 | for d in range(2): 57 | for e in range(2): 58 | for f in range(2): 59 | for g in range(2): 60 | for h in range(2): 61 | for i in range(2): 62 | for j in range(2): 63 | for k in range(2): 64 | for l in range(2): 65 | for m in range(2): 66 | for n in range(2): 67 | for o in range(2): 68 | A = np.array([[0,a,b,c,d,e],[a,0,f,g,h,i],[b,f,0,j,k,l],[c,g,j,0,m,n],[d,h,k,m,0,o],[e,i,l,n,o,0]]) 69 | 70 | perms = multiset_permutations(np.array(range(6),dtype=np.uint8)) 71 | Per = np.zeros((factorial(6),6),dtype=np.uint8) 72 | ind = 0 73 | for permutation in perms: 74 | Per[ind,:] = permutation 75 | ind += 1 76 | 77 | for p in range(factorial(6)): 78 | A_per = A[np.ix_(Per[p,:],Per[p,:])] 79 | P[graphlet_type(A), graphlet_type(A_per)] = 1 80 | return P 81 | 82 | 83 | def graphlet_type(A): 84 | factor = 2**np.array(range(15)) 85 | 86 | upper = np.hstack((A[0,1:6],A[1,2:6],A[2,3:6],A[3,4:6],A[4,5])) 87 | result = np.sum(factor*upper) 88 | 89 | return int(result) 90 | 91 | 92 | def graphlet_kernel(graphs, num_samples): 93 | N = len(graphs) 94 | 95 | Phi = np.zeros((N,2**15)) 96 | 97 | P = generate_permutation_matrix() 98 | 99 | for i in range(len(graphs)): 100 | n = graphs[i].number_of_nodes() 101 | if n >= 6: 102 | A = nx.to_numpy_matrix(graphs[i]) 103 | A = np.asarray(A, dtype=np.uint8) 104 | for j in range(num_samples): 105 | r = np.random.permutation(n) 106 | window = A[np.ix_(r[:6],r[:6])] 107 | Phi[i, graphlet_type(window)] += 1 108 | 109 | Phi[i,:] /= num_samples 110 | 111 | K = np.dot(Phi,np.dot(P,np.transpose(Phi))) 112 | return K 113 | 114 | 115 | if __name__ == "__main__": 116 | # read the parameters 117 | ds_name = sys.argv[1] 118 | num_samples = int(sys.argv[2]) 119 | 120 | graphs, labels = load_data(ds_name) 121 | np.save(ds_name+"_labels", labels) 122 | 123 | print("Building graphlet kernel for "+ds_name) 124 | 125 | K = graphlet_kernel(graphs, num_samples) 126 | np.save(ds_name+"_graphlet_"+str(num_samples)+"_samples", K) 127 | -------------------------------------------------------------------------------- /baselines/wl_subtree_kernel.py: -------------------------------------------------------------------------------- 1 | """ 2 | - Sample code for: 3 | "Classifying Graphs as Images with Convolutional Neural Networks" 4 | arXiv preprint arXiv:1708.02218 5 | 6 | - Computes the Weisfeiler-Lehman subtree kernel with h iterations for a set of graphs. Each vertex is assigned its degree as label. 7 | 8 | - Datasets should be placed into the "datasets" folder and can be downloaded from: https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 9 | 10 | - To run, use the following command: 11 | python wl_subtree_kernel.py dataset h 12 | 13 | - Part of the code is modified from: 14 | "Deep graph kernels" 15 | Proceedings of the 21th International Conference on Knowledge Discovery and Data Mining, 2015. 16 | """ 17 | 18 | import networkx as nx 19 | import numpy as np 20 | import sys 21 | from collections import defaultdict 22 | import copy 23 | 24 | np.random.seed(None) 25 | 26 | def load_data(ds_name): 27 | node2graph = {} 28 | Gs = [] 29 | 30 | with open("../datasets/%s/%s_graph_indicator.txt"%(ds_name,ds_name), "rb") as f: 31 | c = 1 32 | for line in f: 33 | node2graph[c] = int(line[:-1]) 34 | if not node2graph[c] == len(Gs): 35 | Gs.append(nx.Graph()) 36 | Gs[-1].add_node(c) 37 | c += 1 38 | 39 | with open("../datasets/%s/%s_A.txt"%(ds_name,ds_name), "rb") as f: 40 | for line in f: 41 | edge = line[:-1].split(",") 42 | edge[1] = edge[1].replace(" ", "") 43 | Gs[node2graph[int(edge[0])]-1].add_edge(int(edge[0]), int(edge[1])) 44 | 45 | labels = [] 46 | with open("../datasets/%s/%s_graph_labels.txt"%(ds_name,ds_name), "rb") as f: 47 | for line in f: 48 | labels.append(int(line[:-1])) 49 | 50 | labels = np.array(labels, dtype=np.float) 51 | return Gs, labels 52 | 53 | 54 | def wl_subtree_kernel(graphs, h): 55 | N = len(graphs) 56 | 57 | labels = {} 58 | label_lookup = {} 59 | label_counter = 0 60 | 61 | for G in graphs: 62 | for node in G.nodes(): 63 | G.node[node]['label'] = G.degree(node) 64 | 65 | orig_graph_map = {it: {i: defaultdict(lambda: 0) for i in range(N)} for it in range(-1, h)} 66 | 67 | # initial labeling 68 | ind = 0 69 | for G in graphs: 70 | labels[ind] = np.zeros(G.number_of_nodes(), dtype = np.int32) 71 | node2index = {} 72 | for node in G.nodes(): 73 | node2index[node] = len(node2index) 74 | 75 | for node in G.nodes(): 76 | label = G.node[node]['label'] 77 | if not label_lookup.has_key(label): 78 | label_lookup[label] = len(label_lookup) 79 | 80 | labels[ind][node2index[node]] = label_lookup[label] 81 | orig_graph_map[-1][ind][label] = orig_graph_map[-1][ind].get(label, 0) + 1 82 | 83 | ind += 1 84 | 85 | compressed_labels = copy.deepcopy(labels) 86 | 87 | # WL iterations 88 | for it in range(h): 89 | unique_labels_per_h = set() 90 | label_lookup = {} 91 | ind = 0 92 | for G in graphs: 93 | node2index = {} 94 | for node in G.nodes(): 95 | node2index[node] = len(node2index) 96 | 97 | for node in G.nodes(): 98 | node_label = tuple([labels[ind][node2index[node]]]) 99 | neighbors = G.neighbors(node) 100 | if len(neighbors) > 0: 101 | neighbors_label = tuple([labels[ind][node2index[neigh]] for neigh in neighbors]) 102 | node_label = str(node_label) + "-" + str(sorted(neighbors_label)) 103 | if not label_lookup.has_key(node_label): 104 | label_lookup[node_label] = len(label_lookup) 105 | 106 | compressed_labels[ind][node2index[node]] = label_lookup[node_label] 107 | orig_graph_map[it][ind][node_label] = orig_graph_map[it][ind].get(node_label, 0) + 1 108 | 109 | ind +=1 110 | 111 | labels = copy.deepcopy(compressed_labels) 112 | 113 | K = np.zeros((N, N)) 114 | for it in range(-1, h): 115 | for i in range(N): 116 | for j in range(N): 117 | common_keys = set(orig_graph_map[it][i].keys()) & set(orig_graph_map[it][j].keys()) 118 | K[i][j] += sum([orig_graph_map[it][i].get(k,0)*orig_graph_map[it][j].get(k,0) for k in common_keys]) 119 | 120 | return K 121 | 122 | 123 | if __name__ == "__main__": 124 | # read the parameters 125 | ds_name = sys.argv[1] 126 | h = int(sys.argv[2]) 127 | 128 | graphs, labels = load_data(ds_name) 129 | np.save(ds_name+"_labels", labels) 130 | 131 | print("Building wl subtree kernel for "+ds_name) 132 | 133 | K = wl_subtree_kernel(graphs, h) 134 | np.save(ds_name+"_wl_subtree_"+str(h), K) -------------------------------------------------------------------------------- /code/get_histograms.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import numpy as np 5 | import time as t 6 | from multiprocessing import Pool, cpu_count 7 | from functools import partial 8 | 9 | # ============================================================================= 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | # positional arguments (required) 14 | parser.add_argument('path_to_node2vec', type=str, help='path to the root folder where node2vec arrays are stored for all datasets') 15 | parser.add_argument('path_to_hist', type=str, help='path to the root folder where histograms should be written for all datasets') 16 | parser.add_argument('dataset', type=str, help='name of the dataset. Must correspond to a valid value that matches names of files in node2vec folder') 17 | parser.add_argument('p', type=str, help='p parameter of node2vec. Must correspond to a valid value that matches names of files in node2vec folder') 18 | parser.add_argument('q', type=str, help='q parameter of node2vec. Must correspond to a valid value that matches names of files in node2vec folder') 19 | parser.add_argument('definition', type=int, help='definition. E.g., 14 for 14:1. Must correspond to a valid value that matches names of files in node2vec folder') 20 | parser.add_argument('max_n_channels', type=int, help='maximum number of channels that we will be able to pass to the network. Must not exceed half the depth of the tensors in node2vec folder') 21 | 22 | args = parser.parse_args() 23 | 24 | # convert command line arguments 25 | path_to_node2vec = args.path_to_node2vec 26 | path_to_hist = args.path_to_hist 27 | dataset = args.dataset 28 | p = args.p 29 | q = args.q 30 | definition = args.definition 31 | max_n_channels = args.max_n_channels 32 | 33 | # command line example: python get_histograms.py /home/antoine/Desktop/graph_2D_CNN/datasets/raw_node2vec/ /home/antoine/Desktop/graph_2D_CNN/datasets/tensors/ imdb_action_romance 1 1 14 5 34 | 35 | # ============================================================================= 36 | 37 | def atoi(text): 38 | return int(text) if text.isdigit() else text 39 | 40 | def natural_keys(text): 41 | return [atoi(c) for c in re.split('(\d+)', text)] 42 | 43 | def get_hist_node2vec(emb,d,my_min,my_max,definition): 44 | # d should be an even integer 45 | img_dim = int(np.arange(my_min, my_max+0.05,(my_max+0.05-my_min)/float(definition*(my_max+0.05-my_min))).shape[0]-1) 46 | my_bins = np.linspace(my_min,my_max,img_dim) # to have middle bin centered on zero 47 | Hs = [] 48 | for i in range(0,d,2): 49 | H, xedges, yedges = np.histogram2d(x=emb[:,i],y=emb[:,i+1],bins=my_bins, normed=False) 50 | Hs.append(H) 51 | Hs = np.array(Hs) 52 | return Hs 53 | 54 | def to_parallelize(my_file_name,dataset,n_dim,my_min,my_max,my_def,path_read,path_write): 55 | path_write_dataset = path_write + dataset + '/node2vec_hist/' 56 | 57 | p_value = [elt for elt in my_file_name.split('_') if elt.startswith('p=')][0] 58 | q_value = [elt for elt in my_file_name.split('_') if elt.startswith('q=')][0] 59 | real_idx = my_file_name.split('.npy')[0].split('_')[-1:][0] 60 | emb = np.load(path_read + my_file_name) 61 | emb = emb[:,:n_dim] 62 | my_hist = get_hist_node2vec(emb=emb,d=n_dim,my_min=my_min,my_max=my_max,definition=my_def) 63 | 64 | np.save(path_write_dataset + dataset + '_' + str(my_def) + ':1'+ '_' + p_value + '_' + q_value + '_' + real_idx, my_hist, allow_pickle=False) 65 | if int(real_idx) % 1000 == 0: 66 | print 'done', my_hist.shape 67 | 68 | # ============================================================================= 69 | 70 | def main(): 71 | t_start = t.time() 72 | 73 | n_dim = 2*max_n_channels 74 | 75 | all_file_names = os.listdir(path_to_node2vec + dataset + '/') 76 | print '===== total number of files in folder: =====', len(all_file_names) 77 | 78 | file_names_filtered = [elt for elt in all_file_names if dataset in elt and 'p=' + p in elt and 'q=' + q in elt] 79 | file_names_filtered.sort(key=natural_keys) 80 | print '===== number of files after filtering: =====', len(file_names_filtered) 81 | print '*** head ***' 82 | print file_names_filtered[:5] 83 | print '*** tail ***' 84 | print file_names_filtered[-5:] 85 | 86 | # load tensors 87 | tensors = [] 88 | for idx, name in enumerate(file_names_filtered): 89 | tensor = np.load(path_to_node2vec + dataset + '/' + name) 90 | tensors.append(tensor[:,:n_dim]) 91 | if idx % round(len(file_names_filtered)/10) == 0: 92 | print idx 93 | 94 | print 'tensors loaded' 95 | 96 | full = np.concatenate(tensors) 97 | my_max = np.amax(full) 98 | my_min = np.amin(full) 99 | print 'range:', my_max, my_min 100 | 101 | to_parallelize_partial = partial(to_parallelize, dataset=dataset, n_dim=n_dim, my_min=my_min, my_max=my_max, my_def=definition, path_read=path_to_node2vec + dataset + '/',path_write=path_to_hist) 102 | 103 | n_jobs = 2*cpu_count() 104 | 105 | print 'creating', n_jobs, 'jobs' 106 | 107 | pool = Pool(processes=n_jobs) 108 | pool.map(to_parallelize_partial, file_names_filtered) 109 | pool.close() 110 | 111 | print 'done in ', round(t.time() - t_start,4) 112 | 113 | if __name__ == "__main__": 114 | main() -------------------------------------------------------------------------------- /code/get_node2vec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import re 5 | import igraph 6 | import numpy as np 7 | from subprocess import call 8 | from sklearn.decomposition import PCA 9 | 10 | import tempfile 11 | import shutil 12 | 13 | from multiprocessing import Pool, cpu_count 14 | from functools import partial 15 | import shelve 16 | import time as t 17 | import datetime 18 | 19 | # ============================================================================= 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | # positional arguments (required) 24 | parser.add_argument('path_node2vec', type=str, help='path to node2vec executable') 25 | parser.add_argument('path_read', type=str, help='path to adjacency matrices') 26 | parser.add_argument('path_write', type=str, help='path to folder where node2vec embeddings should be saved') 27 | parser.add_argument('path_stats', type=str, help='path to folder where statistics should be saved') 28 | parser.add_argument('dataset', type=str, help='name of the dataset. Must correspond to a valid value that matches an adjacency matrix folder') 29 | parser.add_argument('p', type=str, help='p parameter of node2vec') 30 | parser.add_argument('q', type=str, help='q parameter of node2vec') 31 | 32 | # optional arguments 33 | parser.add_argument('--max_n_channels', type=int, default=5, help='maximum number of channels that we will be able to pass to the network') 34 | 35 | args = parser.parse_args() 36 | 37 | # convert command line arguments 38 | path_node2vec = args.path_node2vec 39 | path_read = args.path_read 40 | path_write = args.path_write 41 | path_stats = args.path_stats 42 | dataset = args.dataset 43 | p = args.p 44 | q = args.q 45 | max_n_channels = args.max_n_channels 46 | 47 | # command line example: python get_node2vec.py /home/antoine/Desktop/snap-master/examples/node2vec/ /home/antoine/Desktop/share_ubuntu/datasets/data_as_adj/ /home/antoine/Desktop/graph_2D_CNN/datasets/raw_node2vec/ /home/antoine/Desktop/graph_2D_CNN/datasets/stats/ imdb_action_romance 1 1 48 | 49 | # ============================================================================= 50 | 51 | def atoi(text): 52 | return int(text) if text.isdigit() else text 53 | 54 | def natural_keys(text): 55 | return [atoi(c) for c in re.split('(\d+)', text)] 56 | 57 | def get_embeddings_node2vec(g,d,p,q,path_node2vec): 58 | my_pca = PCA(n_components=d) 59 | my_edgelist = igraph.Graph.get_edgelist(g) 60 | # create temp dir to write and read from 61 | tmpdir = tempfile.mkdtemp() 62 | # create subdirs for node2vec 63 | os.makedirs(tmpdir + '/graph/') 64 | os.makedirs(tmpdir + '/emb/') 65 | # write edge list 66 | with open(tmpdir + '/graph/input.edgelist', 'w') as my_file: 67 | my_file.write('\n'.join('%s %s' % x for x in my_edgelist)) 68 | # execute node2vec 69 | call([path_node2vec + 'node2vec -i:' + tmpdir + '/graph/input.edgelist' + ' -o:' + tmpdir + '/emb/output.emb' + ' -p:' + p + ' -q:' + q],shell=True) 70 | # read back results 71 | emb = np.loadtxt(tmpdir + '/emb/output.emb',skiprows=1) 72 | # sort by increasing node index and keep only coordinates 73 | emb = emb[emb[:,0].argsort(),1:] 74 | # remove temp dir 75 | shutil.rmtree(tmpdir) 76 | # perform PCA on the embeddings to align and reduce dim 77 | pca_output = my_pca.fit_transform(emb) 78 | return pca_output 79 | 80 | def to_parallelize(file_name,p,q,dataset,path_read,path_write): 81 | excluded = '' 82 | excluded_exc = '' 83 | 84 | idx = file_name.split('.txt')[0].split('_')[-1:][0] 85 | 86 | adj_mat = np.loadtxt(path_read + dataset + '/' + file_name) 87 | g = igraph.Graph.Adjacency(adj_mat.tolist(),mode='UNDIRECTED') 88 | if len(g.vs)<(max_n_channels*2): # exclude graphs with less nodes than the required min number of dims 89 | excluded = file_name 90 | try: 91 | emb = get_embeddings_node2vec(g,d=max(20,max_n_channels*2),p=p,q=q,path_node2vec=path_node2vec) 92 | np.save(path_write + dataset + '/' + dataset + '_node2vec_raw_p=' + p + '_q=' + q + '_' + idx, emb, allow_pickle=False) 93 | except Exception, e: 94 | print e 95 | excluded_exc = file_name 96 | 97 | return [len(g.vs),len(g.es),excluded,excluded_exc] 98 | 99 | # ============================================================================= 100 | 101 | def main(): 102 | my_date_time = '_'.join(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").split()) 103 | 104 | file_names = os.listdir(path_read + dataset + '/') 105 | file_names.sort(key=natural_keys) 106 | print '===== number of graphs: =====', len(file_names) 107 | print '*** head ***' 108 | print file_names[:5] 109 | print '*** tail ***' 110 | print file_names[-5:] 111 | 112 | # map 'to_parallelize' over all files 113 | to_parallelize_partial = partial(to_parallelize,p=p,q=q,dataset=dataset,path_read=path_read,path_write=path_write) 114 | 115 | n_jobs = cpu_count() 116 | 117 | print 'using', n_jobs, 'cores' 118 | t_start = t.time() 119 | 120 | pool = Pool(processes=n_jobs) 121 | lol = pool.map(to_parallelize_partial, file_names) 122 | pool.close() 123 | 124 | print 'type', type(lol) 125 | print 'len', len(lol) 126 | print 'len lol[0]', len(lol[0]) 127 | print lol[0] 128 | 129 | stats_array = np.array(lol) 130 | print 'shape', stats_array.shape 131 | 132 | np.savetxt(path_stats + dataset + '/' + dataset + '_' + my_date_time + '.txt', stats_array, fmt='%s') 133 | 134 | print 'done in ', round(t.time() - t_start,4) 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph classification with 2D CNNs ![GitHub stars](https://img.shields.io/github/stars/tixierae/graph_2D_CNN.svg?style=plastic) ![GitHub forks](https://img.shields.io/github/forks/tixierae/graph_2D_CNN.svg?color=blue&style=plastic) 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-classification-with-2d-convolutional/graph-classification-on-collab)](https://paperswithcode.com/sota/graph-classification-on-collab?p=graph-classification-with-2d-convolutional) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-classification-with-2d-convolutional/graph-classification-on-re-m12k)](https://paperswithcode.com/sota/graph-classification-on-re-m12k?p=graph-classification-with-2d-convolutional) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-classification-with-2d-convolutional/graph-classification-on-re-m5k)](https://paperswithcode.com/sota/graph-classification-on-re-m5k?p=graph-classification-with-2d-convolutional) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/graph-classification-with-2d-convolutional/graph-classification-on-imdb-b)](https://paperswithcode.com/sota/graph-classification-on-imdb-b?p=graph-classification-with-2d-convolutional) 4 | 5 | ### What is this repo for? 6 | This repo provides the code and datasets used in the paper [Classifying graphs as images with Convolutional Neural Networks](https://arxiv.org/abs/1708.02218) (Tixier, Nikolentzos, Meladianos and Vazirgiannis, 2017). Note that the paper was published at the ICANN 2019 conference under the title *Graph classification with 2D convolutional neural networks*. As its name suggests, the paper introduces a technique to perform graph classification with standard Convolutional Neural Networks for images (2D CNNs). 7 | 8 | ### Idea 9 | We encode graphs as stacks of 2D histograms of their node embeddings, and pass them to a classical 2D CNN architecture designed for images. The *bins* of the histograms can be viewed as *pixels*, and the value of a given pixel is the number of nodes falling into the associated bin. 10 | 11 | For instance, below are the node embeddings and corresponding bivariate histograms for graph ID #10001 (577 nodes, 1320 edges) of the REDDIT-12K dataset: 12 | ![alt text](https://github.com/Tixierae/graph_2D_CNN/raw/master/image_example_graph_cnn_github.png) 13 | The full image representation of a graph is given by stacking its n_channels bivariate histograms (where n_channels can be 2,5...). Each pixel is thus associated with a n_channels-dimensional vector of counts. 14 | 15 | ### Results 16 | Despite its simplicity, our method proves very competitive to state-of-the-art graph kernels, and even outperforms them by a wide margin on some datasets. 17 | 18 | 10-fold CV average test set classification accuracy of state-of-the-art graph kernel and graph CNN baselines (top), vs our 2D CNN approach (bottom): 19 | ![alt text](https://github.com/Tixierae/graph_2D_CNN/raw/master/results_graph_cnn_github.png) 20 | 21 | The results reported in the paper (without data augmentation) are available in the `/datasets/results/` subdirectory, with slight variations due to the stochasticity of the approach. You can read them using the `read_results.py` script. 22 | 23 | ### Advantages over graph kernels + SVM (GK+SVM) 24 | We can summarize the advantages of our approach as follows: 25 | * **better accuracy**: CNNs learn their own features directly from the raw data during training to optimize performance on the downstream task (whereas GKs compute similarity *a priori*) 26 | * **better accuracy**: we compute images of graphs from their node embeddings (obtained via node2vec), so we capture both *local* and *global* information about the networks (whereas most GKs, based on substructures, capture only local information) 27 | * **reduced time complexity at the graph level**: node2vec is linear in the number of nodes (whereas most GKs are polynomial) -> we can process bigger graphs 28 | * **reduced time complexity at the collection level**: the time required to process a graph with a 2D CNN is constant (all images have same dimension for a given dataset), and the time required to go through the entire dataset with a 2D CNN grows linearly with the size of the dataset (whereas GKs take quadratic time to compute kernel matrix, then finding the support vectors is again quadratic) -> we can process bigger datasets 29 | 30 | 31 | ### Use 32 | * `get_node2vec.py` computes the node2vec embeddings of the graphs from their adjacency matrices (parallelized over graphs) 33 | * `get_histograms.py` computes the image representations of the graphs (stacks of 2D histograms) from their node2vec embeddings (parallelized over graphs) 34 | * `main.py` reproduces the experiments in the paper (classification of graphs as images with a 2D CNN architecture, using a 10-fold cross validation scheme) 35 | * `main_data_augmentation.py` is like `main.py`, but it implements the data augmentation scheme described in the paper (smoothed bootstrap) 36 | 37 | Command line examples and descriptions of the parameters are available within each script. 38 | 39 | ### Setup 40 | Code was developed and tested under Ubuntu 16.04.2 LTS 64-bit operating system and Python 2.7 with [Keras 1.2.2](https://faroit.github.io/keras-docs/1.2.2/) and tensorflow 1.1.0 backend. 41 | 42 | ### Other notable dependencies 43 | * igraph 0.7.1 44 | * scikit-learn 0.18.1 45 | * numpy 1.11.0 46 | * multiprocessing 47 | * functools 48 | * json 49 | * argparse 50 | 51 | ### Correspondence between names of datasets in the paper and in the code (paper -> code) 52 | * IMDB-B -> imdb_action_romance 53 | * COLLAB -> collab 54 | * REDDIT-B -> reddit_iama_askreddit_atheism_trollx 55 | * REDDIT-5K -> reddit_multi_5K 56 | * REDDIT-12K -> reddit_subreddit_10K 57 | 58 | ### Cite 59 | If you use some of the code in this repository in your work, please cite: 60 | 61 | Conference version (ICANN 2019): 62 | ````BibTeX 63 | @inproceedings{tixier2019graph, 64 | title={Graph classification with 2d convolutional neural networks}, 65 | author={Tixier, Antoine J-P and Nikolentzos, Giannis and Meladianos, Polykarpos and Vazirgiannis, Michalis}, 66 | booktitle={International Conference on Artificial Neural Networks}, 67 | pages={578--593}, 68 | year={2019}, 69 | organization={Springer} 70 | } 71 | ```` 72 | 73 | Pre-print version (2017): 74 | ````BibTeX 75 | @article{tixier2017classifying, 76 | title={Classifying Graphs as Images with Convolutional Neural Networks}, 77 | author={Tixier, Antoine Jean-Pierre and Nikolentzos, Giannis and Meladianos, Polykarpos and Vazirgiannis, Michalis}, 78 | journal={arXiv preprint arXiv:1708.02218}, 79 | year={2017} 80 | } 81 | ```` 82 | -------------------------------------------------------------------------------- /code/node2vec/emb/karate.emb: -------------------------------------------------------------------------------- 1 | 35 24 2 | 8 0.118947 0.0115566 0.17038 -0.0427312 -0.0582563 0.17704 -0.121751 -0.101363 0.0840584 0.0233382 -0.00495273 -0.066572 -0.0807031 -0.0511995 -0.160118 -0.0557651 -0.12064 0.0347513 0.0860857 0.113402 -0.0668617 -0.0117763 -0.00167025 0.0414904 3 | 0 0.413358 -0.0253449 0.648117 -0.158958 -0.225951 0.642511 -0.558545 -0.378568 0.367838 0.0527376 0.0522228 -0.276052 -0.400306 -0.280719 -0.562293 -0.275254 -0.448003 0.161317 0.250854 0.360617 -0.175733 0.043839 0.010377 0.108536 4 | 24 0.0959565 -0.00748697 0.150077 -0.0460109 -0.0574141 0.164684 -0.145206 -0.0918497 0.0492995 0.026819 0.000909779 -0.0530157 -0.0844664 -0.0408044 -0.116677 -0.0510233 -0.115293 0.0339205 0.0384144 0.0967733 -0.0222449 0.00698876 0.0228543 0.0169967 5 | 26 0.117928 -0.0299779 0.287184 -0.0668332 -0.0975112 0.295627 -0.214978 -0.139006 0.110095 0.0362877 0.00930762 -0.110894 -0.16779 -0.106583 -0.201951 -0.0838555 -0.155375 0.0493395 0.0917201 0.11728 -0.0428179 0.0095603 0.0181627 0.049319 6 | 32 0.210492 -0.0269387 0.376019 -0.103568 -0.144015 0.400048 -0.309797 -0.233032 0.184148 0.0190237 0.0207759 -0.118901 -0.224329 -0.13172 -0.298942 -0.115841 -0.22863 0.0792035 0.150131 0.194566 -0.0764207 0.0199491 0.0259204 0.0460807 7 | 5 0.0726876 0.00878614 0.150215 -0.0239939 -0.0317104 0.147181 -0.103859 -0.0949495 0.0622109 0.018549 0.0118674 -0.0527971 -0.0878049 -0.0312997 -0.122763 -0.0262234 -0.084414 0.00925851 0.0486123 0.0881301 -0.0363007 0.0182295 -0.0117973 0.0054165 8 | 11 0.098142 0.00750466 0.182097 -0.0275925 -0.0593766 0.174662 -0.117622 -0.0988937 0.0899922 0.0185472 0.0019109 -0.0909265 -0.118452 -0.0526772 -0.151581 -0.0554015 -0.11983 0.0280502 0.0703484 0.0885612 -0.0524591 0.0191643 0.00370216 0.0377737 9 | 27 0.0669557 -0.00355869 0.144273 -0.0527809 -0.0579692 0.11876 -0.107888 -0.0815976 0.0434697 0.0107044 0.0247682 -0.0654676 -0.0854248 -0.0398348 -0.100165 -0.0649923 -0.0686936 0.0244514 0.0307845 0.073638 -0.0471696 -0.0099186 -0.00733778 0.0212839 10 | 34 0.409867 -0.0402608 0.63757 -0.178174 -0.215166 0.673893 -0.494027 -0.346898 0.366314 0.0812553 0.0653644 -0.255046 -0.389361 -0.243655 -0.578325 -0.224293 -0.421472 0.157386 0.26198 0.376569 -0.220972 -0.0164087 0.0552008 0.183183 11 | 12 0.0427329 -0.011462 0.0780363 -0.0347101 -0.031618 0.0528913 -0.026396 -0.0498597 0.0448877 0.00515578 -0.0018151 -0.0449939 -0.0263193 -0.0153767 -0.0508971 -0.033773 -0.0549128 -0.000318146 0.0228869 0.0433418 -0.0339232 0.00323771 -0.014071 0.0182955 12 | 28 0.0797013 -0.0231618 0.114635 -0.015386 -0.0431869 0.114719 -0.105886 -0.0769633 0.0546555 -0.00143438 0.0267875 -0.0611881 -0.0623212 -0.03473 -0.116412 -0.0469874 -0.08971 0.0139599 0.054301 0.0555678 -0.0397026 -0.00700309 0.00724302 0.0285325 13 | 15 0.0635056 0.00332796 0.123964 -0.044004 -0.0471534 0.136754 -0.107769 -0.0568234 0.0293294 0.0194793 0.00242147 -0.0427178 -0.0523179 -0.0233603 -0.0952407 -0.0305979 -0.0623338 0.0435218 0.0319879 0.0548051 -0.0196982 0.0109058 -0.00393601 0.030302 14 | 33 0.245962 -0.0510136 0.513719 -0.12997 -0.210704 0.482909 -0.45752 -0.299971 0.241251 0.0274585 0.0373239 -0.170976 -0.25718 -0.196578 -0.363138 -0.188263 -0.320821 0.136559 0.170129 0.270088 -0.0862947 0.00736999 0.0345257 0.0548533 15 | 14 0.112286 -0.012721 0.191944 -0.0379584 -0.0859052 0.220621 -0.147213 -0.106616 0.0824577 0.025531 0.00672393 -0.0749347 -0.0915777 -0.0919624 -0.132947 -0.050148 -0.114429 0.0271393 0.0735144 0.0939301 -0.0560501 0.0229004 0.0137695 0.0149704 16 | 23 0.0280407 0.00875203 0.065563 -0.00528527 -0.0373522 0.0887338 -0.0692577 -0.0293119 0.0499992 0.00775846 -0.01408 -0.0164848 -0.053258 -0.0181916 -0.0472385 -0.0387351 -0.0385936 0.0368352 0.0158394 0.061768 -0.0168318 -0.00437657 0.0005417 0.0216465 17 | 9 0.0716663 -0.000743907 0.156218 -0.018863 -0.0483074 0.141727 -0.0971186 -0.0649113 0.0568612 0.0230879 0.0167915 -0.0691977 -0.0693805 -0.0699063 -0.0919994 -0.0362833 -0.0821156 0.0504948 0.0278097 0.0694654 -0.021136 -0.00425107 0.00314027 0.0438963 18 | 19 0.019699 0.00534746 0.0884439 -0.00774919 -0.0245967 0.086731 -0.0479487 -0.0334427 0.0258971 -0.00226271 0.02613 -0.0320537 -0.0433144 -0.0337618 -0.0758679 -0.0278388 -0.0357567 0.0218924 0.0413929 0.0280184 -0.0270589 -0.0125202 -0.00263388 -0.0088657 19 | 21 0.0359935 0.00515828 0.0843798 -0.0319058 -0.0345594 0.0709504 -0.0843852 -0.0353921 0.0449449 0.00628934 -0.00567762 -0.0347139 -0.0530895 -0.0331078 -0.081639 -0.0484632 -0.056396 0.0254517 0.0454157 0.0571912 -0.0214015 0.0195467 -0.0111147 0.00340027 20 | 7 0.108468 -0.0235252 0.190167 -0.0428432 -0.0702886 0.204791 -0.1574 -0.112999 0.0690458 0.0180378 0.00561076 -0.0535201 -0.120963 -0.06034 -0.141818 -0.0754123 -0.11501 0.0449745 0.0604494 0.116347 -0.0390616 -0.0122457 0.0290511 0.0315622 21 | 17 0.130792 -0.0169222 0.235971 -0.0474316 -0.0850914 0.224267 -0.176437 -0.152704 0.13194 0.04518 0.00382862 -0.0984689 -0.129833 -0.09495 -0.229675 -0.090041 -0.164439 0.0372863 0.109188 0.148621 -0.0885197 0.00986087 0.0285113 0.0711025 22 | 30 0.0713547 -0.0195912 0.17375 -0.0478923 -0.0777311 0.176459 -0.142162 -0.0926034 0.055808 -0.00373828 0.0190742 -0.0678709 -0.0830271 -0.0541142 -0.111199 -0.0506722 -0.103861 0.0243821 0.0517994 0.0967771 -0.0173839 0.0204399 -0.00682065 0.0178343 23 | 25 0.1012 -0.0112908 0.182377 -0.0554762 -0.0511169 0.182195 -0.135983 -0.0863102 0.0740075 0.0287203 0.0107861 -0.0816619 -0.100232 -0.0595143 -0.139955 -0.0550945 -0.086283 0.0228503 0.0513408 0.0784497 -0.0475498 0.00663531 0.0262551 0.0478823 24 | 31 0.0632546 -0.0115278 0.15983 -0.0328864 -0.0456679 0.145881 -0.114335 -0.0750976 0.0605611 -0.00207562 0.0211625 -0.0529826 -0.107869 -0.0618957 -0.127804 -0.0595611 -0.0882734 0.0491374 0.0670595 0.0967456 -0.0249219 0.0204529 0.0167798 0.0155552 25 | 29 0.0610553 -0.0236958 0.111635 -0.0122139 -0.0430489 0.0891527 -0.0730972 -0.0768495 0.0464722 0.00955864 0.0196678 -0.0414937 -0.0465708 -0.050145 -0.0927333 -0.0345689 -0.0774831 0.0355462 0.0163513 0.0719025 -0.015708 -0.013021 0.0221789 -0.000886422 26 | 13 0.0484411 0.0140812 0.0893033 -0.013338 -0.0300399 0.103117 -0.0529133 -0.05201 0.0611225 0.0148114 -0.0034379 -0.0476821 -0.0540623 -0.0321806 -0.0956852 -0.0533025 -0.0620502 0.0224071 0.0304266 0.0572541 -0.0502562 -0.00658122 0.00123917 0.01237 27 | 20 0.0380009 0.0111501 0.071708 0.000827945 -0.0191356 0.0606677 -0.0690673 -0.0342003 0.0516233 0.00515997 0.00584095 -0.0338132 -0.021314 -0.0401679 -0.0779817 -0.0183202 -0.045158 0.0246461 0.039868 0.0387305 -0.0131368 0.00379219 0.0189722 0.00574684 28 | 22 0.0637214 0.0122848 0.0941975 -0.00442218 -0.0163503 0.112567 -0.0623572 -0.0402746 0.0581555 0.02138 -0.0138224 -0.02326 -0.0428745 -0.023093 -0.104924 -0.0424019 -0.0784373 0.0294764 0.0417407 0.0669454 -0.0159104 -0.00203617 -0.013239 0.029833 29 | 18 0.0223935 0.0149774 0.0590671 0.00336113 -0.0107314 0.0688089 -0.0509246 -0.0307124 0.0150091 0.0277257 0.00233884 -0.037372 -0.0271454 -0.0209123 -0.0402169 -0.0124461 -0.0610691 0.0274362 0.0390003 0.0398409 -0.0346251 -0.00948166 0.0055367 0.00747267 30 | 4 0.0896401 0.00557873 0.185318 -0.0528846 -0.0421975 0.175384 -0.1418 -0.101392 0.0677894 -0.00631507 0.0191637 -0.0708559 -0.0828436 -0.0492209 -0.112012 -0.070257 -0.120565 0.0468425 0.0511171 0.079687 -0.0386807 -0.0132886 0.0155174 0.048941 31 | 1 0.0820661 0.0044734 0.151223 -0.0264093 -0.0506276 0.160945 -0.133863 -0.107371 0.0737196 0.00327225 0.00425153 -0.0655424 -0.0772875 -0.0712148 -0.134787 -0.0554853 -0.105125 0.0571095 0.0774468 0.0697515 -0.0456046 -0.0101718 -0.0051909 0.0475099 32 | 16 0.0355144 -0.0155761 0.0612251 -0.0224216 -0.0128161 0.054373 -0.0630691 -0.0368117 0.0281261 0.0105568 -0.00639393 -0.0345967 -0.0315763 -0.0490873 -0.058171 -0.0238263 -0.0347844 0.0186218 0.0192487 0.0423311 -0.0138155 0.00154633 0.00320158 0.023479 33 | 3 0.0606583 -0.0261054 0.13175 -0.054149 -0.0612835 0.157766 -0.124911 -0.080694 0.0638854 0.0121434 0.00939311 -0.0568471 -0.0816554 -0.0585825 -0.0943096 -0.0643112 -0.0700524 0.0135982 0.0639154 0.0819225 -0.0350031 0.0224395 0.0126063 0.0259323 34 | 6 0.114629 0.00370078 0.152576 -0.0232581 -0.0671492 0.163236 -0.116852 -0.0849857 0.0782815 0.0234211 -0.00481475 -0.0427268 -0.0958873 -0.0450822 -0.136027 -0.0430074 -0.11362 0.0526915 0.0728607 0.0801992 -0.0429163 0.00807461 -0.00560012 0.0518726 35 | 10 0.0345248 0.00684733 0.0860396 -0.0267371 -0.016227 0.0581562 -0.0632731 -0.0509194 0.0317745 0.0162948 0.00512926 -0.018467 -0.0274532 -0.0169355 -0.064275 -0.0277926 -0.0465277 0.0322793 0.0157672 0.0385453 -0.0346562 0.0187347 -0.010111 0.0132143 36 | 2 0.0888687 0.00864752 0.195538 -0.0431593 -0.0630316 0.207562 -0.140082 -0.0985806 0.0750078 0.0179878 -0.000249816 -0.0850315 -0.126998 -0.088028 -0.160173 -0.073099 -0.134317 0.0505976 0.0453499 0.0861437 -0.0512437 -0.000149248 -0.0103378 0.0406122 37 | -------------------------------------------------------------------------------- /datasets/classes/imdb_action_romance/imdb_action_romance_classes.txt: -------------------------------------------------------------------------------- 1 | -1 2 | -1 3 | -1 4 | -1 5 | -1 6 | -1 7 | -1 8 | -1 9 | -1 10 | -1 11 | -1 12 | -1 13 | -1 14 | -1 15 | -1 16 | -1 17 | -1 18 | -1 19 | -1 20 | -1 21 | -1 22 | -1 23 | -1 24 | -1 25 | -1 26 | -1 27 | -1 28 | -1 29 | -1 30 | -1 31 | -1 32 | -1 33 | -1 34 | -1 35 | -1 36 | -1 37 | -1 38 | -1 39 | -1 40 | -1 41 | -1 42 | -1 43 | -1 44 | -1 45 | -1 46 | -1 47 | -1 48 | -1 49 | -1 50 | -1 51 | -1 52 | -1 53 | -1 54 | -1 55 | -1 56 | -1 57 | -1 58 | -1 59 | -1 60 | -1 61 | -1 62 | -1 63 | -1 64 | -1 65 | -1 66 | -1 67 | -1 68 | -1 69 | -1 70 | -1 71 | -1 72 | -1 73 | -1 74 | -1 75 | -1 76 | -1 77 | -1 78 | -1 79 | -1 80 | -1 81 | -1 82 | -1 83 | -1 84 | -1 85 | -1 86 | -1 87 | -1 88 | -1 89 | -1 90 | -1 91 | -1 92 | -1 93 | -1 94 | -1 95 | -1 96 | -1 97 | -1 98 | -1 99 | -1 100 | -1 101 | -1 102 | -1 103 | -1 104 | -1 105 | -1 106 | -1 107 | -1 108 | -1 109 | -1 110 | -1 111 | -1 112 | -1 113 | -1 114 | -1 115 | -1 116 | -1 117 | -1 118 | -1 119 | -1 120 | -1 121 | -1 122 | -1 123 | -1 124 | -1 125 | -1 126 | -1 127 | -1 128 | -1 129 | -1 130 | -1 131 | -1 132 | -1 133 | -1 134 | -1 135 | -1 136 | -1 137 | -1 138 | -1 139 | -1 140 | -1 141 | -1 142 | -1 143 | -1 144 | -1 145 | -1 146 | -1 147 | -1 148 | -1 149 | -1 150 | -1 151 | -1 152 | -1 153 | -1 154 | -1 155 | -1 156 | -1 157 | -1 158 | -1 159 | -1 160 | -1 161 | -1 162 | -1 163 | -1 164 | -1 165 | -1 166 | -1 167 | -1 168 | -1 169 | -1 170 | -1 171 | -1 172 | -1 173 | -1 174 | -1 175 | -1 176 | -1 177 | -1 178 | -1 179 | -1 180 | -1 181 | -1 182 | -1 183 | -1 184 | -1 185 | -1 186 | -1 187 | -1 188 | -1 189 | -1 190 | -1 191 | -1 192 | -1 193 | -1 194 | -1 195 | -1 196 | -1 197 | -1 198 | -1 199 | -1 200 | -1 201 | -1 202 | -1 203 | -1 204 | -1 205 | -1 206 | -1 207 | -1 208 | -1 209 | -1 210 | -1 211 | -1 212 | -1 213 | -1 214 | -1 215 | -1 216 | -1 217 | -1 218 | -1 219 | -1 220 | -1 221 | -1 222 | -1 223 | -1 224 | -1 225 | -1 226 | -1 227 | -1 228 | -1 229 | -1 230 | -1 231 | -1 232 | -1 233 | -1 234 | -1 235 | -1 236 | -1 237 | -1 238 | -1 239 | -1 240 | -1 241 | -1 242 | -1 243 | -1 244 | -1 245 | -1 246 | -1 247 | -1 248 | -1 249 | -1 250 | -1 251 | -1 252 | -1 253 | -1 254 | -1 255 | -1 256 | -1 257 | -1 258 | -1 259 | -1 260 | -1 261 | -1 262 | -1 263 | -1 264 | -1 265 | -1 266 | -1 267 | -1 268 | -1 269 | -1 270 | -1 271 | -1 272 | -1 273 | -1 274 | -1 275 | -1 276 | -1 277 | -1 278 | -1 279 | -1 280 | -1 281 | -1 282 | -1 283 | -1 284 | -1 285 | -1 286 | -1 287 | -1 288 | -1 289 | -1 290 | -1 291 | -1 292 | -1 293 | -1 294 | -1 295 | -1 296 | -1 297 | -1 298 | -1 299 | -1 300 | -1 301 | -1 302 | -1 303 | -1 304 | -1 305 | -1 306 | -1 307 | -1 308 | -1 309 | -1 310 | -1 311 | -1 312 | -1 313 | -1 314 | -1 315 | -1 316 | -1 317 | -1 318 | -1 319 | -1 320 | -1 321 | -1 322 | -1 323 | -1 324 | -1 325 | -1 326 | -1 327 | -1 328 | -1 329 | -1 330 | -1 331 | -1 332 | -1 333 | -1 334 | -1 335 | -1 336 | -1 337 | -1 338 | -1 339 | -1 340 | -1 341 | -1 342 | -1 343 | -1 344 | -1 345 | -1 346 | -1 347 | -1 348 | -1 349 | -1 350 | -1 351 | -1 352 | -1 353 | -1 354 | -1 355 | -1 356 | -1 357 | -1 358 | -1 359 | -1 360 | -1 361 | -1 362 | -1 363 | -1 364 | -1 365 | -1 366 | -1 367 | -1 368 | -1 369 | -1 370 | -1 371 | -1 372 | -1 373 | -1 374 | -1 375 | -1 376 | -1 377 | -1 378 | -1 379 | -1 380 | -1 381 | -1 382 | -1 383 | -1 384 | -1 385 | -1 386 | -1 387 | -1 388 | -1 389 | -1 390 | -1 391 | -1 392 | -1 393 | -1 394 | -1 395 | -1 396 | -1 397 | -1 398 | -1 399 | -1 400 | -1 401 | -1 402 | -1 403 | -1 404 | -1 405 | -1 406 | -1 407 | -1 408 | -1 409 | -1 410 | -1 411 | -1 412 | -1 413 | -1 414 | -1 415 | -1 416 | -1 417 | -1 418 | -1 419 | -1 420 | -1 421 | -1 422 | -1 423 | -1 424 | -1 425 | -1 426 | -1 427 | -1 428 | -1 429 | -1 430 | -1 431 | -1 432 | -1 433 | -1 434 | -1 435 | -1 436 | -1 437 | -1 438 | -1 439 | -1 440 | -1 441 | -1 442 | -1 443 | -1 444 | -1 445 | -1 446 | -1 447 | -1 448 | -1 449 | -1 450 | -1 451 | -1 452 | -1 453 | -1 454 | -1 455 | -1 456 | -1 457 | -1 458 | -1 459 | -1 460 | -1 461 | -1 462 | -1 463 | -1 464 | -1 465 | -1 466 | -1 467 | -1 468 | -1 469 | -1 470 | -1 471 | -1 472 | -1 473 | -1 474 | -1 475 | -1 476 | -1 477 | -1 478 | -1 479 | -1 480 | -1 481 | -1 482 | -1 483 | -1 484 | -1 485 | -1 486 | -1 487 | -1 488 | -1 489 | -1 490 | -1 491 | -1 492 | -1 493 | -1 494 | -1 495 | -1 496 | -1 497 | -1 498 | -1 499 | -1 500 | -1 501 | 1 502 | 1 503 | 1 504 | 1 505 | 1 506 | 1 507 | 1 508 | 1 509 | 1 510 | 1 511 | 1 512 | 1 513 | 1 514 | 1 515 | 1 516 | 1 517 | 1 518 | 1 519 | 1 520 | 1 521 | 1 522 | 1 523 | 1 524 | 1 525 | 1 526 | 1 527 | 1 528 | 1 529 | 1 530 | 1 531 | 1 532 | 1 533 | 1 534 | 1 535 | 1 536 | 1 537 | 1 538 | 1 539 | 1 540 | 1 541 | 1 542 | 1 543 | 1 544 | 1 545 | 1 546 | 1 547 | 1 548 | 1 549 | 1 550 | 1 551 | 1 552 | 1 553 | 1 554 | 1 555 | 1 556 | 1 557 | 1 558 | 1 559 | 1 560 | 1 561 | 1 562 | 1 563 | 1 564 | 1 565 | 1 566 | 1 567 | 1 568 | 1 569 | 1 570 | 1 571 | 1 572 | 1 573 | 1 574 | 1 575 | 1 576 | 1 577 | 1 578 | 1 579 | 1 580 | 1 581 | 1 582 | 1 583 | 1 584 | 1 585 | 1 586 | 1 587 | 1 588 | 1 589 | 1 590 | 1 591 | 1 592 | 1 593 | 1 594 | 1 595 | 1 596 | 1 597 | 1 598 | 1 599 | 1 600 | 1 601 | 1 602 | 1 603 | 1 604 | 1 605 | 1 606 | 1 607 | 1 608 | 1 609 | 1 610 | 1 611 | 1 612 | 1 613 | 1 614 | 1 615 | 1 616 | 1 617 | 1 618 | 1 619 | 1 620 | 1 621 | 1 622 | 1 623 | 1 624 | 1 625 | 1 626 | 1 627 | 1 628 | 1 629 | 1 630 | 1 631 | 1 632 | 1 633 | 1 634 | 1 635 | 1 636 | 1 637 | 1 638 | 1 639 | 1 640 | 1 641 | 1 642 | 1 643 | 1 644 | 1 645 | 1 646 | 1 647 | 1 648 | 1 649 | 1 650 | 1 651 | 1 652 | 1 653 | 1 654 | 1 655 | 1 656 | 1 657 | 1 658 | 1 659 | 1 660 | 1 661 | 1 662 | 1 663 | 1 664 | 1 665 | 1 666 | 1 667 | 1 668 | 1 669 | 1 670 | 1 671 | 1 672 | 1 673 | 1 674 | 1 675 | 1 676 | 1 677 | 1 678 | 1 679 | 1 680 | 1 681 | 1 682 | 1 683 | 1 684 | 1 685 | 1 686 | 1 687 | 1 688 | 1 689 | 1 690 | 1 691 | 1 692 | 1 693 | 1 694 | 1 695 | 1 696 | 1 697 | 1 698 | 1 699 | 1 700 | 1 701 | 1 702 | 1 703 | 1 704 | 1 705 | 1 706 | 1 707 | 1 708 | 1 709 | 1 710 | 1 711 | 1 712 | 1 713 | 1 714 | 1 715 | 1 716 | 1 717 | 1 718 | 1 719 | 1 720 | 1 721 | 1 722 | 1 723 | 1 724 | 1 725 | 1 726 | 1 727 | 1 728 | 1 729 | 1 730 | 1 731 | 1 732 | 1 733 | 1 734 | 1 735 | 1 736 | 1 737 | 1 738 | 1 739 | 1 740 | 1 741 | 1 742 | 1 743 | 1 744 | 1 745 | 1 746 | 1 747 | 1 748 | 1 749 | 1 750 | 1 751 | 1 752 | 1 753 | 1 754 | 1 755 | 1 756 | 1 757 | 1 758 | 1 759 | 1 760 | 1 761 | 1 762 | 1 763 | 1 764 | 1 765 | 1 766 | 1 767 | 1 768 | 1 769 | 1 770 | 1 771 | 1 772 | 1 773 | 1 774 | 1 775 | 1 776 | 1 777 | 1 778 | 1 779 | 1 780 | 1 781 | 1 782 | 1 783 | 1 784 | 1 785 | 1 786 | 1 787 | 1 788 | 1 789 | 1 790 | 1 791 | 1 792 | 1 793 | 1 794 | 1 795 | 1 796 | 1 797 | 1 798 | 1 799 | 1 800 | 1 801 | 1 802 | 1 803 | 1 804 | 1 805 | 1 806 | 1 807 | 1 808 | 1 809 | 1 810 | 1 811 | 1 812 | 1 813 | 1 814 | 1 815 | 1 816 | 1 817 | 1 818 | 1 819 | 1 820 | 1 821 | 1 822 | 1 823 | 1 824 | 1 825 | 1 826 | 1 827 | 1 828 | 1 829 | 1 830 | 1 831 | 1 832 | 1 833 | 1 834 | 1 835 | 1 836 | 1 837 | 1 838 | 1 839 | 1 840 | 1 841 | 1 842 | 1 843 | 1 844 | 1 845 | 1 846 | 1 847 | 1 848 | 1 849 | 1 850 | 1 851 | 1 852 | 1 853 | 1 854 | 1 855 | 1 856 | 1 857 | 1 858 | 1 859 | 1 860 | 1 861 | 1 862 | 1 863 | 1 864 | 1 865 | 1 866 | 1 867 | 1 868 | 1 869 | 1 870 | 1 871 | 1 872 | 1 873 | 1 874 | 1 875 | 1 876 | 1 877 | 1 878 | 1 879 | 1 880 | 1 881 | 1 882 | 1 883 | 1 884 | 1 885 | 1 886 | 1 887 | 1 888 | 1 889 | 1 890 | 1 891 | 1 892 | 1 893 | 1 894 | 1 895 | 1 896 | 1 897 | 1 898 | 1 899 | 1 900 | 1 901 | 1 902 | 1 903 | 1 904 | 1 905 | 1 906 | 1 907 | 1 908 | 1 909 | 1 910 | 1 911 | 1 912 | 1 913 | 1 914 | 1 915 | 1 916 | 1 917 | 1 918 | 1 919 | 1 920 | 1 921 | 1 922 | 1 923 | 1 924 | 1 925 | 1 926 | 1 927 | 1 928 | 1 929 | 1 930 | 1 931 | 1 932 | 1 933 | 1 934 | 1 935 | 1 936 | 1 937 | 1 938 | 1 939 | 1 940 | 1 941 | 1 942 | 1 943 | 1 944 | 1 945 | 1 946 | 1 947 | 1 948 | 1 949 | 1 950 | 1 951 | 1 952 | 1 953 | 1 954 | 1 955 | 1 956 | 1 957 | 1 958 | 1 959 | 1 960 | 1 961 | 1 962 | 1 963 | 1 964 | 1 965 | 1 966 | 1 967 | 1 968 | 1 969 | 1 970 | 1 971 | 1 972 | 1 973 | 1 974 | 1 975 | 1 976 | 1 977 | 1 978 | 1 979 | 1 980 | 1 981 | 1 982 | 1 983 | 1 984 | 1 985 | 1 986 | 1 987 | 1 988 | 1 989 | 1 990 | 1 991 | 1 992 | 1 993 | 1 994 | 1 995 | 1 996 | 1 997 | 1 998 | 1 999 | 1 1000 | 1 1001 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import math 4 | import numpy as np 5 | import os 6 | import random 7 | import json 8 | import re 9 | import datetime 10 | import time 11 | 12 | from keras.layers import Dense, Dropout, Flatten, Input, Convolution2D, MaxPooling2D, Merge 13 | from keras.utils import np_utils 14 | from keras.models import Model 15 | from keras import backend as K 16 | 17 | from keras.callbacks import EarlyStopping 18 | 19 | # ============================================================================= 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | # positional arguments (required) 24 | parser.add_argument('path_root', type=str, help="path to 'datasets' directory") 25 | parser.add_argument('dataset', type=str, help='name of the dataset. Must correspond to a valid value that matches names of files in tensors/*dataset*/node2vec_hist/ folder') 26 | parser.add_argument('p', type=str, help='p parameter of node2vec. Must correspond to a valid value that matches names of files in tensors/*dataset*/node2vec_hist/ folder') 27 | parser.add_argument('q', type=str, help='q parameter of node2vec. Must correspond to a valid value that matches names of files in tensors/*dataset*/node2vec_hist/ folder') 28 | parser.add_argument('definition', type=int, help='definition. E.g., 14 for 14:1. Must correspond to a valid value that matches names of files in tensors/*dataset*/node2vec_hist/ folder') 29 | parser.add_argument('n_channels', type=int, help='number of channels. Must not exceed half the depth of the tensors in tensors/*dataset*/node2vec_hist/ folder') 30 | 31 | # optional arguments 32 | parser.add_argument('--n_folds', type=int, default=10, choices=[2,3,4,5,6,7,8,9,10], help='number of folds for cross-validation') 33 | parser.add_argument('--n_repeats', type=int, default=3, choices=[1,2,3,4,5], help='number of times each fold should be repeated') 34 | parser.add_argument('--batch_size', type=int, default=32, choices=[32,64,128], help='batch size') 35 | parser.add_argument('--nb_epochs', type=int, default=50, help='maximum number of epochs') 36 | parser.add_argument('--patience', type=int, default=5, help='patience for early stopping strategy') 37 | parser.add_argument('--drop_rate', type=float, default=0.3, help='dropout rate') 38 | 39 | args = parser.parse_args() 40 | 41 | # convert command line arguments 42 | path_root = args.path_root 43 | dataset = args.dataset 44 | p = args.p 45 | q = args.q 46 | definition = args.definition 47 | n_channels = args.n_channels 48 | 49 | n_folds = args.n_folds 50 | n_repeats = args.n_repeats 51 | batch_size = args.batch_size 52 | nb_epochs = args.nb_epochs 53 | my_patience = args.patience 54 | drop_rate = args.drop_rate 55 | 56 | dim_ordering = 'th' # channels first 57 | my_optimizer = 'adam' 58 | 59 | # command line examples: python main.py /home/antoine/Desktop/graph_2D_CNN/datasets/ imdb_action_romance 1 1 14 5 60 | # python main.py /home/antoine/Desktop/graph_2D_CNN/datasets/ imdb_action_romance 1 1 14 5 --n_folds 10 --n_repeats 1 --nb_epochs 20 --patience 3 61 | 62 | # ============================================================================= 63 | 64 | def atoi(text): 65 | return int(text) if text.isdigit() else text 66 | 67 | def natural_keys(text): 68 | return [atoi(c) for c in re.split('(\d+)', text)] 69 | 70 | # ============================================================================= 71 | 72 | def main(): 73 | 74 | my_date_time = '_'.join(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").split()) 75 | 76 | parameters = {'path_root':path_root, 77 | 'dataset':dataset, 78 | 'p':p, 79 | 'q':q, 80 | 'definition':definition, 81 | 'n_channels':n_channels, 82 | 'n_folds':n_folds, 83 | 'n_repeats':n_repeats, 84 | 'batch_size':batch_size, 85 | 'nb_epochs':nb_epochs, 86 | 'my_patience':my_patience, 87 | 'drop_rate':drop_rate, 88 | 'dim_ordering':dim_ordering, 89 | 'my_optimizer':my_optimizer 90 | } 91 | 92 | name_save = path_root + '/results/' + dataset + '_augmentation_' + my_date_time 93 | 94 | with open(name_save + '_parameters.json', 'w') as my_file: 95 | json.dump(parameters, my_file, sort_keys=True, indent=4) 96 | 97 | print '========== parameters defined and saved to disk ==========' 98 | 99 | regexp_p = re.compile('p=' + p) 100 | regexp_q = re.compile('q=' + q) 101 | 102 | print '========== loading labels ==========' 103 | 104 | with open(path_root + 'classes/' + dataset + '/' + dataset + '_classes.txt', 'r') as f: 105 | ys = f.read().splitlines() 106 | ys = [int(elt) for elt in ys] 107 | 108 | num_classes = len(list(set(ys))) 109 | 110 | print 'classes:', list(set(ys)) 111 | 112 | print 'converting to 0-based index' 113 | 114 | if 0 not in list(set(ys)): 115 | if -1 not in list(set(ys)): 116 | ys = [y-1 for y in ys] 117 | else: 118 | ys = [1 if y==1 else 0 for y in ys] 119 | 120 | print 'classes:', list(set(ys)) 121 | 122 | print '========== loading tensors ==========' 123 | 124 | path_read = path_root + 'tensors/' + dataset + '/node2vec_hist/' 125 | file_names = [elt for elt in os.listdir(path_read) if (str(definition)+':1' in elt and regexp_p.search(elt) and regexp_q.search(elt) and elt.count('p=')==1 and elt.count('q=')==1 and elt.split('_')[-1:][0][0].isdigit())] # make sure the right files are selected 126 | file_names.sort(key=natural_keys) 127 | print len(file_names) 128 | print file_names[:5] 129 | print file_names[-5:] 130 | 131 | print 'ensuring tensor-label matching' 132 | kept_idxs = [int(elt.split('_')[-1].split('.')[0]) for elt in file_names] 133 | print len(kept_idxs) 134 | print kept_idxs[:5] 135 | print kept_idxs[-5:] 136 | print 'removing', len(ys) - len(kept_idxs), 'labels' 137 | ys = [y for idx,y in enumerate(ys) if idx in kept_idxs] 138 | 139 | print len(file_names) == len(ys) 140 | 141 | print 'converting labels to array' 142 | ys = np.array(ys) 143 | 144 | print 'transforming integer labels into one-hot vectors' 145 | ys = np_utils.to_categorical(ys, num_classes) 146 | 147 | tensors = [] 148 | for name in file_names: 149 | tensor = np.load(path_read + name) 150 | tensors.append(tensor[:n_channels,:,:]) 151 | 152 | tensors = np.array(tensors) 153 | tensors = tensors.astype('float32') 154 | 155 | print 'tensors shape:', tensors.shape 156 | 157 | print '========== getting image dimensions ==========' 158 | 159 | img_rows, img_cols = int(tensors.shape[2]), int(tensors.shape[3]) 160 | input_shape = (int(tensors.shape[1]), img_rows, img_cols) 161 | 162 | print 'input shape:', input_shape 163 | 164 | print '========== shuffling data ==========' 165 | 166 | shuffled_idxs = random.sample(range(tensors.shape[0]), int(tensors.shape[0])) # sample w/o replct 167 | tensors = tensors[shuffled_idxs] 168 | ys = ys[shuffled_idxs] 169 | 170 | print '========== conducting', n_folds ,'fold cross validation =========='; print 'repeating each fold:', n_repeats, 'times' 171 | 172 | folds = np.array_split(tensors,n_folds,axis=0) 173 | 174 | print 'fold sizes:', [len(fold) for fold in folds] 175 | 176 | folds_labels = np.array_split(ys,n_folds,axis=0) 177 | 178 | outputs = [] 179 | histories = [] 180 | 181 | for i in range(n_folds): 182 | 183 | t = time.time() 184 | 185 | x_train = np.concatenate([fold for j,fold in enumerate(folds) if j!=i],axis=0) 186 | x_test = [fold for j,fold in enumerate(folds) if j==i] 187 | 188 | y_train = np.concatenate([y for j,y in enumerate(folds_labels) if j!=i],axis=0) 189 | y_test = [y for j,y in enumerate(folds_labels) if j==i] 190 | 191 | for repeating in range(n_repeats): 192 | 193 | print 'clearing Keras session' 194 | K.clear_session() 195 | 196 | my_input = Input(shape=input_shape, dtype='float32') 197 | 198 | conv_1 = Convolution2D(64, 199 | 3, 200 | 3, 201 | border_mode='valid', 202 | activation='relu', 203 | dim_ordering=dim_ordering 204 | )(my_input) 205 | 206 | pooled_conv_1 = MaxPooling2D(pool_size=(2,2), 207 | dim_ordering=dim_ordering 208 | )(conv_1) 209 | 210 | pooled_conv_1_dropped = Dropout(drop_rate)(pooled_conv_1) 211 | 212 | conv_11 = Convolution2D(96, 213 | 3, 214 | 3, 215 | border_mode='valid', 216 | activation='relu', 217 | dim_ordering=dim_ordering 218 | )(pooled_conv_1_dropped) 219 | 220 | pooled_conv_11 = MaxPooling2D(pool_size=(2,2), 221 | dim_ordering=dim_ordering 222 | )(conv_11) 223 | 224 | pooled_conv_11_dropped = Dropout(drop_rate)(pooled_conv_11) 225 | pooled_conv_11_dropped_flat = Flatten()(pooled_conv_11_dropped) 226 | 227 | conv_2 = Convolution2D(64, 228 | 4, 229 | 4, 230 | border_mode='valid', 231 | activation='relu', 232 | dim_ordering=dim_ordering 233 | )(my_input) 234 | 235 | pooled_conv_2 = MaxPooling2D(pool_size=(2,2),dim_ordering=dim_ordering)(conv_2) 236 | pooled_conv_2_dropped = Dropout(drop_rate)(pooled_conv_2) 237 | 238 | conv_22 = Convolution2D(96, 239 | 4, 240 | 4, 241 | border_mode='valid', 242 | activation='relu', 243 | dim_ordering=dim_ordering, 244 | )(pooled_conv_2_dropped) 245 | 246 | pooled_conv_22 = MaxPooling2D(pool_size=(2,2),dim_ordering=dim_ordering)(conv_22) 247 | pooled_conv_22_dropped = Dropout(drop_rate)(pooled_conv_22) 248 | pooled_conv_22_dropped_flat = Flatten()(pooled_conv_22_dropped) 249 | 250 | conv_3 = Convolution2D(64, 251 | 5, 252 | 5, 253 | border_mode='valid', 254 | activation='relu', 255 | dim_ordering=dim_ordering 256 | )(my_input) 257 | 258 | pooled_conv_3 = MaxPooling2D(pool_size=(2,2),dim_ordering=dim_ordering)(conv_3) 259 | pooled_conv_3_dropped = Dropout(drop_rate)(pooled_conv_3) 260 | 261 | conv_33 = Convolution2D(96, 262 | 5, 263 | 5, 264 | border_mode='valid', 265 | activation='relu', 266 | dim_ordering=dim_ordering 267 | )(pooled_conv_3_dropped) 268 | 269 | pooled_conv_33 = MaxPooling2D(pool_size=(2,2),dim_ordering=dim_ordering)(conv_33) 270 | pooled_conv_33_dropped = Dropout(drop_rate)(pooled_conv_33) 271 | pooled_conv_33_dropped_flat = Flatten()(pooled_conv_33_dropped) 272 | 273 | conv_4 = Convolution2D(64, 274 | 6, 275 | 6, 276 | border_mode='valid', 277 | activation='relu', 278 | dim_ordering=dim_ordering 279 | )(my_input) 280 | 281 | pooled_conv_4 = MaxPooling2D(pool_size=(2,2),dim_ordering=dim_ordering)(conv_4) 282 | pooled_conv_4_dropped = Dropout(drop_rate)(pooled_conv_4) 283 | 284 | conv_44 = Convolution2D(96, 285 | 6, 286 | 6, 287 | border_mode='valid', 288 | activation='relu', 289 | dim_ordering=dim_ordering 290 | )(pooled_conv_4_dropped) 291 | 292 | pooled_conv_44 = MaxPooling2D(pool_size=(2,2),dim_ordering=dim_ordering) (conv_44) 293 | pooled_conv_44_dropped = Dropout(drop_rate) (pooled_conv_44) 294 | pooled_conv_44_dropped_flat = Flatten()(pooled_conv_44_dropped) 295 | 296 | merge = Merge(mode='concat')([pooled_conv_11_dropped_flat, 297 | pooled_conv_22_dropped_flat, 298 | pooled_conv_33_dropped_flat, 299 | pooled_conv_44_dropped_flat]) 300 | 301 | merge_dropped = Dropout(drop_rate)(merge) 302 | 303 | dense = Dense(128, 304 | activation='relu' 305 | )(merge_dropped) 306 | 307 | dense_dropped = Dropout(drop_rate)(dense) 308 | 309 | prob = Dense(output_dim=num_classes, 310 | activation='softmax' 311 | )(dense_dropped) 312 | 313 | # instantiate model 314 | model = Model(my_input,prob) 315 | 316 | # configure model for training 317 | model.compile(loss='categorical_crossentropy', 318 | optimizer=my_optimizer, 319 | metrics=['accuracy']) 320 | 321 | print 'model compiled' 322 | 323 | early_stopping = EarlyStopping(monitor='val_acc', # go through epochs as long as acc on validation set increases 324 | patience=my_patience, 325 | mode='max') 326 | 327 | history = model.fit(x_train, 328 | y_train, 329 | batch_size=batch_size, 330 | nb_epoch=nb_epochs, 331 | validation_data=(x_test, y_test), 332 | callbacks=[early_stopping]) 333 | 334 | # save [min loss,max acc] on test set 335 | max_acc = max(model.history.history['val_acc']) 336 | max_idx = model.history.history['val_acc'].index(max_acc) 337 | output = [model.history.history['val_loss'][max_idx],max_acc] 338 | outputs.append(output) 339 | 340 | # also save full history for sanity checking 341 | histories.append(model.history.history) 342 | 343 | print '**** fold', i+1 ,'done in ' + str(math.ceil(time.time() - t)) + ' second(s) ****' 344 | 345 | # save results to disk 346 | with open(name_save + '_results.json', 'w') as my_file: 347 | json.dump({'outputs':outputs,'histories':histories}, my_file, sort_keys=False, indent=4) 348 | 349 | print '========== results saved to disk ==========' 350 | 351 | if __name__ == "__main__": 352 | main() 353 | -------------------------------------------------------------------------------- /datasets/classes/reddit_iama_askreddit_atheism_trollx/reddit_iama_askreddit_atheism_trollx_classes.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 1 4 | 1 5 | 1 6 | 1 7 | 1 8 | 1 9 | 1 10 | 1 11 | 1 12 | 1 13 | 1 14 | 1 15 | 1 16 | 1 17 | 1 18 | 1 19 | 1 20 | 1 21 | 1 22 | 1 23 | 1 24 | 1 25 | 1 26 | 1 27 | 1 28 | 1 29 | 1 30 | 1 31 | 1 32 | 1 33 | 1 34 | 1 35 | 1 36 | 1 37 | 1 38 | 1 39 | 1 40 | 1 41 | 1 42 | 1 43 | 1 44 | 1 45 | 1 46 | 1 47 | 1 48 | 1 49 | 1 50 | 1 51 | 1 52 | 1 53 | 1 54 | 1 55 | 1 56 | 1 57 | 1 58 | 1 59 | 1 60 | 1 61 | 1 62 | 1 63 | 1 64 | 1 65 | 1 66 | 1 67 | 1 68 | 1 69 | 1 70 | 1 71 | 1 72 | 1 73 | 1 74 | 1 75 | 1 76 | 1 77 | 1 78 | 1 79 | 1 80 | 1 81 | 1 82 | 1 83 | 1 84 | 1 85 | 1 86 | 1 87 | 1 88 | 1 89 | 1 90 | 1 91 | 1 92 | 1 93 | 1 94 | 1 95 | 1 96 | 1 97 | 1 98 | 1 99 | 1 100 | 1 101 | 1 102 | 1 103 | 1 104 | 1 105 | 1 106 | 1 107 | 1 108 | 1 109 | 1 110 | 1 111 | 1 112 | 1 113 | 1 114 | 1 115 | 1 116 | 1 117 | 1 118 | 1 119 | 1 120 | 1 121 | 1 122 | 1 123 | 1 124 | 1 125 | 1 126 | 1 127 | 1 128 | 1 129 | 1 130 | 1 131 | 1 132 | 1 133 | 1 134 | 1 135 | 1 136 | 1 137 | 1 138 | 1 139 | 1 140 | 1 141 | 1 142 | 1 143 | 1 144 | 1 145 | 1 146 | 1 147 | 1 148 | 1 149 | 1 150 | 1 151 | 1 152 | 1 153 | 1 154 | 1 155 | 1 156 | 1 157 | 1 158 | 1 159 | 1 160 | 1 161 | 1 162 | 1 163 | 1 164 | 1 165 | 1 166 | 1 167 | 1 168 | 1 169 | 1 170 | 1 171 | 1 172 | 1 173 | 1 174 | 1 175 | 1 176 | 1 177 | 1 178 | 1 179 | 1 180 | 1 181 | 1 182 | 1 183 | 1 184 | 1 185 | 1 186 | 1 187 | 1 188 | 1 189 | 1 190 | 1 191 | 1 192 | 1 193 | 1 194 | 1 195 | 1 196 | 1 197 | 1 198 | 1 199 | 1 200 | 1 201 | 1 202 | 1 203 | 1 204 | 1 205 | 1 206 | 1 207 | 1 208 | 1 209 | 1 210 | 1 211 | 1 212 | 1 213 | 1 214 | 1 215 | 1 216 | 1 217 | 1 218 | 1 219 | 1 220 | 1 221 | 1 222 | 1 223 | 1 224 | 1 225 | 1 226 | 1 227 | 1 228 | 1 229 | 1 230 | 1 231 | 1 232 | 1 233 | 1 234 | 1 235 | 1 236 | 1 237 | 1 238 | 1 239 | 1 240 | 1 241 | 1 242 | 1 243 | 1 244 | 1 245 | 1 246 | 1 247 | 1 248 | 1 249 | 1 250 | 1 251 | 1 252 | 1 253 | 1 254 | 1 255 | 1 256 | 1 257 | 1 258 | 1 259 | 1 260 | 1 261 | 1 262 | 1 263 | 1 264 | 1 265 | 1 266 | 1 267 | 1 268 | 1 269 | 1 270 | 1 271 | 1 272 | 1 273 | 1 274 | 1 275 | 1 276 | 1 277 | 1 278 | 1 279 | 1 280 | 1 281 | 1 282 | 1 283 | 1 284 | 1 285 | 1 286 | 1 287 | 1 288 | 1 289 | 1 290 | 1 291 | 1 292 | 1 293 | 1 294 | 1 295 | 1 296 | 1 297 | 1 298 | 1 299 | 1 300 | 1 301 | 1 302 | 1 303 | 1 304 | 1 305 | 1 306 | 1 307 | 1 308 | 1 309 | 1 310 | 1 311 | 1 312 | 1 313 | 1 314 | 1 315 | 1 316 | 1 317 | 1 318 | 1 319 | 1 320 | 1 321 | 1 322 | 1 323 | 1 324 | 1 325 | 1 326 | 1 327 | 1 328 | 1 329 | 1 330 | 1 331 | 1 332 | 1 333 | 1 334 | 1 335 | 1 336 | 1 337 | 1 338 | 1 339 | 1 340 | 1 341 | 1 342 | 1 343 | 1 344 | 1 345 | 1 346 | 1 347 | 1 348 | 1 349 | 1 350 | 1 351 | 1 352 | 1 353 | 1 354 | 1 355 | 1 356 | 1 357 | 1 358 | 1 359 | 1 360 | 1 361 | 1 362 | 1 363 | 1 364 | 1 365 | 1 366 | 1 367 | 1 368 | 1 369 | 1 370 | 1 371 | 1 372 | 1 373 | 1 374 | 1 375 | 1 376 | 1 377 | 1 378 | 1 379 | 1 380 | 1 381 | 1 382 | 1 383 | 1 384 | 1 385 | 1 386 | 1 387 | 1 388 | 1 389 | 1 390 | 1 391 | 1 392 | 1 393 | 1 394 | 1 395 | 1 396 | 1 397 | 1 398 | 1 399 | 1 400 | 1 401 | 1 402 | 1 403 | 1 404 | 1 405 | 1 406 | 1 407 | 1 408 | 1 409 | 1 410 | 1 411 | 1 412 | 1 413 | 1 414 | 1 415 | 1 416 | 1 417 | 1 418 | 1 419 | 1 420 | 1 421 | 1 422 | 1 423 | 1 424 | 1 425 | 1 426 | 1 427 | 1 428 | 1 429 | 1 430 | 1 431 | 1 432 | 1 433 | 1 434 | 1 435 | 1 436 | 1 437 | 1 438 | 1 439 | 1 440 | 1 441 | 1 442 | 1 443 | 1 444 | 1 445 | 1 446 | 1 447 | 1 448 | 1 449 | 1 450 | 1 451 | 1 452 | 1 453 | 1 454 | 1 455 | 1 456 | 1 457 | 1 458 | 1 459 | 1 460 | 1 461 | 1 462 | 1 463 | 1 464 | 1 465 | 1 466 | 1 467 | 1 468 | 1 469 | 1 470 | 1 471 | 1 472 | 1 473 | 1 474 | 1 475 | 1 476 | 1 477 | 1 478 | 1 479 | 1 480 | 1 481 | 1 482 | 1 483 | 1 484 | 1 485 | 1 486 | 1 487 | 1 488 | 1 489 | 1 490 | 1 491 | 1 492 | 1 493 | 1 494 | 1 495 | 1 496 | 1 497 | 1 498 | 1 499 | 1 500 | 1 501 | -1 502 | -1 503 | -1 504 | -1 505 | -1 506 | -1 507 | -1 508 | -1 509 | -1 510 | -1 511 | -1 512 | -1 513 | -1 514 | -1 515 | -1 516 | -1 517 | -1 518 | -1 519 | -1 520 | -1 521 | -1 522 | -1 523 | -1 524 | -1 525 | -1 526 | -1 527 | -1 528 | -1 529 | -1 530 | -1 531 | -1 532 | -1 533 | -1 534 | -1 535 | -1 536 | -1 537 | -1 538 | -1 539 | -1 540 | -1 541 | -1 542 | -1 543 | -1 544 | -1 545 | -1 546 | -1 547 | -1 548 | -1 549 | -1 550 | -1 551 | -1 552 | -1 553 | -1 554 | -1 555 | -1 556 | -1 557 | -1 558 | -1 559 | -1 560 | -1 561 | -1 562 | -1 563 | -1 564 | -1 565 | -1 566 | -1 567 | -1 568 | -1 569 | -1 570 | -1 571 | -1 572 | -1 573 | -1 574 | -1 575 | -1 576 | -1 577 | -1 578 | -1 579 | -1 580 | -1 581 | -1 582 | -1 583 | -1 584 | -1 585 | -1 586 | -1 587 | -1 588 | -1 589 | -1 590 | -1 591 | -1 592 | -1 593 | -1 594 | -1 595 | -1 596 | -1 597 | -1 598 | -1 599 | -1 600 | -1 601 | -1 602 | -1 603 | -1 604 | -1 605 | -1 606 | -1 607 | -1 608 | -1 609 | -1 610 | -1 611 | -1 612 | -1 613 | -1 614 | -1 615 | -1 616 | -1 617 | -1 618 | -1 619 | -1 620 | -1 621 | -1 622 | -1 623 | -1 624 | -1 625 | -1 626 | -1 627 | -1 628 | -1 629 | -1 630 | -1 631 | -1 632 | -1 633 | -1 634 | -1 635 | -1 636 | -1 637 | -1 638 | -1 639 | -1 640 | -1 641 | -1 642 | -1 643 | -1 644 | -1 645 | -1 646 | -1 647 | -1 648 | -1 649 | -1 650 | -1 651 | -1 652 | -1 653 | -1 654 | -1 655 | -1 656 | -1 657 | -1 658 | -1 659 | -1 660 | -1 661 | -1 662 | -1 663 | -1 664 | -1 665 | -1 666 | -1 667 | -1 668 | -1 669 | -1 670 | -1 671 | -1 672 | -1 673 | -1 674 | -1 675 | -1 676 | -1 677 | -1 678 | -1 679 | -1 680 | -1 681 | -1 682 | -1 683 | -1 684 | -1 685 | -1 686 | -1 687 | -1 688 | -1 689 | -1 690 | -1 691 | -1 692 | -1 693 | -1 694 | -1 695 | -1 696 | -1 697 | -1 698 | -1 699 | -1 700 | -1 701 | -1 702 | -1 703 | -1 704 | -1 705 | -1 706 | -1 707 | -1 708 | -1 709 | -1 710 | -1 711 | -1 712 | -1 713 | -1 714 | -1 715 | -1 716 | -1 717 | -1 718 | -1 719 | -1 720 | -1 721 | -1 722 | -1 723 | -1 724 | -1 725 | -1 726 | -1 727 | -1 728 | -1 729 | -1 730 | -1 731 | -1 732 | -1 733 | -1 734 | -1 735 | -1 736 | -1 737 | -1 738 | -1 739 | -1 740 | -1 741 | -1 742 | -1 743 | -1 744 | -1 745 | -1 746 | -1 747 | -1 748 | -1 749 | -1 750 | -1 751 | -1 752 | -1 753 | -1 754 | -1 755 | -1 756 | -1 757 | -1 758 | -1 759 | -1 760 | -1 761 | -1 762 | -1 763 | -1 764 | -1 765 | -1 766 | -1 767 | -1 768 | -1 769 | -1 770 | -1 771 | -1 772 | -1 773 | -1 774 | -1 775 | -1 776 | -1 777 | -1 778 | -1 779 | -1 780 | -1 781 | -1 782 | -1 783 | -1 784 | -1 785 | -1 786 | -1 787 | -1 788 | -1 789 | -1 790 | -1 791 | -1 792 | -1 793 | -1 794 | -1 795 | -1 796 | -1 797 | -1 798 | -1 799 | -1 800 | -1 801 | -1 802 | -1 803 | -1 804 | -1 805 | -1 806 | -1 807 | -1 808 | -1 809 | -1 810 | -1 811 | -1 812 | -1 813 | -1 814 | -1 815 | -1 816 | -1 817 | -1 818 | -1 819 | -1 820 | -1 821 | -1 822 | -1 823 | -1 824 | -1 825 | -1 826 | -1 827 | -1 828 | -1 829 | -1 830 | -1 831 | -1 832 | -1 833 | -1 834 | -1 835 | -1 836 | -1 837 | -1 838 | -1 839 | -1 840 | -1 841 | -1 842 | -1 843 | -1 844 | -1 845 | -1 846 | -1 847 | -1 848 | -1 849 | -1 850 | -1 851 | -1 852 | -1 853 | -1 854 | -1 855 | -1 856 | -1 857 | -1 858 | -1 859 | -1 860 | -1 861 | -1 862 | -1 863 | -1 864 | -1 865 | -1 866 | -1 867 | -1 868 | -1 869 | -1 870 | -1 871 | -1 872 | -1 873 | -1 874 | -1 875 | -1 876 | -1 877 | -1 878 | -1 879 | -1 880 | -1 881 | -1 882 | -1 883 | -1 884 | -1 885 | -1 886 | -1 887 | -1 888 | -1 889 | -1 890 | -1 891 | -1 892 | -1 893 | -1 894 | -1 895 | -1 896 | -1 897 | -1 898 | -1 899 | -1 900 | -1 901 | -1 902 | -1 903 | -1 904 | -1 905 | -1 906 | -1 907 | -1 908 | -1 909 | -1 910 | -1 911 | -1 912 | -1 913 | -1 914 | -1 915 | -1 916 | -1 917 | -1 918 | -1 919 | -1 920 | -1 921 | -1 922 | -1 923 | -1 924 | -1 925 | -1 926 | -1 927 | -1 928 | -1 929 | -1 930 | -1 931 | -1 932 | -1 933 | -1 934 | -1 935 | -1 936 | -1 937 | -1 938 | -1 939 | -1 940 | -1 941 | -1 942 | -1 943 | -1 944 | -1 945 | -1 946 | -1 947 | -1 948 | -1 949 | -1 950 | -1 951 | -1 952 | -1 953 | -1 954 | -1 955 | -1 956 | -1 957 | -1 958 | -1 959 | -1 960 | -1 961 | -1 962 | -1 963 | -1 964 | -1 965 | -1 966 | -1 967 | -1 968 | -1 969 | -1 970 | -1 971 | -1 972 | -1 973 | -1 974 | -1 975 | -1 976 | -1 977 | -1 978 | -1 979 | -1 980 | -1 981 | -1 982 | -1 983 | -1 984 | -1 985 | -1 986 | -1 987 | -1 988 | -1 989 | -1 990 | -1 991 | -1 992 | -1 993 | -1 994 | -1 995 | -1 996 | -1 997 | -1 998 | -1 999 | -1 1000 | -1 1001 | -1 1002 | -1 1003 | -1 1004 | -1 1005 | -1 1006 | -1 1007 | -1 1008 | -1 1009 | -1 1010 | -1 1011 | -1 1012 | -1 1013 | -1 1014 | -1 1015 | -1 1016 | -1 1017 | -1 1018 | -1 1019 | -1 1020 | -1 1021 | -1 1022 | -1 1023 | -1 1024 | -1 1025 | -1 1026 | -1 1027 | -1 1028 | -1 1029 | -1 1030 | -1 1031 | -1 1032 | -1 1033 | -1 1034 | -1 1035 | -1 1036 | -1 1037 | -1 1038 | -1 1039 | -1 1040 | -1 1041 | -1 1042 | -1 1043 | -1 1044 | -1 1045 | -1 1046 | -1 1047 | -1 1048 | -1 1049 | -1 1050 | -1 1051 | -1 1052 | -1 1053 | -1 1054 | -1 1055 | -1 1056 | -1 1057 | -1 1058 | -1 1059 | -1 1060 | -1 1061 | -1 1062 | -1 1063 | -1 1064 | -1 1065 | -1 1066 | -1 1067 | -1 1068 | -1 1069 | -1 1070 | -1 1071 | -1 1072 | -1 1073 | -1 1074 | -1 1075 | -1 1076 | -1 1077 | -1 1078 | -1 1079 | -1 1080 | -1 1081 | -1 1082 | -1 1083 | -1 1084 | -1 1085 | -1 1086 | -1 1087 | -1 1088 | -1 1089 | -1 1090 | -1 1091 | -1 1092 | -1 1093 | -1 1094 | -1 1095 | -1 1096 | -1 1097 | -1 1098 | -1 1099 | -1 1100 | -1 1101 | -1 1102 | -1 1103 | -1 1104 | -1 1105 | -1 1106 | -1 1107 | -1 1108 | -1 1109 | -1 1110 | -1 1111 | -1 1112 | -1 1113 | -1 1114 | -1 1115 | -1 1116 | -1 1117 | -1 1118 | -1 1119 | -1 1120 | -1 1121 | -1 1122 | -1 1123 | -1 1124 | -1 1125 | -1 1126 | -1 1127 | -1 1128 | -1 1129 | -1 1130 | -1 1131 | -1 1132 | -1 1133 | -1 1134 | -1 1135 | -1 1136 | -1 1137 | -1 1138 | -1 1139 | -1 1140 | -1 1141 | -1 1142 | -1 1143 | -1 1144 | -1 1145 | -1 1146 | -1 1147 | -1 1148 | -1 1149 | -1 1150 | -1 1151 | -1 1152 | -1 1153 | -1 1154 | -1 1155 | -1 1156 | -1 1157 | -1 1158 | -1 1159 | -1 1160 | -1 1161 | -1 1162 | -1 1163 | -1 1164 | -1 1165 | -1 1166 | -1 1167 | -1 1168 | -1 1169 | -1 1170 | -1 1171 | -1 1172 | -1 1173 | -1 1174 | -1 1175 | -1 1176 | -1 1177 | -1 1178 | -1 1179 | -1 1180 | -1 1181 | -1 1182 | -1 1183 | -1 1184 | -1 1185 | -1 1186 | -1 1187 | -1 1188 | -1 1189 | -1 1190 | -1 1191 | -1 1192 | -1 1193 | -1 1194 | -1 1195 | -1 1196 | -1 1197 | -1 1198 | -1 1199 | -1 1200 | -1 1201 | -1 1202 | -1 1203 | -1 1204 | -1 1205 | -1 1206 | -1 1207 | -1 1208 | -1 1209 | -1 1210 | -1 1211 | -1 1212 | -1 1213 | -1 1214 | -1 1215 | -1 1216 | -1 1217 | -1 1218 | -1 1219 | -1 1220 | -1 1221 | -1 1222 | -1 1223 | -1 1224 | -1 1225 | -1 1226 | -1 1227 | -1 1228 | -1 1229 | -1 1230 | -1 1231 | -1 1232 | -1 1233 | -1 1234 | -1 1235 | -1 1236 | -1 1237 | -1 1238 | -1 1239 | -1 1240 | -1 1241 | -1 1242 | -1 1243 | -1 1244 | -1 1245 | -1 1246 | -1 1247 | -1 1248 | -1 1249 | -1 1250 | -1 1251 | -1 1252 | -1 1253 | -1 1254 | -1 1255 | -1 1256 | -1 1257 | -1 1258 | -1 1259 | -1 1260 | -1 1261 | -1 1262 | -1 1263 | -1 1264 | -1 1265 | -1 1266 | -1 1267 | -1 1268 | -1 1269 | -1 1270 | -1 1271 | -1 1272 | -1 1273 | -1 1274 | -1 1275 | -1 1276 | -1 1277 | -1 1278 | -1 1279 | -1 1280 | -1 1281 | -1 1282 | -1 1283 | -1 1284 | -1 1285 | -1 1286 | -1 1287 | -1 1288 | -1 1289 | -1 1290 | -1 1291 | -1 1292 | -1 1293 | -1 1294 | -1 1295 | -1 1296 | -1 1297 | -1 1298 | -1 1299 | -1 1300 | -1 1301 | -1 1302 | -1 1303 | -1 1304 | -1 1305 | -1 1306 | -1 1307 | -1 1308 | -1 1309 | -1 1310 | -1 1311 | -1 1312 | -1 1313 | -1 1314 | -1 1315 | -1 1316 | -1 1317 | -1 1318 | -1 1319 | -1 1320 | -1 1321 | -1 1322 | -1 1323 | -1 1324 | -1 1325 | -1 1326 | -1 1327 | -1 1328 | -1 1329 | -1 1330 | -1 1331 | -1 1332 | -1 1333 | -1 1334 | -1 1335 | -1 1336 | -1 1337 | -1 1338 | -1 1339 | -1 1340 | -1 1341 | -1 1342 | -1 1343 | -1 1344 | -1 1345 | -1 1346 | -1 1347 | -1 1348 | -1 1349 | -1 1350 | -1 1351 | -1 1352 | -1 1353 | -1 1354 | -1 1355 | -1 1356 | -1 1357 | -1 1358 | -1 1359 | -1 1360 | -1 1361 | -1 1362 | -1 1363 | -1 1364 | -1 1365 | -1 1366 | -1 1367 | -1 1368 | -1 1369 | -1 1370 | -1 1371 | -1 1372 | -1 1373 | -1 1374 | -1 1375 | -1 1376 | -1 1377 | -1 1378 | -1 1379 | -1 1380 | -1 1381 | -1 1382 | -1 1383 | -1 1384 | -1 1385 | -1 1386 | -1 1387 | -1 1388 | -1 1389 | -1 1390 | -1 1391 | -1 1392 | -1 1393 | -1 1394 | -1 1395 | -1 1396 | -1 1397 | -1 1398 | -1 1399 | -1 1400 | -1 1401 | -1 1402 | -1 1403 | -1 1404 | -1 1405 | -1 1406 | -1 1407 | -1 1408 | -1 1409 | -1 1410 | -1 1411 | -1 1412 | -1 1413 | -1 1414 | -1 1415 | -1 1416 | -1 1417 | -1 1418 | -1 1419 | -1 1420 | -1 1421 | -1 1422 | -1 1423 | -1 1424 | -1 1425 | -1 1426 | -1 1427 | -1 1428 | -1 1429 | -1 1430 | -1 1431 | -1 1432 | -1 1433 | -1 1434 | -1 1435 | -1 1436 | -1 1437 | -1 1438 | -1 1439 | -1 1440 | -1 1441 | -1 1442 | -1 1443 | -1 1444 | -1 1445 | -1 1446 | -1 1447 | -1 1448 | -1 1449 | -1 1450 | -1 1451 | -1 1452 | -1 1453 | -1 1454 | -1 1455 | -1 1456 | -1 1457 | -1 1458 | -1 1459 | -1 1460 | -1 1461 | -1 1462 | -1 1463 | -1 1464 | -1 1465 | -1 1466 | -1 1467 | -1 1468 | -1 1469 | -1 1470 | -1 1471 | -1 1472 | -1 1473 | -1 1474 | -1 1475 | -1 1476 | -1 1477 | -1 1478 | -1 1479 | -1 1480 | -1 1481 | -1 1482 | -1 1483 | -1 1484 | -1 1485 | -1 1486 | -1 1487 | -1 1488 | -1 1489 | -1 1490 | -1 1491 | -1 1492 | -1 1493 | -1 1494 | -1 1495 | -1 1496 | -1 1497 | -1 1498 | -1 1499 | -1 1500 | -1 1501 | 1 1502 | 1 1503 | 1 1504 | 1 1505 | 1 1506 | 1 1507 | 1 1508 | 1 1509 | 1 1510 | 1 1511 | 1 1512 | 1 1513 | 1 1514 | 1 1515 | 1 1516 | 1 1517 | 1 1518 | 1 1519 | 1 1520 | 1 1521 | 1 1522 | 1 1523 | 1 1524 | 1 1525 | 1 1526 | 1 1527 | 1 1528 | 1 1529 | 1 1530 | 1 1531 | 1 1532 | 1 1533 | 1 1534 | 1 1535 | 1 1536 | 1 1537 | 1 1538 | 1 1539 | 1 1540 | 1 1541 | 1 1542 | 1 1543 | 1 1544 | 1 1545 | 1 1546 | 1 1547 | 1 1548 | 1 1549 | 1 1550 | 1 1551 | 1 1552 | 1 1553 | 1 1554 | 1 1555 | 1 1556 | 1 1557 | 1 1558 | 1 1559 | 1 1560 | 1 1561 | 1 1562 | 1 1563 | 1 1564 | 1 1565 | 1 1566 | 1 1567 | 1 1568 | 1 1569 | 1 1570 | 1 1571 | 1 1572 | 1 1573 | 1 1574 | 1 1575 | 1 1576 | 1 1577 | 1 1578 | 1 1579 | 1 1580 | 1 1581 | 1 1582 | 1 1583 | 1 1584 | 1 1585 | 1 1586 | 1 1587 | 1 1588 | 1 1589 | 1 1590 | 1 1591 | 1 1592 | 1 1593 | 1 1594 | 1 1595 | 1 1596 | 1 1597 | 1 1598 | 1 1599 | 1 1600 | 1 1601 | 1 1602 | 1 1603 | 1 1604 | 1 1605 | 1 1606 | 1 1607 | 1 1608 | 1 1609 | 1 1610 | 1 1611 | 1 1612 | 1 1613 | 1 1614 | 1 1615 | 1 1616 | 1 1617 | 1 1618 | 1 1619 | 1 1620 | 1 1621 | 1 1622 | 1 1623 | 1 1624 | 1 1625 | 1 1626 | 1 1627 | 1 1628 | 1 1629 | 1 1630 | 1 1631 | 1 1632 | 1 1633 | 1 1634 | 1 1635 | 1 1636 | 1 1637 | 1 1638 | 1 1639 | 1 1640 | 1 1641 | 1 1642 | 1 1643 | 1 1644 | 1 1645 | 1 1646 | 1 1647 | 1 1648 | 1 1649 | 1 1650 | 1 1651 | 1 1652 | 1 1653 | 1 1654 | 1 1655 | 1 1656 | 1 1657 | 1 1658 | 1 1659 | 1 1660 | 1 1661 | 1 1662 | 1 1663 | 1 1664 | 1 1665 | 1 1666 | 1 1667 | 1 1668 | 1 1669 | 1 1670 | 1 1671 | 1 1672 | 1 1673 | 1 1674 | 1 1675 | 1 1676 | 1 1677 | 1 1678 | 1 1679 | 1 1680 | 1 1681 | 1 1682 | 1 1683 | 1 1684 | 1 1685 | 1 1686 | 1 1687 | 1 1688 | 1 1689 | 1 1690 | 1 1691 | 1 1692 | 1 1693 | 1 1694 | 1 1695 | 1 1696 | 1 1697 | 1 1698 | 1 1699 | 1 1700 | 1 1701 | 1 1702 | 1 1703 | 1 1704 | 1 1705 | 1 1706 | 1 1707 | 1 1708 | 1 1709 | 1 1710 | 1 1711 | 1 1712 | 1 1713 | 1 1714 | 1 1715 | 1 1716 | 1 1717 | 1 1718 | 1 1719 | 1 1720 | 1 1721 | 1 1722 | 1 1723 | 1 1724 | 1 1725 | 1 1726 | 1 1727 | 1 1728 | 1 1729 | 1 1730 | 1 1731 | 1 1732 | 1 1733 | 1 1734 | 1 1735 | 1 1736 | 1 1737 | 1 1738 | 1 1739 | 1 1740 | 1 1741 | 1 1742 | 1 1743 | 1 1744 | 1 1745 | 1 1746 | 1 1747 | 1 1748 | 1 1749 | 1 1750 | 1 1751 | 1 1752 | 1 1753 | 1 1754 | 1 1755 | 1 1756 | 1 1757 | 1 1758 | 1 1759 | 1 1760 | 1 1761 | 1 1762 | 1 1763 | 1 1764 | 1 1765 | 1 1766 | 1 1767 | 1 1768 | 1 1769 | 1 1770 | 1 1771 | 1 1772 | 1 1773 | 1 1774 | 1 1775 | 1 1776 | 1 1777 | 1 1778 | 1 1779 | 1 1780 | 1 1781 | 1 1782 | 1 1783 | 1 1784 | 1 1785 | 1 1786 | 1 1787 | 1 1788 | 1 1789 | 1 1790 | 1 1791 | 1 1792 | 1 1793 | 1 1794 | 1 1795 | 1 1796 | 1 1797 | 1 1798 | 1 1799 | 1 1800 | 1 1801 | 1 1802 | 1 1803 | 1 1804 | 1 1805 | 1 1806 | 1 1807 | 1 1808 | 1 1809 | 1 1810 | 1 1811 | 1 1812 | 1 1813 | 1 1814 | 1 1815 | 1 1816 | 1 1817 | 1 1818 | 1 1819 | 1 1820 | 1 1821 | 1 1822 | 1 1823 | 1 1824 | 1 1825 | 1 1826 | 1 1827 | 1 1828 | 1 1829 | 1 1830 | 1 1831 | 1 1832 | 1 1833 | 1 1834 | 1 1835 | 1 1836 | 1 1837 | 1 1838 | 1 1839 | 1 1840 | 1 1841 | 1 1842 | 1 1843 | 1 1844 | 1 1845 | 1 1846 | 1 1847 | 1 1848 | 1 1849 | 1 1850 | 1 1851 | 1 1852 | 1 1853 | 1 1854 | 1 1855 | 1 1856 | 1 1857 | 1 1858 | 1 1859 | 1 1860 | 1 1861 | 1 1862 | 1 1863 | 1 1864 | 1 1865 | 1 1866 | 1 1867 | 1 1868 | 1 1869 | 1 1870 | 1 1871 | 1 1872 | 1 1873 | 1 1874 | 1 1875 | 1 1876 | 1 1877 | 1 1878 | 1 1879 | 1 1880 | 1 1881 | 1 1882 | 1 1883 | 1 1884 | 1 1885 | 1 1886 | 1 1887 | 1 1888 | 1 1889 | 1 1890 | 1 1891 | 1 1892 | 1 1893 | 1 1894 | 1 1895 | 1 1896 | 1 1897 | 1 1898 | 1 1899 | 1 1900 | 1 1901 | 1 1902 | 1 1903 | 1 1904 | 1 1905 | 1 1906 | 1 1907 | 1 1908 | 1 1909 | 1 1910 | 1 1911 | 1 1912 | 1 1913 | 1 1914 | 1 1915 | 1 1916 | 1 1917 | 1 1918 | 1 1919 | 1 1920 | 1 1921 | 1 1922 | 1 1923 | 1 1924 | 1 1925 | 1 1926 | 1 1927 | 1 1928 | 1 1929 | 1 1930 | 1 1931 | 1 1932 | 1 1933 | 1 1934 | 1 1935 | 1 1936 | 1 1937 | 1 1938 | 1 1939 | 1 1940 | 1 1941 | 1 1942 | 1 1943 | 1 1944 | 1 1945 | 1 1946 | 1 1947 | 1 1948 | 1 1949 | 1 1950 | 1 1951 | 1 1952 | 1 1953 | 1 1954 | 1 1955 | 1 1956 | 1 1957 | 1 1958 | 1 1959 | 1 1960 | 1 1961 | 1 1962 | 1 1963 | 1 1964 | 1 1965 | 1 1966 | 1 1967 | 1 1968 | 1 1969 | 1 1970 | 1 1971 | 1 1972 | 1 1973 | 1 1974 | 1 1975 | 1 1976 | 1 1977 | 1 1978 | 1 1979 | 1 1980 | 1 1981 | 1 1982 | 1 1983 | 1 1984 | 1 1985 | 1 1986 | 1 1987 | 1 1988 | 1 1989 | 1 1990 | 1 1991 | 1 1992 | 1 1993 | 1 1994 | 1 1995 | 1 1996 | 1 1997 | 1 1998 | 1 1999 | 1 2000 | 1 2001 | -------------------------------------------------------------------------------- /code/main_data_augmentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import math 4 | import numpy as np 5 | import os 6 | import random 7 | import json 8 | import re 9 | import datetime 10 | import time 11 | from multiprocessing import Pool, cpu_count 12 | from functools import partial 13 | 14 | from sklearn.model_selection import GridSearchCV 15 | from sklearn.neighbors import KernelDensity 16 | 17 | from keras.layers import Dense, Dropout, Flatten, Input, Convolution2D, MaxPooling2D, Merge 18 | from keras.utils import np_utils 19 | from keras.models import Model 20 | from keras import backend as K 21 | from keras.callbacks import EarlyStopping 22 | 23 | # ============================================================================= 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | # positional arguments (required) 28 | parser.add_argument('path_root', type=str, help="path to 'datasets' directory") 29 | parser.add_argument('dataset', type=str, help='name of the dataset. Must correspond to a valid value that matches names of files in node2vec folder') 30 | parser.add_argument('p', type=str, help='p parameter of node2vec. Must correspond to a valid value that matches names of files in node2vec folder') 31 | parser.add_argument('q', type=str, help='q parameter of node2vec. Must correspond to a valid value that matches names of files in node2vec folder') 32 | parser.add_argument('definition', type=int, help='definition. E.g., 14 for 14:1. Must correspond to a valid value that matches names of files in node2vec folder') 33 | parser.add_argument('n_channels', type=int, help='number of channels that we will be passed to the network. Must not exceed half the depth of the tensors in node2vec folder') 34 | parser.add_argument('n_bootstrap', type=float, help='augmentation ratio. Must be strictly between 0 and 1') 35 | 36 | # optional arguments 37 | parser.add_argument('--n_folds', type=int, default=10, choices=[2,3,4,5,6,7,8,9,10], help='number of folds for cross-validation') 38 | parser.add_argument('--n_repeats', type=int, default=3, choices=[1,2,3,4,5], help='number of times each fold should be repeated') 39 | parser.add_argument('--batch_size', type=int, default=32, choices=[32,64,128], help='batch size') 40 | parser.add_argument('--nb_epochs', type=int, default=50, help='maximum number of epochs') 41 | parser.add_argument('--patience', type=int, default=5, help='patience for early stopping strategy') 42 | parser.add_argument('--drop_rate', type=float, default=0.3, help='dropout rate') 43 | 44 | args = parser.parse_args() 45 | 46 | # convert command line arguments 47 | path_root = args.path_root 48 | dataset = args.dataset 49 | p = args.p 50 | q = args.q 51 | definition = args.definition 52 | n_channels = args.n_channels 53 | n_bootstrap = args.n_bootstrap 54 | 55 | n_folds = args.n_folds 56 | n_repeats = args.n_repeats 57 | batch_size = args.batch_size 58 | nb_epochs = args.nb_epochs 59 | my_patience = args.patience 60 | drop_rate = args.drop_rate 61 | 62 | dim_ordering = 'th' # channels first 63 | my_optimizer = 'adam' 64 | params = {'bandwidth': np.logspace(-1, -0.5, 10)} # for the kernel bandwidth grid search 65 | 66 | # command line example: python main_data_augmentation.py /home/antoine/Desktop/graph_2D_CNN/datasets/ imdb_action_romance 1 1 14 5 0.1 67 | 68 | # ============================================================================= 69 | 70 | def atoi(text): 71 | return int(text) if text.isdigit() else text 72 | 73 | def natural_keys(text): 74 | return [atoi(c) for c in re.split('(\d+)', text)] 75 | 76 | def get_bw_cv(x,params): 77 | grid = GridSearchCV(KernelDensity(), params,cv=2,n_jobs=1) 78 | grid.fit(x[:,None]) 79 | bw = grid.best_estimator_.bandwidth 80 | return bw 81 | 82 | def get_hist_node2vec(emb,d,my_min,my_max,definition): 83 | # d should be an even integer 84 | img_dim = int(np.arange(my_min, my_max+0.05,(my_max+0.05-my_min)/float(definition*(my_max+0.05-my_min))).shape[0]-1) 85 | my_bins = np.linspace(my_min,my_max,img_dim) # to have middle bin centered on zero 86 | Hs = [] 87 | for i in range(0,d,2): 88 | H, xedges, yedges = np.histogram2d(x=emb[:,i],y=emb[:,i+1],bins=my_bins, normed=False) 89 | Hs.append(H) 90 | Hs = np.array(Hs) 91 | return Hs 92 | 93 | def smoothed_bootstrap(my_array,params): 94 | # compute mean and variance along each dimension 95 | my_means = np.mean(my_array,0) 96 | my_vars = np.var(my_array,0) 97 | 98 | # to save time, estimate bandwidth from at most the first 100 nodes 99 | my_bws = np.apply_along_axis(get_bw_cv,0,my_array[:min(100,my_array.shape[0]),:],params) 100 | 101 | all_new_coords = [] 102 | for jj in range(int(np.random.normal(my_array.shape[0],scale=my_array.shape[0]/5))): 103 | rand_row_idx = random.randint(0,my_array.shape[0]-1) # select a row index (i.e., a node) at random 104 | my_coords = my_array[rand_row_idx,:].tolist() 105 | new_coords = [0]*len(my_coords) 106 | for kk in range(len(new_coords)): # for each dim 107 | new_coords[kk] = my_means[kk] + (my_coords[kk] - my_means[kk] + np.random.normal(0,scale=(my_bws[kk])**0.5))/((1+my_bws[kk]/my_vars[kk])**(.5)) 108 | all_new_coords.append(new_coords) 109 | all_new_coords = np.array(all_new_coords) 110 | return all_new_coords 111 | 112 | # ============================================================================= 113 | 114 | def main(): 115 | 116 | my_date_time = '_'.join(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S").split()) 117 | 118 | parameters = {'path_root':path_root, 119 | 'dataset':dataset, 120 | 'p':p, 121 | 'q':q, 122 | 'definition':definition, 123 | 'n_channels':n_channels, 124 | 'n_bootstrap':n_bootstrap, 125 | 'n_folds':n_folds, 126 | 'n_repeats':n_repeats, 127 | 'batch_size':batch_size, 128 | 'nb_epochs':nb_epochs, 129 | 'my_patience':my_patience, 130 | 'drop_rate':drop_rate, 131 | 'dim_ordering':dim_ordering, 132 | 'my_optimizer':my_optimizer 133 | } 134 | 135 | name_save = path_root + '/results/' + dataset + '_augmentation_' + my_date_time 136 | 137 | with open(name_save + '_parameters.json', 'w') as my_file: 138 | json.dump(parameters, my_file, sort_keys=True, indent=4) 139 | 140 | print '========== parameters defined and saved to disk ==========' 141 | 142 | regexp_p = re.compile('p=' + p) 143 | regexp_q = re.compile('q=' + q) 144 | n_dim = 2*n_channels 145 | inverse_n_b = int(round(1/n_bootstrap)) 146 | smoothed_bootstrap_partial = partial(smoothed_bootstrap, params=params) # for parallelization 147 | n_jobs = cpu_count() 148 | 149 | print '========== loading labels ==========' 150 | 151 | with open(path_root + 'classes/' + dataset + '/' + dataset + '_classes.txt', 'r') as f: 152 | ys = f.read().splitlines() 153 | ys = [int(elt) for elt in ys] 154 | 155 | num_classes = len(list(set(ys))) 156 | 157 | print 'classes:', list(set(ys)) 158 | 159 | print 'converting to 0-based index' 160 | 161 | if 0 not in list(set(ys)): 162 | if -1 not in list(set(ys)): 163 | ys = [y-1 for y in ys] 164 | else: 165 | ys = [1 if y==1 else 0 for y in ys] 166 | 167 | print 'classes:', list(set(ys)) 168 | 169 | print '========== loading node2vec embeddings ==========' 170 | 171 | all_file_names = os.listdir(path_root + '/raw_node2vec/' + dataset + '/') 172 | print '===== total number of files in folder: =====', len(all_file_names) 173 | 174 | file_names_filtered = [elt for elt in all_file_names if (dataset in elt and regexp_p.search(elt) and regexp_q.search(elt) and elt.count('p=')==1 and elt.count('q=')==1 and elt.split('_')[-1:][0][0].isdigit())] 175 | file_names_filtered.sort(key=natural_keys) 176 | 177 | print 'number of files after filtering:', len(file_names_filtered) 178 | print '*** head ***' 179 | print file_names_filtered[:5] 180 | print '*** tail ***' 181 | print file_names_filtered[-5:] 182 | 183 | # load tensors 184 | raw_emb = [] 185 | excluded_idxs = [] 186 | for idx, name in enumerate(file_names_filtered): 187 | emb = np.load(path_root + '/raw_node2vec/' + dataset + '/' + name) 188 | if emb.shape[1]