├── .gitignore ├── LICENSE ├── README.md ├── convertor └── mutag_convertor.py ├── figure ├── diamnet.png ├── overview.png └── representation.png ├── generator ├── graph_checker.py ├── graph_generator.py ├── mutag_generator.py ├── pattern_checker.py ├── pattern_generator.py ├── run.py └── utils.py ├── requirements.txt └── src ├── basemodel.py ├── cnn.py ├── dataset.py ├── embedding.py ├── evaluate.py ├── filternet.py ├── finetune.py ├── finetune_mutag.py ├── predictnet.py ├── rgcn.py ├── rgin.py ├── rnn.py ├── train.py ├── train_mutag.py ├── txl.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | data/* 3 | dumps/* 4 | figs/* 5 | .ipynb_checkpoints 6 | __pychache__ 7 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 HKUST-KnowComp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuralSubgraphCounting 2 | 3 | This repository is an official implementation of the paper [Neural Subgraph Isomorphism Counting](https://arxiv.org/abs/1912.11589). 4 | 5 | ## Introduction 6 | 7 | We propose a learning framework which augments different representation learning architectures and iteratively attends pattern and target data graphs to memorize subgraph isomorphisms for the global counting. 8 | 9 | ![Overview](figure/overview.png) 10 | 11 | ![Representation](figure/representation.png) 12 | 13 | We can use the **minimum code** (with the minimum lexicographic order) defined by Xifeng Yan to convert a graph to a sequence and use sequence models, e.g., CNN, LSTM, and Transformer-XL. A more direct apporach is to use graph covlutional networks to learn representations, e.g., RGCN, RGIN. 14 | 15 | As for the interaction module, simple pooling is obviously not enough. We design the Memory Attention Predict Network (MemAttnPredictNet) and Dynamic Intermedium Attention Memory (DIAMNet), you can try them in the following reproduction part. 16 | 17 | ![DIAMNet](figure/diamnet.png) 18 | 19 | ## Reproduction 20 | 21 | ### Package Dependencies 22 | 23 | * tqdm 24 | * numpy 25 | * pandas 26 | * scipy 27 | * tensorboardX 28 | * python-igraph == 0.9.11 29 | * torch >= 1.3.0 30 | * dgl == 0.4.3post2 31 | 32 | 33 | ### Data Generation 34 | The data in the KDD paper is available at [OneDrive](https://hkustconnect-my.sharepoint.com/:f:/g/personal/xliucr_connect_ust_hk/EqEONJuKHLVGo7ky759-ZvEB4WjWe2bKG2A725AGSD6G9g?e=HhIFeb). 35 | 36 | You can also generate data by modifying `run.py` to set `CONFIG` and run 37 | ```bash 38 | cd generator 39 | python run.py 40 | ``` 41 | 42 | For the *MUTAG* data, you can use the `mutag_convertor.py` to generate the raw graphs. 43 | ```bash 44 | cd convertor 45 | python mutag_convertor.py 46 | ``` 47 | 48 | You can use `generator\mutag_generator.py` to generate patterns. But be careful of duplications. 49 | 50 | ### Model Training/Finetuning 51 | 52 | For the *small* dataset, just run 53 | ```bash 54 | cd src 55 | python train.py --model RGIN --predict_net SumPredictNet \ 56 | --gpu_id 0 --batch_size 512 \ 57 | --max_npv 8 --max_npe 8 --max_npvl 8 --max_npel 8 \ 58 | --max_ngv 64 --max_nge 256 --max_ngvl 16 --max_ngel 16 \ 59 | --pattern_dir ../data/small/patterns \ 60 | --graph_dir ../data/small/graphs \ 61 | --metadata_dir ../data/small/metadata \ 62 | --save_data_dir ../data/small \ 63 | --save_model_dir ../dumps/small/RGIN-SumPredictNet 64 | ``` 65 | 66 | ```bash 67 | cd src 68 | python train.py --model RGIN --predict_net DIAMNet \ 69 | --predict_net_mem_init mean --predict_net_mem_len 4 --predict_net_recurrent_steps 3 \ 70 | --gpu_id 0 --batch_size 512 \ 71 | --max_npv 8 --max_npe 8 --max_npvl 8 --max_npel 8 \ 72 | --max_ngv 64 --max_nge 256 --max_ngvl 16 --max_ngel 16 \ 73 | --pattern_dir ../data/small/patterns \ 74 | --graph_dir ../data/small/graphs \ 75 | --metadata_dir ../data/small/metadata \ 76 | --save_data_dir ../data/small \ 77 | --save_model_dir ../dumps/small/RGIN-DIAMNet 78 | ``` 79 | 80 | We find using the encoder module from RGIN-SumPredictNet results in the faster convergence of RGIN-DIAMNet 81 | 82 | ```bash 83 | cd src 84 | python finetune.py --model RGIN --predict_net DIAMNet \ 85 | --predict_net_mem_init mean --predict_net_mem_len 4 --predict_net_recurrent_steps 3 \ 86 | --gpu_id 0 --batch_size 512 \ 87 | --max_npv 8 --max_npe 8 --max_npvl 8 --max_npel 8 \ 88 | --max_ngv 64 --max_nge 256 --max_ngvl 16 --max_ngel 16 \ 89 | --pattern_dir ../data/small/patterns \ 90 | --graph_dir ../data/small/graphs \ 91 | --metadata_dir ../data/small/metadata \ 92 | --save_data_dir ../data/small \ 93 | --save_model_dir ../dumps/small/RGIN-DIAMNet \ 94 | --load_model_dir ../dumps/small/RGIN-SumPredictNet 95 | ``` 96 | 97 | 98 | For the *large* dataset, just run 99 | ```bash 100 | cd src 101 | python finetune.py --model RGIN --predict_net SumPredictNet \ 102 | --gpu_id 0 --batch_size 128 --update_every 4 \ 103 | --max_npv 16 --max_npe 16 --max_npvl 16 --max_npel 16 \ 104 | --max_ngv 512 --max_nge 2048 --max_ngvl 64 --max_ngel 64 \ 105 | --pattern_dir ../data/large/patterns \ 106 | --graph_dir ../data/large/graphs \ 107 | --metadata_dir ../data/large/metadata \ 108 | --save_data_dir ../data/large \ 109 | --save_model_dir ../dumps/large/RGIN-SumPredictNet \ 110 | --load_model_dir ../dumps/small/RGIN-SumPredictNet 111 | ``` 112 | 113 | ```bash 114 | cd src 115 | python finetune.py --model RGIN --predict_net DIAMNet \ 116 | --predict_net_mem_init mean --predict_net_mem_len 4 --predict_net_recurrent_steps 3 \ 117 | --gpu_id 0 --batch_size 128 --update_every 4 \ 118 | --max_npv 16 --max_npe 16 --max_npvl 16 --max_npel 16 \ 119 | --max_ngv 512 --max_nge 2048 --max_ngvl 64 --max_ngel 64 \ 120 | --pattern_dir ../data/large/patterns \ 121 | --graph_dir ../data/large/graphs \ 122 | --metadata_dir ../data/large/metadata \ 123 | --save_data_dir ../data/large \ 124 | --save_model_dir ../dumps/large/RGIN-DIAMNet \ 125 | --load_model_dir ../dumps/small/RGIN-DIAMNet 126 | ``` 127 | 128 | 129 | For the *MUTAG* dataset, you need to set the `train_ratio` manually 130 | ```bash 131 | cd src 132 | python train_mutag.py --model RGIN --predict_net SumPredictNet \ 133 | --gpu_id 0 --batch_size 64 \ 134 | --max_npv 4 --max_npe 3 --max_npvl 2 --max_npel 2 \ 135 | --max_ngv 28 --max_nge 66 --max_ngvl 7 --max_ngel 4 \ 136 | --pattern_dir ../data/MUTAG/patterns \ 137 | --graph_dir ../data/MUTAG/raw \ 138 | --metadata_dir ../data/MUTAG/metadata \ 139 | --save_data_dir ../data/MUTAG/RGIN-SumPredictNet-0.4 \ 140 | --save_model_dir ../dumps/MUTAG \ 141 | --train_ratio 0.4 142 | ``` 143 | 144 | Transfer learning can improve the performance when the number of training data is limited. 145 | 146 | ```bash 147 | cd src 148 | python finetune_mutag.py --model RGIN --predict_net SumPredictNet \ 149 | --gpu_id 0 --batch_size 64 \ 150 | --max_npv 8 --max_npe 8 --max_npvl 8 --max_npel 8 \ 151 | --max_ngv 64 --max_nge 256 --max_ngvl 16 --max_ngel 16 \ 152 | --pattern_dir ../data/MUTAG/patterns \ 153 | --graph_dir ../data/MUTAG/raw \ 154 | --metadata_dir ../data/MUTAG/metadata \ 155 | --save_data_dir ../data/MUTAG \ 156 | --save_model_dir ../dumps/MUTAG/Transfer-RGIN-SumPredictNet-0.4 \ 157 | --train_ratio 0.4 \ 158 | --load_model_dir ../dumps/small/RGIN-SumPredictNet 159 | ``` 160 | 161 | For the RGIN-DIAMNet on the *MUTAG*, it is difficult to converge. So we load RGIN-SumPredictNet and replace the interaction module for both MeanMemAttnPredictNet and DIAMNet. 162 | 163 | ```bash 164 | cd src 165 | python finetune_mutag.py --model RGIN --predict_net DIAMNet \ 166 | --predict_net_mem_init mean --predict_net_mem_len 4 --predict_net_recurrent_steps 1 \ 167 | --gpu_id 0 --batch_size 64 \ 168 | --max_npv 4 --max_npe 3 --max_npvl 2 --max_npel 2 \ 169 | --max_ngv 28 --max_nge 66 --max_ngvl 7 --max_ngel 4 \ 170 | --pattern_dir ../data/MUTAG/patterns \ 171 | --graph_dir ../data/MUTAG/raw \ 172 | --metadata_dir ../data/MUTAG/metadata \ 173 | --save_data_dir ../data/MUTAG \ 174 | --save_model_dir ../dumps/MUTAG/RGIN-DIAMNet-0.4 \ 175 | --train_ratio 0.4 \ 176 | --load_model_dir ../dumps/MUTAG/RGIN-SumPredictNet-0.4 177 | ``` 178 | 179 | ```bash 180 | cd src 181 | python finetune_mutag.py --model RGIN --predict_net DIAMNet \ 182 | --predict_net_mem_init mean --predict_net_mem_len 4 --predict_net_recurrent_steps 1 \ 183 | --gpu_id 0 --batch_size 64 \ 184 | --max_npv 8 --max_npe 8 --max_npvl 8 --max_npel 8 \ 185 | --max_ngv 64 --max_nge 256 --max_ngvl 16 --max_ngel 16 \ 186 | --pattern_dir ../data/MUTAG/patterns \ 187 | --graph_dir ../data/MUTAG/raw \ 188 | --metadata_dir ../data/MUTAG/metadata \ 189 | --save_data_dir ../data/MUTAG \ 190 | --save_model_dir ../dumps/MUTAG/Transfer-RGIN-SumPredictNet-0.4 \ 191 | --train_ratio 0.4 \ 192 | --load_model_dir ../dumps/MUTAG/Transfer-RGIN-DIAMNet-0.4 193 | ``` 194 | 195 | ### Model Evaluation 196 | ```bash 197 | cd src 198 | python evaluate.py ../dumps/small/RGIN-DIAMNet 199 | ``` 200 | 201 | ### Citation 202 | 203 | The details of this pipeline are described in the following paper. If you use this code in your work, please kindly cite it. 204 | 205 | ```bibtex 206 | @inproceedings{liu2020neuralsubgrpahcounting, 207 | author = {Xin Liu, Haojie Pan, Mutian He, Yangqiu Song, Xin Jiang, Lifeng Shang}, 208 | title = {Neural Subgraph Isomorphism Counting}, 209 | booktitle = {ACM SIGKDD Conference on Knowledge Discovery and Data Mining {KDD} 2020, August 23-27, 2020, San Diego, United States.} 210 | } 211 | ``` 212 | 213 | ### Miscellaneous 214 | 215 | Please send any questions about the code and/or the algorithm to . 216 | -------------------------------------------------------------------------------- /convertor/mutag_convertor.py: -------------------------------------------------------------------------------- 1 | import igraph as ig 2 | import os 3 | import sys 4 | from tqdm import tqdm 5 | 6 | if __name__ == "__main__": 7 | assert len(sys.argv) == 2 8 | mutag_data_path = sys.argv[1] 9 | 10 | nodes = list() 11 | node2graph = list() 12 | with open(os.path.join(mutag_data_path, "MUTAG_graph_indicator.txt"), "r") as f: 13 | for n_id, g_id in enumerate(f): 14 | g_id = int(g_id)-1 15 | node2graph.append(g_id) 16 | if g_id == len(nodes): 17 | nodes.append(list()) 18 | nodes[-1].append(n_id) 19 | 20 | nodelabels = [list() for _ in range(len(nodes))] 21 | with open(os.path.join(mutag_data_path, "MUTAG_node_labels.txt"), "r") as f: 22 | _nodelabels = list() 23 | for nl in f: 24 | nl = int(nl) 25 | _nodelabels.append(nl) 26 | n_idx = 0 27 | for g_idx in range(len(nodes)): 28 | for _ in range(len(nodes[g_idx])): 29 | nodelabels[g_idx].append(_nodelabels[n_idx]) 30 | n_idx += 1 31 | del _nodelabels 32 | 33 | edges = [list() for _ in range(len(nodes))] 34 | with open(os.path.join(mutag_data_path, "MUTAG_A.txt"), "r") as f: 35 | for e in f: 36 | e = [int(v)-1 for v in e.split(",")] 37 | g_id = node2graph[e[0]] 38 | edges[g_id].append((e[0]-nodes[g_id][0], e[1]-nodes[g_id][0])) 39 | 40 | edgelabels = [list() for _ in range(len(nodes))] 41 | with open(os.path.join(mutag_data_path, "MUTAG_edge_labels.txt"), "r") as f: 42 | _edgelabels = list() 43 | for el in f: 44 | el = int(el) 45 | _edgelabels.append(el) 46 | e_idx = 0 47 | for g_idx in range(len(edges)): 48 | for _ in range(len(edges[g_idx])): 49 | edgelabels[g_idx].append(_edgelabels[e_idx]) 50 | e_idx += 1 51 | del _edgelabels 52 | 53 | os.makedirs(os.path.join(mutag_data_path, "raw"), exist_ok=True) 54 | for g_id in tqdm(range(len(nodes))): 55 | graph = ig.Graph(directed=True) 56 | vcount = len(nodes[g_id]) 57 | vlabels = nodelabels[g_id] 58 | elabels = edgelabels[g_id] 59 | 60 | graph.add_vertices(vcount) 61 | graph.add_edges(edges[g_id]) 62 | graph.vs["label"] = vlabels 63 | graph.es["label"] = elabels 64 | graph.es["key"] = [0] * len(edges[g_id]) 65 | 66 | graph_id = "G_N%d_E%d_NL%d_EL%d_%d" % ( 67 | vcount, len(edges[g_id]), max(vlabels)+1, max(elabels)+1, g_id) 68 | filename = os.path.join(mutag_data_path, "raw", graph_id) 69 | # nx.nx_pydot.write_dot(pattern, filename + ".dot") 70 | graph.write(filename + ".gml") 71 | -------------------------------------------------------------------------------- /figure/diamnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/NeuralSubgraphCounting/22df62d59e112716a22f80db408fcddfc95da5c8/figure/diamnet.png -------------------------------------------------------------------------------- /figure/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/NeuralSubgraphCounting/22df62d59e112716a22f80db408fcddfc95da5c8/figure/overview.png -------------------------------------------------------------------------------- /figure/representation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/NeuralSubgraphCounting/22df62d59e112716a22f80db408fcddfc95da5c8/figure/representation.png -------------------------------------------------------------------------------- /generator/graph_checker.py: -------------------------------------------------------------------------------- 1 | import igraph as ig 2 | import numpy as np 3 | import argparse 4 | import os 5 | import math 6 | import json 7 | import shutil 8 | from collections import Counter, defaultdict 9 | from utils import generate_labels, generate_tree, get_direction, powerset, sample_element, str2bool, retrieve_multiple_edges 10 | from pattern_checker import PatternChecker 11 | from time import time 12 | from functools import partial 13 | from tqdm import tqdm 14 | from multiprocessing import Pool 15 | 16 | def get_subisomorphisms(pattern, graphs): 17 | results = dict() 18 | pattern_checker = PatternChecker() 19 | for gid, graph in graphs.items(): 20 | subisomorphisms = pattern_checker.get_subisomorphisms(graph, pattern) 21 | results[gid] = subisomorphisms 22 | return results 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--pattern_dir", type=str, default=0) 27 | parser.add_argument("--raw_graph_dir", type=str, default=3) 28 | parser.add_argument("--save_graph_dir", type=str, default="graphs") 29 | parser.add_argument("--save_metadata_dir", type=str, default="metadata") 30 | parser.add_argument("--num_workers", type=int, default=20) 31 | parser.add_argument("--seed", type=int, default=0) 32 | args = parser.parse_args() 33 | 34 | np.random.seed(args.seed) 35 | 36 | os.makedirs(os.path.join(args.save_graph_dir), exist_ok=True) 37 | os.makedirs(os.path.join(args.save_metadata_dir), exist_ok=True) 38 | with Pool(args.num_workers) as pool: 39 | pool_results = list() 40 | patterns = dict() 41 | graphs = dict() 42 | for pid in os.listdir(args.pattern_dir): 43 | if not pid.endswith(".gml"): 44 | continue 45 | pattern = ig.read(os.path.join(args.pattern_dir, pid)) 46 | patterns[os.path.splitext(pid)[0]] = pattern 47 | # shutil.copytree(args.raw_graph_dir, os.path.join(args.save_graph_dir, os.path.splitext(pid)[0])) 48 | # os.system("ln -s %s %s" % (args.raw_graph_dir, os.path.join(args.save_graph_dir, os.path.splitext(pid)[0]))) 49 | 50 | for gid in os.listdir(args.raw_graph_dir): 51 | if not gid.endswith(".gml"): 52 | continue 53 | graph = ig.read(os.path.join(args.raw_graph_dir, gid)) 54 | graphs[os.path.splitext(gid)[0]] = graph 55 | 56 | for pid, pattern in patterns.items(): 57 | pool_results.append((pid, pool.apply_async(get_subisomorphisms, args=(pattern, graphs)))) 58 | for x in tqdm(pool_results): 59 | pid, x = x 60 | x = x.get() 61 | os.makedirs(os.path.join(args.save_metadata_dir, pid), exist_ok=True) 62 | 63 | mae = 0 64 | mse = 0 65 | for gid, subisomorphisms in x.items(): 66 | mae += len(subisomorphisms) 67 | mse += len(subisomorphisms) * len(subisomorphisms) 68 | with open(os.path.join(args.save_metadata_dir, pid, gid+".meta"), "w") as f: 69 | json.dump({"counts": len(subisomorphisms), "subisomorphisms": subisomorphisms}, f) 70 | print("pid", "mae", mae/len(x), "mse", mse/len(x)) 71 | 72 | -------------------------------------------------------------------------------- /generator/mutag_generator.py: -------------------------------------------------------------------------------- 1 | import igraph as ig 2 | import numpy as np 3 | import argparse 4 | import os 5 | import math 6 | import json 7 | from collections import Counter, defaultdict 8 | from utils import generate_labels, generate_tree, get_direction, powerset, sample_element, str2bool, retrieve_multiple_edges 9 | from pattern_checker import PatternChecker 10 | from pattern_generator import generate_patterns 11 | from graph_generator import GraphGenerator 12 | from time import time 13 | 14 | def generate_graphs(pattern, min_number_of_vertices, max_number_of_vertices, min_number_of_edges, max_number_of_edges, \ 15 | min_number_of_vertex_labels, max_number_of_vertex_labels, min_number_of_edge_labels, max_number_of_edge_labels, \ 16 | alpha, max_pattern_counts, max_subgraph, return_subisomorphisms, number_of_graphs): 17 | graph_generator = GraphGenerator(pattern) 18 | results = list() 19 | vl1, vl2 = math.log2(min_number_of_vertex_labels), math.log2(max_number_of_vertex_labels) 20 | el1, el2 = math.log2(min_number_of_edge_labels), math.log2(max_number_of_edge_labels) 21 | for g in range(number_of_graphs): 22 | number_of_vertices = np.random.randint(min_number_of_vertices, max_number_of_vertices+1) 23 | number_of_edges = np.random.randint(max(min_number_of_edges, number_of_vertices), max_number_of_edges+1) 24 | number_of_vertex_labels = math.floor(math.pow(np.random.rand()*(vl2-vl1)+vl1, 2)) 25 | number_of_edge_labels = math.floor(math.pow(np.random.rand()*(el2-el1)+el1, 2)) 26 | graph, metadata = graph_generator.generate( 27 | number_of_vertices, number_of_edges, number_of_vertex_labels, number_of_edge_labels, 28 | alpha, max_pattern_counts=max_pattern_counts, max_subgraph=max_subgraph, 29 | return_subisomorphisms=return_subisomorphisms) 30 | print("%d/%d" % (g+1, number_of_graphs), "number of subisomorphisms: %d" % (metadata["counts"])) 31 | results.append((graph, metadata)) 32 | return results 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--seed", type=int, default=0) 38 | parser.add_argument("--min_number_of_vertices", type=int, default=10) 39 | parser.add_argument("--max_number_of_vertices", type=int, default=28) 40 | parser.add_argument("--min_number_of_edges", type=int, default=20//2) 41 | parser.add_argument("--max_number_of_edges", type=int, default=66//2) 42 | parser.add_argument("--min_number_of_vertex_labels", type=int, default=3) 43 | parser.add_argument("--max_number_of_vertex_labels", type=int, default=7) 44 | parser.add_argument("--min_number_of_edge_labels", type=int, default=3) 45 | parser.add_argument("--max_number_of_edge_labels", type=int, default=4) 46 | parser.add_argument("--alpha", type=float, default=0.95) 47 | parser.add_argument("--max_pattern_counts", type=float, default=128) 48 | parser.add_argument("--return_subisomorphisms", type=str2bool, default=True) 49 | parser.add_argument("--max_subgraph", type=int, default=512) 50 | parser.add_argument("--number_of_graphs", type=int, default=56) 51 | parser.add_argument("--pattern_path", type=str,default=r"patterns/P_N3_E3_NL2_EL2_0.gml") 52 | parser.add_argument("--save_graph_dir", type=str, default="graphs") 53 | parser.add_argument("--save_metadata_dir", type=str, default="metadata") 54 | parser.add_argument("--save_png", type=str2bool, default=False) 55 | parser.add_argument("--show_img", type=str2bool, default=False) 56 | args = parser.parse_args() 57 | 58 | np.random.seed(args.seed) 59 | 60 | try: 61 | pattern = ig.read(args.pattern_path) 62 | pattern.vs["label"] = [int(x) for x in pattern.vs["label"]] 63 | pattern.es["label"] = [int(x) for x in pattern.es["label"]] 64 | pattern.es["key"] = [int(x) for x in pattern.es["key"]] 65 | except BaseException as e: 66 | print(e) 67 | pattern = ig.Graph(directed=True) 68 | pattern.vs["label"] = [] 69 | pattern.es["label"] = [] 70 | pattern.es["key"] = [] 71 | 72 | results = generate_graphs(pattern, 73 | args.min_number_of_vertices, args.max_number_of_vertices, 74 | args.min_number_of_edges, args.max_number_of_edges, 75 | args.min_number_of_vertex_labels, args.max_number_of_vertex_labels, 76 | args.min_number_of_edge_labels, args.max_number_of_edge_labels, 77 | args.alpha, args.max_pattern_counts, args.max_subgraph, 78 | args.return_subisomorphisms, args.number_of_graphs) 79 | 80 | if args.save_graph_dir: 81 | os.makedirs(args.save_graph_dir, exist_ok=True) 82 | save_graph_dir = os.path.join(args.save_graph_dir, os.path.splitext(os.path.basename(args.pattern_path))[0]) 83 | os.makedirs(save_graph_dir, exist_ok=True) 84 | if args.save_metadata_dir: 85 | os.makedirs(args.save_metadata_dir, exist_ok=True) 86 | save_metadata_dir = os.path.join(args.save_metadata_dir, os.path.splitext(os.path.basename(args.pattern_path))[0]) 87 | os.makedirs(save_metadata_dir, exist_ok=True) 88 | for g, (graph, metadata) in enumerate(results): 89 | graph_id = "G_N%d_E%d_NL%d_EL%d_%d" % ( 90 | graph.vcount(), graph.ecount(), max(graph.vs["label"])+1, max(graph.es["label"])+1, g) 91 | graph_filename = os.path.join(save_graph_dir, graph_id) 92 | graph.write(graph_filename + ".gml") 93 | if args.save_metadata_dir: 94 | metadata_filename = os.path.join(save_metadata_dir, graph_id) 95 | with open(metadata_filename + ".meta", "w") as f: 96 | json.dump(metadata, f) 97 | if args.save_png: 98 | ig.plot(graph, graph_filename + ".png") 99 | if args.show_img: 100 | draw(graph, pattern, metadata["subisomorphisms"]) 101 | -------------------------------------------------------------------------------- /generator/pattern_checker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import igraph as ig 3 | from collections import Counter 4 | from utils import retrieve_multiple_edges 5 | 6 | INF = float("inf") 7 | 8 | class PatternChecker(object): 9 | def __init__(self): 10 | pass 11 | 12 | @classmethod 13 | def node_compat_fn(cls, g1, g2, v1, v2): 14 | vertex1 = g1.vs[v1] 15 | vertex2 = g2.vs[v2] 16 | return vertex1["label"] == vertex2["label"] 17 | 18 | @classmethod 19 | def edge_compat_fn(cls, g1, g2, e1, e2): 20 | edge1 = g1.es[e1] 21 | edge2 = g2.es[e2] 22 | if edge1.is_loop() != edge2.is_loop(): 23 | return False 24 | # for multiedges 25 | edges1 = retrieve_multiple_edges(g1, edge1.source, edge1.target) 26 | edges2 = retrieve_multiple_edges(g2, edge2.source, edge2.target) 27 | if len(edges1) < len(edges2): 28 | return False 29 | edge1_labels = set(edges1["label"]) 30 | for el in edges2["label"]: 31 | if el not in edge1_labels: 32 | return False 33 | return True 34 | 35 | @classmethod 36 | def get_vertex_color_vectors(cls, g1, g2, seed_v1=-1, seed_v2=-1): 37 | N1 = g1.vcount() 38 | N2 = g2.vcount() 39 | color_vectors = list() 40 | if seed_v1 == -1 and seed_v2 == -1: 41 | color_vectors.append((None, None)) 42 | elif seed_v1 == -1 and seed_v2 != -1: 43 | vertex = g1.vs[seed_v1] 44 | seed_label = vertex["label"] 45 | for seed_v1, vertex in enumerate(g1.vs): 46 | if vertex["label"] == seed_label: 47 | color1 = [0] * N1 48 | color1[seed_v1] = 1 49 | color2 = [0] * N2 50 | color2[seed_v2] = 1 51 | color_vectors.append((color1, color2)) 52 | elif seed_v1 != -1 and seed_v2 == -1: 53 | seed_label = g1.vs[seed_v1]["label"] 54 | for seed_v2, vertex in enumerate(g2.vs): 55 | if vertex["label"] == seed_label: 56 | color1 = [0] * N1 57 | color1[seed_v1] = 1 58 | color2 = [0] * N2 59 | color2[seed_v2] = 1 60 | color_vectors.append((color1, color2)) 61 | else: # seed_v1 != -1 and seed_v2 != -1: 62 | if g1.vs[seed_v1]["label"] == g2.vs[seed_v2]["label"]: 63 | color1 = [0] * N1 64 | color1[seed_v1] = 1 65 | color2 = [0] * N2 66 | color2[seed_v2] = 1 67 | color_vectors.append((color1, color2)) 68 | return color_vectors 69 | 70 | @classmethod 71 | def get_edge_color_vectors(cls, g1, g2, seed_e1=-1, seed_e2=-1): 72 | E1 = len(g1.es) 73 | E2 = len(g2.es) 74 | edge_color_vectors = list() 75 | if seed_e1 == -1 and seed_e2 == -1: 76 | edge_color_vectors.append((None, None)) 77 | elif seed_e1 == -1 and seed_e2 != -1: 78 | edge = g2.es[seed_e2] 79 | color2 = [0] * E2 80 | color2[seed_e2] = 1 81 | seed_label = edge["label"] 82 | is_loop = edge.is_loop() 83 | for seed_e1, edge in enumerate(g1.es): 84 | if edge["label"] == seed_label and is_loop == edge.is_loop(): 85 | color1 = [0] * E1 86 | color1[seed_e1] = 1 87 | edge_color_vectors.append((color1, color2)) 88 | elif seed_e1 != -1 and seed_e2 == -1: 89 | edge = g1.es[seed_e1] 90 | color1 = [0] * E1 91 | color1[seed_e1] = 1 92 | seed_label = edge["label"] 93 | is_loop = edge.is_loop() 94 | for seed_e2, edge in enumerate(g2.es): 95 | if edge["label"] == seed_label and is_loop == edge.is_loop(): 96 | color2 = [0] * E2 97 | color2[seed_e2] = 1 98 | edge_color_vectors.append((color1, color2)) 99 | else: # seed_e1 != -1 and seed_e2 != -1: 100 | edge1 = g1.es[seed_e1] 101 | edge2 = g2.es[seed_e2] 102 | color1 = [0] * E1 103 | color1[seed_e1] = 1 104 | color2 = [0] * E2 105 | color2[seed_e2] = 1 106 | if edge1["label"] == edge2["label"] and edge1.is_loop() == edge2.is_loop(): 107 | edge_color_vectors.append((color1, color2)) 108 | return edge_color_vectors 109 | 110 | def check(self, graph, pattern, **kw): 111 | # valid or not 112 | if graph.vcount() < pattern.vcount(): 113 | return False 114 | if graph.ecount() < pattern.ecount(): 115 | return False 116 | 117 | graph_vlabels = Counter(graph.vs["label"]) 118 | pattern_vlabels = Counter(pattern.vs["label"]) 119 | if len(graph_vlabels) < len(pattern_vlabels): 120 | return False 121 | for vertex_label, pv_cnt in pattern_vlabels.most_common(): 122 | diff = graph_vlabels[vertex_label] - pv_cnt 123 | if diff < 0: 124 | return False 125 | 126 | graph_elabels = Counter(graph.es["label"]) 127 | pattern_elabels = Counter(pattern.es["label"]) 128 | if len(graph_elabels) < len(pattern_elabels): 129 | return False 130 | for edge_label, pe_cnt in pattern_elabels.most_common(): 131 | diff = graph_elabels[edge_label] - pe_cnt 132 | if diff < 0: 133 | return False 134 | return True 135 | 136 | def get_subisomorphisms(self, graph, pattern, **kw): 137 | if not self.check(graph, pattern): 138 | return list() 139 | 140 | seed_v1 = kw.get("seed_v1", -1) 141 | seed_v2 = kw.get("seed_v2", -1) 142 | seed_e1 = kw.get("seed_e1", -1) 143 | seed_e2 = kw.get("seed_e2", -1) 144 | 145 | vertex_color_vectors = PatternChecker.get_vertex_color_vectors(graph, pattern, seed_v1=seed_v1, seed_v2=seed_v2) 146 | edge_color_vectors = PatternChecker.get_edge_color_vectors(graph, pattern, seed_e1=seed_e1, seed_e2=seed_e2) 147 | 148 | vertices_in_graph = list() 149 | if seed_v1 != -1: 150 | vertices_in_graph.append(seed_v1) 151 | if seed_e1 != -1: 152 | vertices_in_graph.extend(graph.es[seed_e1].tuple) 153 | subisomorphisms = list() # [(component, mapping), ...] 154 | for vertex_colors in vertex_color_vectors: 155 | for edge_colors in edge_color_vectors: 156 | for subisomorphism in graph.get_subisomorphisms_vf2(pattern, 157 | color1=vertex_colors[0], color2=vertex_colors[1], 158 | edge_color1=edge_colors[0], edge_color2=edge_colors[1], 159 | node_compat_fn=PatternChecker.node_compat_fn, 160 | edge_compat_fn=PatternChecker.edge_compat_fn): 161 | if len(vertices_in_graph) == 0 or all([v in subisomorphism for v in vertices_in_graph]): 162 | subisomorphisms.append(subisomorphism) 163 | return subisomorphisms 164 | 165 | def count_subisomorphisms(self, graph, pattern, **kw): 166 | if not self.check(graph, pattern): 167 | return 0 168 | 169 | seed_v1 = kw.get("seed_v1", -1) 170 | seed_v2 = kw.get("seed_v2", -1) 171 | seed_e1 = kw.get("seed_e1", -1) 172 | seed_e2 = kw.get("seed_e2", -1) 173 | 174 | vertex_color_vectors = PatternChecker.get_vertex_color_vectors(graph, pattern, seed_v1=seed_v1, seed_v2=seed_v2) 175 | edge_color_vectors = PatternChecker.get_edge_color_vectors(graph, pattern, seed_e1=seed_e1, seed_e2=seed_e2) 176 | 177 | vertices_in_graph = list() 178 | if seed_v1 != -1: 179 | vertices_in_graph.append(seed_v1) 180 | if seed_e1 != -1: 181 | vertices_in_graph.extend(graph.es[seed_e1].tuple) 182 | if len(vertices_in_graph) == 0: 183 | counts = 0 184 | for vertex_colors in vertex_color_vectors: 185 | for edge_colors in edge_color_vectors: 186 | counts += graph.count_subisomorphisms_vf2(pattern, 187 | color1=vertex_colors[0], color2=vertex_colors[1], 188 | edge_color1=edge_colors[0], edge_color2=edge_colors[1], 189 | node_compat_fn=PatternChecker.node_compat_fn, 190 | edge_compat_fn=PatternChecker.edge_compat_fn) 191 | return counts 192 | else: 193 | counts = 0 194 | for vertex_colors in vertex_color_vectors: 195 | for edge_colors in edge_color_vectors: 196 | for subisomorphism in graph.get_subisomorphisms_vf2(pattern, 197 | color1=vertex_colors[0], color2=vertex_colors[1], 198 | edge_color1=edge_colors[0], edge_color2=edge_colors[1], 199 | node_compat_fn=PatternChecker.node_compat_fn, 200 | edge_compat_fn=PatternChecker.edge_compat_fn): 201 | if all([v in subisomorphism for v in vertices_in_graph]): 202 | counts += 1 203 | return counts 204 | 205 | 206 | if __name__ == "__main__": 207 | 208 | graph = ig.read(r"D:\Workspace\GraphPatternMatching\generator\graphs\P$N10$E20$NL10$EL10$0\G$N100$E200$NL10$EL10$0.gml") 209 | pattern = ig.read(r"D:\Workspace\GraphPatternMatching\generator\patterns\P$N10$E20$NL10$EL10$0.gml") 210 | ground_truth = graph.count_subisomorphisms_vf2(pattern, 211 | node_compat_fn=PatternChecker.node_compat_fn, 212 | edge_compat_fn=PatternChecker.edge_compat_fn) 213 | 214 | pc = PatternChecker() 215 | print(len(pc.get_subisomorphisms(graph, pattern)), ground_truth) 216 | -------------------------------------------------------------------------------- /generator/pattern_generator.py: -------------------------------------------------------------------------------- 1 | # import networkx as nx 2 | import igraph as ig 3 | import argparse 4 | import numpy as np 5 | import os 6 | from utils import generate_labels, generate_tree, get_direction, str2bool 7 | from collections import Counter, defaultdict 8 | from time import time 9 | 10 | def generate_patterns(number_of_vertices, number_of_edges, number_of_vertex_labels, number_of_edge_labels, number_of_patterns): 11 | patterns = [] 12 | 13 | for p in range(number_of_patterns): 14 | start = time() 15 | 16 | pattern = ig.Graph(directed=True) 17 | 18 | # vertex labels 19 | vertex_labels = generate_labels(number_of_vertices, number_of_vertex_labels) 20 | # edge labels 21 | edge_labels = generate_labels(number_of_edges, number_of_edge_labels) 22 | 23 | # first, generate a tree 24 | pattern = generate_tree(number_of_vertices, directed=True) 25 | edge_label_mapping = defaultdict(set) 26 | for e, edge in enumerate(pattern.es): 27 | edge_label_mapping[edge.tuple].add(edge_labels[e]) 28 | edge_keys = [0] * (number_of_vertices-1) 29 | 30 | # second, random add edges 31 | ecount = pattern.ecount() 32 | new_edges = list() 33 | while ecount < number_of_edges: 34 | u = np.random.randint(0, number_of_vertices) 35 | v = np.random.randint(0, number_of_vertices) 36 | src_tgt = (u, v) 37 | edge_label = edge_labels[ecount] 38 | # # we do not generate edges between two same vertices with same labels 39 | if edge_label in edge_label_mapping[src_tgt]: 40 | continue 41 | new_edges.append(src_tgt) 42 | edge_keys.append(len(edge_label_mapping[src_tgt])) 43 | edge_label_mapping[src_tgt].add(edge_label) 44 | ecount += 1 45 | pattern.add_edges(new_edges) 46 | pattern.vs["label"] = vertex_labels 47 | pattern.es["label"] = edge_labels 48 | pattern.es["key"] = edge_keys 49 | 50 | patterns.append(pattern) 51 | return patterns 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--seed", type=int, default=0) 56 | parser.add_argument("--number_of_vertices", type=int, default=3) 57 | parser.add_argument("--number_of_edges", type=int, default=3) 58 | parser.add_argument("--number_of_vertex_labels", type=int, default=2) 59 | parser.add_argument("--number_of_edge_labels", type=int, default=2) 60 | parser.add_argument("--number_of_patterns", type=int, default=1) 61 | parser.add_argument("--save_dir", type=str, default="patterns") 62 | parser.add_argument("--save_png", type=str2bool, default=False) 63 | args = parser.parse_args() 64 | 65 | np.random.seed(args.seed) 66 | 67 | patterns = generate_patterns(args.number_of_vertices, args.number_of_edges, 68 | args.number_of_vertex_labels, args.number_of_edge_labels, 69 | args.number_of_patterns) 70 | 71 | if args.save_dir: 72 | os.makedirs(args.save_dir, exist_ok=True) 73 | for p, pattern in enumerate(patterns): 74 | pattern_id = "P_N%d_E%d_NL%d_EL%d_%d" % ( 75 | args.number_of_vertices, args.number_of_edges, args.number_of_vertex_labels, args.number_of_edge_labels, p) 76 | filename = os.path.join(args.save_dir, pattern_id) 77 | # nx.nx_pydot.write_dot(pattern, filename + ".dot") 78 | pattern.write(filename + ".gml") 79 | if args.save_png: 80 | ig.plot(pattern, filename + ".png") 81 | 82 | -------------------------------------------------------------------------------- /generator/run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import os 4 | import igraph as ig 5 | import json 6 | from multiprocessing import Pool 7 | from utils import generate_labels, get_direction 8 | from pattern_checker import PatternChecker 9 | from graph_generator import GraphGenerator 10 | from pattern_generator import generate_patterns 11 | from time import sleep 12 | from tqdm import tqdm 13 | 14 | 15 | DEBUG_CONFIG = { 16 | "max_subgraph": 512, 17 | 18 | "alphas": [0.5], 19 | 20 | "number_of_patterns": 1, 21 | "number_of_pattern_vertices": [3, 4], 22 | "number_of_pattern_edges": [2, 4], 23 | "number_of_pattern_vertex_labels": [2, 4], 24 | "number_of_pattern_edge_labels": [2, 4], 25 | 26 | "number_of_graphs": 10, # train:dev:test = 8:1:1 27 | "number_of_graph_vertices": [16, 64], 28 | "number_of_graph_edges": [16, 64, 256], 29 | "number_of_graph_vertex_labels": [4, 8], 30 | "number_of_graph_edge_labels": [4, 8], 31 | 32 | "max_ratio_of_edges_vertices": 4, 33 | "max_pattern_counts": 1024, 34 | 35 | "save_data_dir": r"../data/debug", 36 | "num_workers": 16 37 | } 38 | 39 | SMALL_CONFIG = { 40 | "max_subgraph": 512, 41 | 42 | "alphas": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], 43 | 44 | "number_of_patterns": 3, 45 | "number_of_pattern_vertices": [3, 4, 8], 46 | "number_of_pattern_edges": [2, 4, 8], 47 | "number_of_pattern_vertex_labels": [2, 4, 8], 48 | "number_of_pattern_edge_labels": [2, 4, 8], 49 | 50 | "number_of_graphs": 10, # train:dev:test = 8:1:1 51 | "number_of_graph_vertices": [8, 16, 32, 64], 52 | "number_of_graph_edges": [8, 16, 32, 64, 128, 256], 53 | "number_of_graph_vertex_labels": [4, 8, 16], 54 | "number_of_graph_edge_labels": [4, 8, 16], 55 | 56 | "max_ratio_of_edges_vertices": 4, 57 | "max_pattern_counts": 1024, 58 | 59 | "save_data_dir": r"/data/xliucr/SubIsoCnt/small", 60 | "num_workers": 16 61 | } 62 | 63 | LARGE_CONFIG = { 64 | "max_subgraph": 512, 65 | 66 | "alphas": [0.05, 0.1, 0.15], 67 | 68 | "number_of_patterns": 2, 69 | "number_of_pattern_vertices": [3, 4, 8, 16], 70 | "number_of_pattern_edges": [2, 4, 8, 16], 71 | "number_of_pattern_vertex_labels": [2, 4, 8, 16], 72 | "number_of_pattern_edge_labels": [2, 4, 8, 16], 73 | 74 | "number_of_graphs": 10, # train:dev:test = 8:1:1 75 | "number_of_graph_vertices": [64, 128, 256, 512], 76 | "number_of_graph_edges": [64, 128, 256, 512, 1024, 2048], 77 | "number_of_graph_vertex_labels": [16, 32, 64], 78 | "number_of_graph_edge_labels": [16, 32, 64], 79 | 80 | "max_ratio_of_edges_vertices": 4, 81 | "max_pattern_counts": 4096, 82 | 83 | "save_data_dir": r"/data/xliucr/SubIsoCnt/large", 84 | "num_workers": 16 85 | } 86 | 87 | CONFIG = DEBUG_CONFIG 88 | 89 | def generate_graphs(graph_generator, number_of_graph_vertices, number_of_graph_edges, number_of_graph_vertex_labels, number_of_graph_edge_labels, 90 | alpha, max_pattern_counts, max_subgraph, number_of_graphs, save_graph_dir, save_metadata_dir): 91 | graphs_id = "G_N%d_E%d_NL%d_EL%d_A%.2f" % ( 92 | number_of_graph_vertices, number_of_graph_edges, number_of_graph_vertex_labels, number_of_graph_edge_labels, alpha) 93 | # print(graphs_id) 94 | for g in range(number_of_graphs): 95 | graph, metadata = graph_generator.generate( 96 | number_of_graph_vertices, number_of_graph_edges, number_of_graph_vertex_labels, number_of_graph_edge_labels, 97 | alpha, max_pattern_counts=max_pattern_counts, max_subgraph=max_subgraph, return_subisomorphisms=True) 98 | graph.write(os.path.join(save_graph_dir, graphs_id + "_%d.gml" % (g))) 99 | with open(os.path.join(save_metadata_dir, graphs_id + "_%d.meta" % (g)), "w") as f: 100 | json.dump(metadata, f) 101 | return graphs_id 102 | 103 | if __name__ == "__main__": 104 | save_pattern_dir = os.path.join(CONFIG["save_data_dir"], "patterns") 105 | save_graph_dir = os.path.join(CONFIG["save_data_dir"], "graphs") 106 | save_metadata_dir = os.path.join(CONFIG["save_data_dir"], "metadata") 107 | os.makedirs(CONFIG["save_data_dir"], exist_ok=True) 108 | os.makedirs(save_pattern_dir, exist_ok=True) 109 | os.makedirs(save_graph_dir, exist_ok=True) 110 | os.makedirs(save_metadata_dir, exist_ok=True) 111 | 112 | np.random.seed(0) 113 | 114 | pattern_cnt = 0 115 | for number_of_pattern_vertices in CONFIG["number_of_pattern_vertices"]: 116 | for number_of_pattern_vertex_labels in CONFIG["number_of_pattern_vertex_labels"]: 117 | if number_of_pattern_vertex_labels > number_of_pattern_vertices: 118 | continue 119 | for number_of_pattern_edges in CONFIG["number_of_pattern_edges"]: 120 | if number_of_pattern_edges < number_of_pattern_vertices - 1: # not connected 121 | continue 122 | if number_of_pattern_edges > CONFIG["max_ratio_of_edges_vertices"] * number_of_pattern_vertices: # too dense 123 | continue 124 | for number_of_pattern_edge_labels in CONFIG["number_of_pattern_edge_labels"]: 125 | if number_of_pattern_edge_labels > number_of_pattern_edges: 126 | continue 127 | patterns_id = "P_N%d_E%d_NL%d_EL%d" % ( 128 | number_of_pattern_vertices, number_of_pattern_edges, number_of_pattern_vertex_labels, number_of_pattern_edge_labels) 129 | for p, pattern in enumerate(generate_patterns( 130 | number_of_pattern_vertices, number_of_pattern_edges, number_of_pattern_vertex_labels, number_of_pattern_edge_labels, 131 | CONFIG["number_of_patterns"])): 132 | pattern.write(os.path.join(save_pattern_dir, patterns_id + "_%d.gml" % (p))) 133 | pattern_cnt += CONFIG["number_of_patterns"] 134 | print("patterns_id", patterns_id) 135 | print("%d patterns generation finished!" % (pattern_cnt)) 136 | 137 | graph_cnt = 0 138 | pool = Pool(CONFIG["num_workers"]) 139 | results = list() 140 | for number_of_pattern_vertices in CONFIG["number_of_pattern_vertices"]: 141 | for number_of_pattern_vertex_labels in CONFIG["number_of_pattern_vertex_labels"]: 142 | if number_of_pattern_vertex_labels > number_of_pattern_vertices: 143 | continue 144 | for number_of_pattern_edges in CONFIG["number_of_pattern_edges"]: 145 | if number_of_pattern_edges < number_of_pattern_vertices - 1: # not connected 146 | continue 147 | if number_of_pattern_edges > CONFIG["max_ratio_of_edges_vertices"] * number_of_pattern_vertices: # too dense 148 | continue 149 | for number_of_pattern_edge_labels in CONFIG["number_of_pattern_edge_labels"]: 150 | if number_of_pattern_edge_labels > number_of_pattern_edges: 151 | continue 152 | patterns_id = "P_N%d_E%d_NL%d_EL%d" % ( 153 | number_of_pattern_vertices, number_of_pattern_edges, number_of_pattern_vertex_labels, number_of_pattern_edge_labels) 154 | graph_generators = list() 155 | for p in range(CONFIG["number_of_patterns"]): 156 | pattern = ig.read(os.path.join(save_pattern_dir, patterns_id + "_%d.gml" % (p))) 157 | pattern.vs["label"] = [int(x) for x in pattern.vs["label"]] 158 | pattern.es["label"] = [int(x) for x in pattern.es["label"]] 159 | pattern.es["key"] = [int(x) for x in pattern.es["key"]] 160 | graph_generators.append(GraphGenerator(pattern)) 161 | for alpha in CONFIG["alphas"]: 162 | for number_of_graph_vertices in CONFIG["number_of_graph_vertices"]: 163 | if number_of_graph_vertices < number_of_pattern_vertices: 164 | continue 165 | for number_of_graph_vertex_labels in CONFIG["number_of_graph_vertex_labels"]: 166 | if number_of_graph_vertex_labels > number_of_graph_vertices: 167 | continue 168 | if number_of_graph_vertex_labels < number_of_pattern_vertex_labels: 169 | continue 170 | for number_of_graph_edges in CONFIG["number_of_graph_edges"]: 171 | if number_of_graph_edges < number_of_graph_vertices - 1: # not connected 172 | continue 173 | if number_of_graph_edges > CONFIG["max_ratio_of_edges_vertices"] * number_of_graph_vertices: # too dense 174 | continue 175 | if number_of_graph_edges < number_of_pattern_edges: 176 | continue 177 | for number_of_graph_edge_labels in CONFIG["number_of_graph_edge_labels"]: 178 | if number_of_graph_edge_labels > number_of_graph_edges: 179 | continue 180 | if number_of_graph_edge_labels < number_of_pattern_edge_labels: 181 | continue 182 | for p, graph_generator in enumerate(graph_generators): 183 | save_graph_dir_p = os.path.join(save_graph_dir, patterns_id + "_%d" % (p)) 184 | save_metadata_dir_p = os.path.join(save_metadata_dir, patterns_id + "_%d" % (p)) 185 | if not os.path.isdir(save_graph_dir_p): 186 | os.mkdir(save_graph_dir_p) 187 | if not os.path.isdir(save_metadata_dir_p): 188 | os.mkdir(save_metadata_dir_p) 189 | results.append( 190 | pool.apply_async(generate_graphs, args=( 191 | graph_generator, number_of_graph_vertices, number_of_graph_edges, 192 | number_of_graph_vertex_labels, number_of_graph_edge_labels, 193 | alpha, CONFIG["max_pattern_counts"], CONFIG["max_subgraph"], 194 | CONFIG["number_of_graphs"], save_graph_dir_p, save_metadata_dir_p))) 195 | # generate_graphs( 196 | # graph_generator, number_of_graph_vertices, number_of_graph_edges, 197 | # number_of_graph_vertex_labels, number_of_graph_edge_labels, 198 | # alpha, CONFIG["max_pattern_counts"], CONFIG["max_subgraph"], 199 | # CONFIG["number_of_graphs"], save_graph_dir_p, save_metadata_dir_p) 200 | graph_cnt += CONFIG["number_of_graphs"] 201 | pool.close() 202 | # pool.join() 203 | for x in tqdm(results): 204 | x.get() 205 | print("%d graphs generation finished!" % (graph_cnt)) 206 | -------------------------------------------------------------------------------- /generator/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import igraph as ig 3 | import json 4 | from itertools import chain, combinations 5 | 6 | def generate_png(dot_filename, png_filename=None, prog="neato"): 7 | if png_filename is None: 8 | png_filename = dot_filename.replace(".dot", ".png") 9 | os.system("%s.exe -T png %s > %s" % (prog, dot_filename, png_filename)) 10 | 11 | def generate_labels(number_of_items, number_of_labels): 12 | labels = list(range(number_of_labels)) 13 | if number_of_items < number_of_labels: 14 | np.random.shuffle(labels) 15 | labels = labels[:number_of_items] 16 | else: 17 | for i in range(number_of_labels, number_of_items): 18 | labels.append(np.random.randint(number_of_labels)) 19 | np.random.shuffle(labels) 20 | return labels 21 | 22 | def generate_tree(number_of_vertices, directed=True): 23 | # Alexey S. Rodionov and Hyunseung Choo, On Generating Random Network Structures: Trees, ICCS 2003, LNCS 2658, pp. 879-887, 2003. 24 | # [connected vertices] + [unconnected vertices] 25 | shuffle_vertices = list(range(number_of_vertices)) 26 | np.random.shuffle(shuffle_vertices) 27 | # randomly choose one vertex from the connected vertex set 28 | # randomly choose one vertex from the unconnected vertex set 29 | # connect them by one edge 30 | # add the latter vertex in the connected vertex set 31 | edges = list() 32 | for v in range(1, number_of_vertices): 33 | u = shuffle_vertices[np.random.randint(0, v)] 34 | v = shuffle_vertices[v] 35 | if get_direction(): 36 | src_tgt = (u, v) 37 | else: 38 | src_tgt = (v, u) 39 | edges.append(src_tgt) 40 | tree = ig.Graph(directed=directed) 41 | tree.add_vertices(number_of_vertices) 42 | tree.add_edges(edges) 43 | return tree 44 | 45 | def get_direction(): 46 | return np.random.randint(0, 2) 47 | 48 | def retrieve_multiple_edges(graph, source=-1, target=-1): 49 | if source != -1: 50 | e = graph.incident(source, mode=ig.OUT) 51 | if target != -1: 52 | e = set(e).intersection(graph.incident(target, mode=ig.IN)) 53 | return ig.EdgeSeq(graph, e) 54 | else: 55 | if target != -1: 56 | e = graph.incident(target, mode=ig.IN) 57 | else: 58 | e = list() 59 | return ig.EdgeSeq(graph, e) 60 | 61 | def str2bool(x): 62 | x = x.lower() 63 | return x == "true" or x == "yes" or x == "t" 64 | 65 | def sample_element(s): 66 | index = np.random.randint(0, len(s)) 67 | return s[index] 68 | 69 | def powerset(iterable, min_size=0, max_size=-1): 70 | "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" 71 | s = sorted(iterable) 72 | if max_size == -1: 73 | max_size = len(s) 74 | return chain.from_iterable(combinations(s, r) for r in range(min_size, max_size+1)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | pandas 4 | scipy 5 | tensorboardX 6 | torch==1.4.0 7 | dgl==0.4.3post2 8 | 9 | -------------------------------------------------------------------------------- /src/basemodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import copy 6 | import numpy as np 7 | from utils import int2onehot 8 | from utils import get_enc_len, split_and_batchify_graph_feats, batch_convert_len_to_mask 9 | from embedding import OrthogonalEmbedding, NormalEmbedding, EquivariantEmbedding 10 | from filternet import MaxGatedFilterNet 11 | from predictnet import MeanPredictNet, SumPredictNet, MaxPredictNet, \ 12 | MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, \ 13 | MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, \ 14 | DIAMNet 15 | 16 | class BaseModel(nn.Module): 17 | def __init__(self, config): 18 | super(BaseModel, self).__init__() 19 | 20 | self.act_func = config["activation_function"] 21 | self.init_emb = config["init_emb"] 22 | self.share_emb = config["share_emb"] 23 | self.share_arch = config["share_arch"] 24 | self.base = config["base"] 25 | self.max_ngv = config["max_ngv"] 26 | self.max_ngvl = config["max_ngvl"] 27 | self.max_nge = config["max_nge"] 28 | self.max_ngel = config["max_ngel"] 29 | self.max_npv = config["max_npv"] 30 | self.max_npvl = config["max_npvl"] 31 | self.max_npe = config["max_npe"] 32 | self.max_npel = config["max_npel"] 33 | 34 | self.emb_dim = config["emb_dim"] 35 | self.dropout = config["dropout"] 36 | self.dropatt = config["dropatt"] 37 | self.add_enc = config["predict_net_add_enc"] 38 | 39 | # create encoding layer 40 | # create filter layers 41 | # create embedding layers 42 | # create networks 43 | self.p_net, self.g_net = None, None 44 | # create predict layers 45 | self.predict_net = None 46 | 47 | def get_emb_dim(self): 48 | if self.init_emb == "None": 49 | return self.get_enc_dim() 50 | else: 51 | return self.emb_dim, self.emb_dim 52 | 53 | def get_enc(self, pattern, pattern_len, graph, graph_len): 54 | raise NotImplementedError 55 | 56 | def get_emb(self, pattern, pattern_len, graph, graph_len): 57 | raise NotImplementedError 58 | 59 | def get_filter_gate(self, pattern, pattern_len, graph, graph_len): 60 | raise NotImplementedError 61 | 62 | def create_filter(self, filter_type): 63 | if filter_type == "None": 64 | filter_net = None 65 | elif filter_type == "MaxGatedFilterNet": 66 | filter_net = MaxGatedFilterNet() 67 | else: 68 | raise NotImplementedError("Currently, %s is not supported!" % (filter_type)) 69 | return filter_net 70 | 71 | def create_enc(self, max_n, base): 72 | enc_len = get_enc_len(max_n-1, base) 73 | enc_dim = enc_len * base 74 | enc = nn.Embedding(max_n, enc_dim) 75 | enc.weight.data.copy_(torch.from_numpy(int2onehot(np.arange(0, max_n), enc_len, base))) 76 | enc.weight.requires_grad=False 77 | return enc 78 | 79 | def create_emb(self, input_dim, emb_dim, init_emb="Orthogonal"): 80 | if init_emb == "None": 81 | emb = None 82 | elif init_emb == "Orthogonal": 83 | emb = OrthogonalEmbedding(input_dim, emb_dim) 84 | elif init_emb == "Normal": 85 | emb = NormalEmbedding(input_dim, emb_dim) 86 | elif init_emb == "Equivariant": 87 | emb = EquivariantEmbedding(input_dim, emb_dim) 88 | else: 89 | raise NotImplementedError 90 | return emb 91 | 92 | def create_net(self, name, input_dim, **kw): 93 | raise NotImplementedError 94 | 95 | def create_predict_net(self, predict_type, pattern_dim, graph_dim, **kw): 96 | if predict_type == "None": 97 | predict_net = None 98 | elif predict_type == "MeanPredictNet": 99 | hidden_dim = kw.get("hidden_dim", 64) 100 | predict_net = MeanPredictNet(pattern_dim, graph_dim, hidden_dim, 101 | act_func=self.act_func, dropout=self.dropout) 102 | elif predict_type == "SumPredictNet": 103 | hidden_dim = kw.get("hidden_dim", 64) 104 | predict_net = SumPredictNet(pattern_dim, graph_dim, hidden_dim, 105 | act_func=self.act_func, dropout=self.dropout) 106 | elif predict_type == "MaxPredictNet": 107 | hidden_dim = kw.get("hidden_dim", 64) 108 | predict_net = MaxPredictNet(pattern_dim, graph_dim, hidden_dim, 109 | act_func=self.act_func, dropout=self.dropout) 110 | elif predict_type == "MeanAttnPredictNet": 111 | hidden_dim = kw.get("hidden_dim", 64) 112 | recurrent_steps = kw.get("recurrent_steps", 1) 113 | num_heads = kw.get("num_heads", 1) 114 | predict_net = MeanAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 115 | act_func=self.act_func, 116 | num_heads=num_heads, recurrent_steps=recurrent_steps, 117 | dropout=self.dropout, dropatt=self.dropatt) 118 | elif predict_type == "SumAttnPredictNet": 119 | hidden_dim = kw.get("hidden_dim", 64) 120 | recurrent_steps = kw.get("recurrent_steps", 1) 121 | num_heads = kw.get("num_heads", 1) 122 | predict_net = SumAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 123 | act_func=self.act_func, 124 | num_heads=num_heads, recurrent_steps=recurrent_steps, 125 | dropout=self.dropout, dropatt=self.dropatt) 126 | elif predict_type == "MaxAttnPredictNet": 127 | hidden_dim = kw.get("hidden_dim", 64) 128 | recurrent_steps = kw.get("recurrent_steps", 1) 129 | num_heads = kw.get("num_heads", 1) 130 | predict_net = MaxAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 131 | act_func=self.act_func, 132 | num_heads=num_heads, recurrent_steps=recurrent_steps, 133 | dropout=self.dropout, dropatt=self.dropatt) 134 | elif predict_type == "MeanMemAttnPredictNet": 135 | hidden_dim = kw.get("hidden_dim", 64) 136 | recurrent_steps = kw.get("recurrent_steps", 1) 137 | num_heads = kw.get("num_heads", 1) 138 | mem_len = kw.get("mem_len", 4) 139 | predict_net = MeanMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 140 | act_func=self.act_func, 141 | num_heads=num_heads, recurrent_steps=recurrent_steps, 142 | mem_len=mem_len, 143 | dropout=self.dropout, dropatt=self.dropatt) 144 | elif predict_type == "SumMemAttnPredictNet": 145 | hidden_dim = kw.get("hidden_dim", 64) 146 | recurrent_steps = kw.get("recurrent_steps", 1) 147 | num_heads = kw.get("num_heads", 1) 148 | mem_len = kw.get("mem_len", 4) 149 | predict_net = SumMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 150 | act_func=self.act_func, 151 | num_heads=num_heads, recurrent_steps=recurrent_steps, 152 | mem_len=mem_len, 153 | dropout=self.dropout, dropatt=self.dropatt) 154 | elif predict_type == "MaxMemAttnPredictNet": 155 | hidden_dim = kw.get("hidden_dim", 64) 156 | recurrent_steps = kw.get("recurrent_steps", 1) 157 | num_heads = kw.get("num_heads", 1) 158 | mem_len = kw.get("mem_len", 4) 159 | predict_net = MaxMemAttnPredictNet(pattern_dim, graph_dim, hidden_dim, 160 | act_func=self.act_func, 161 | num_heads=num_heads, recurrent_steps=recurrent_steps, 162 | mem_len=mem_len, 163 | dropout=self.dropout, dropatt=self.dropatt) 164 | elif predict_type == "DIAMNet": 165 | hidden_dim = kw.get("hidden_dim", 64) 166 | recurrent_steps = kw.get("recurrent_steps", 1) 167 | num_heads = kw.get("num_heads", 1) 168 | mem_len = kw.get("mem_len", 4) 169 | mem_init = kw.get("mem_init", "mean") 170 | predict_net = DIAMNet(pattern_dim, graph_dim, hidden_dim, 171 | act_func=self.act_func, 172 | num_heads=num_heads, recurrent_steps=recurrent_steps, 173 | mem_len=mem_len, mem_init=mem_init, 174 | dropout=self.dropout, dropatt=self.dropatt) 175 | else: 176 | raise NotImplementedError("Currently, %s is not supported!" % (predict_type)) 177 | return predict_net 178 | 179 | def increase_input_size(self, config): 180 | assert config["base"] == self.base 181 | assert config["max_npv"] >= self.max_npv 182 | assert config["max_npvl"] >= self.max_npvl 183 | assert config["max_npe"] >= self.max_npe 184 | assert config["max_npel"] >= self.max_npel 185 | assert config["max_ngv"] >= self.max_ngv 186 | assert config["max_ngvl"] >= self.max_ngvl 187 | assert config["max_nge"] >= self.max_nge 188 | assert config["max_ngel"] >= self.max_ngel 189 | assert config["predict_net_add_enc"] or not self.add_enc 190 | assert config["predict_net_add_degree"] or not self.add_degree 191 | 192 | # create encoding layers 193 | # increase embedding layers 194 | # increase predict network 195 | # set new parameters 196 | 197 | def increase_net(self, config): 198 | raise NotImplementedError 199 | 200 | 201 | class EdgeSeqModel(BaseModel): 202 | def __init__(self, config): 203 | super(EdgeSeqModel, self).__init__(config) 204 | # create encoding layer 205 | self.g_v_enc, self.g_vl_enc, self.g_el_enc = \ 206 | [self.create_enc(max_n, self.base) for max_n in [self.max_ngv, self.max_ngvl, self.max_ngel]] 207 | self.g_u_enc, self.g_ul_enc = self.g_v_enc, self.g_vl_enc 208 | if self.share_emb: 209 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = \ 210 | self.g_v_enc, self.g_vl_enc, self.g_el_enc 211 | else: 212 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = \ 213 | [self.create_enc(max_n, self.base) for max_n in [self.max_npv, self.max_npvl, self.max_npel]] 214 | self.p_u_enc, self.p_ul_enc = self.p_v_enc, self.p_vl_enc 215 | 216 | # create filter layers 217 | self.ul_flt, self.el_flt, self.vl_flt = [self.create_filter(config["filter_net"]) for _ in range(3)] 218 | 219 | # create embedding layers 220 | self.g_u_emb, self.g_v_emb, self.g_ul_emb, self.g_el_emb, self.g_vl_emb = \ 221 | [self.create_emb(enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) \ 222 | for enc in [self.g_u_enc, self.g_v_enc, self.g_ul_enc, self.g_el_enc, self.g_vl_enc]] 223 | if self.share_emb: 224 | self.p_u_emb, self.p_v_emb, self.p_ul_emb, self.p_el_emb, self.p_vl_emb = \ 225 | self.g_u_emb, self.g_v_emb, self.g_ul_emb, self.g_el_emb, self.g_vl_emb 226 | else: 227 | self.p_u_emb, self.p_v_emb, self.p_ul_emb, self.p_el_emb, self.p_vl_emb = \ 228 | [self.create_emb(enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) \ 229 | for enc in [self.p_u_enc, self.p_v_enc, self.p_ul_enc, self.p_el_enc, self.p_vl_enc]] 230 | 231 | # create networks 232 | # create predict layers 233 | 234 | def get_enc_dim(self): 235 | g_dim = self.base * (get_enc_len(self.max_ngv-1, self.base) * 2 + \ 236 | get_enc_len(self.max_ngvl-1, self.base) * 2 + \ 237 | get_enc_len(self.max_ngel-1, self.base)) 238 | if self.share_emb: 239 | return g_dim, g_dim 240 | else: 241 | p_dim = self.base * (get_enc_len(self.max_npv-1, self.base) * 2 + \ 242 | get_enc_len(self.max_npvl-1, self.base) * 2 + \ 243 | get_enc_len(self.max_npel-1, self.base)) 244 | return p_dim, g_dim 245 | 246 | def get_emb_dim(self): 247 | if self.init_emb == "None": 248 | return self.get_enc_dim() 249 | else: 250 | return self.emb_dim, self.emb_dim 251 | 252 | def get_enc(self, pattern, pattern_len, graph, graph_len): 253 | pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl = \ 254 | self.p_u_enc(pattern.u), self.p_v_enc(pattern.v), self.p_ul_enc(pattern.ul), self.p_el_enc(pattern.el), self.p_vl_enc(pattern.vl) 255 | graph_u, graph_v, graph_ul, graph_el, graph_vl = \ 256 | self.g_u_enc(graph.u), self.g_v_enc(graph.v), self.g_ul_enc(graph.ul), self.g_el_enc(graph.el), self.g_vl_enc(graph.vl) 257 | 258 | p_enc = torch.cat([ 259 | pattern_u, 260 | pattern_v, 261 | pattern_ul, 262 | pattern_el, 263 | pattern_vl], dim=2) 264 | g_enc = torch.cat([ 265 | graph_u, 266 | graph_v, 267 | graph_ul, 268 | graph_el, 269 | graph_vl], dim=2) 270 | return p_enc, g_enc 271 | 272 | def get_emb(self, pattern, pattern_len, graph, graph_len): 273 | bsz = pattern_len.size(0) 274 | pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl = \ 275 | self.p_u_enc(pattern.u), self.p_v_enc(pattern.v), self.p_ul_enc(pattern.ul), self.p_el_enc(pattern.el), self.p_vl_enc(pattern.vl) 276 | graph_u, graph_v, graph_ul, graph_el, graph_vl = \ 277 | self.g_u_enc(graph.u), self.g_v_enc(graph.v), self.g_ul_enc(graph.ul), self.g_el_enc(graph.el), self.g_vl_enc(graph.vl) 278 | 279 | if self.init_emb == "None": 280 | p_emb = torch.cat([pattern_u, pattern_v, pattern_ul, pattern_el, pattern_vl], dim=2) 281 | g_emb = torch.cat([graph_u, graph_v, graph_ul, graph_el, graph_vl], dim=2) 282 | else: 283 | p_emb = self.p_u_emb(pattern_u) + \ 284 | self.p_v_emb(pattern_v) + \ 285 | self.p_ul_emb(pattern_ul) + \ 286 | self.p_el_emb(pattern_el) + \ 287 | self.p_vl_emb(pattern_vl) 288 | g_emb = self.g_u_emb(graph_u) + \ 289 | self.g_v_emb(graph_v) + \ 290 | self.g_ul_emb(graph_ul) + \ 291 | self.g_el_emb(graph_el) + \ 292 | self.g_vl_emb(graph_vl) 293 | return p_emb, g_emb 294 | 295 | def get_filter_gate(self, pattern, pattern_len, graph, graph_len): 296 | gate = None 297 | if self.ul_flt is not None: 298 | if gate is not None: 299 | gate &= self.ul_flt(pattern.ul, graph.ul) 300 | else: 301 | gate = self.ul_flt(pattern.ul, graph.ul) 302 | if self.el_flt is not None: 303 | if gate is not None: 304 | gate &= self.el_flt(pattern.el, graph.el) 305 | else: 306 | gate = self.el_flt(pattern.el, graph.el) 307 | if self.vl_flt is not None: 308 | if gate is not None: 309 | gate &= self.vl_flt(pattern.vl, graph.vl) 310 | else: 311 | gate = self.vl_flt(pattern.vl, graph.vl) 312 | return gate 313 | 314 | def increase_input_size(self, config): 315 | super(EdgeSeqModel, self).increase_input_size(config) 316 | 317 | # create encoding layers 318 | new_g_v_enc, new_g_vl_enc, new_g_el_enc = \ 319 | [self.create_enc(max_n, self.base) for max_n in [config["max_ngv"], config["max_ngvl"], config["max_ngel"]]] 320 | if self.share_emb: 321 | new_p_v_enc, new_p_vl_enc, new_p_el_enc = \ 322 | new_g_v_enc, new_g_vl_enc, new_g_el_enc 323 | else: 324 | new_p_v_enc, new_p_vl_enc, new_p_el_enc = \ 325 | [self.create_enc(max_n, self.base) for max_n in [config["max_npv"], config["max_npvl"], config["max_npel"]]] 326 | del self.g_v_enc, self.g_vl_enc, self.g_el_enc, self.g_u_enc, self.g_ul_enc 327 | del self.p_v_enc, self.p_vl_enc, self.p_el_enc, self.p_u_enc, self.p_ul_enc 328 | self.g_v_enc, self.g_vl_enc, self.g_el_enc = new_g_v_enc, new_g_vl_enc, new_g_el_enc 329 | self.g_u_enc, self.g_ul_enc = self.g_v_enc, self.g_vl_enc 330 | self.p_v_enc, self.p_vl_enc, self.p_el_enc = new_p_v_enc, new_p_vl_enc, new_p_el_enc 331 | self.p_u_enc, self.p_ul_enc = self.p_v_enc, self.p_vl_enc 332 | 333 | # increase embedding layers 334 | self.g_u_emb.increase_input_size(self.g_u_enc.embedding_dim) 335 | self.g_v_emb.increase_input_size(self.g_v_enc.embedding_dim) 336 | self.g_ul_emb.increase_input_size(self.g_ul_enc.embedding_dim) 337 | self.g_vl_emb.increase_input_size(self.g_vl_enc.embedding_dim) 338 | self.g_el_emb.increase_input_size(self.g_el_enc.embedding_dim) 339 | if not self.share_emb: 340 | self.p_u_emb.increase_input_size(self.p_u_enc.embedding_dim) 341 | self.p_v_emb.increase_input_size(self.p_v_enc.embedding_dim) 342 | self.p_ul_emb.increase_input_size(self.p_ul_enc.embedding_dim) 343 | self.p_vl_emb.increase_input_size(self.p_vl_enc.embedding_dim) 344 | self.p_el_emb.increase_input_size(self.p_el_enc.embedding_dim) 345 | 346 | # increase predict network 347 | 348 | # set new parameters 349 | self.max_npv = config["max_npv"] 350 | self.max_npvl = config["max_npvl"] 351 | self.max_npe = config["max_npe"] 352 | self.max_npel = config["max_npel"] 353 | self.max_ngv = config["max_ngv"] 354 | self.max_ngvl = config["max_ngvl"] 355 | self.max_nge = config["max_nge"] 356 | self.max_ngel = config["max_ngel"] 357 | 358 | 359 | 360 | class GraphAdjModel(BaseModel): 361 | def __init__(self, config): 362 | super(GraphAdjModel, self).__init__(config) 363 | 364 | self.add_degree = config["predict_net_add_degree"] 365 | 366 | # create encoding layer 367 | self.g_v_enc, self.g_vl_enc = \ 368 | [self.create_enc(max_n, self.base) for max_n in [self.max_ngv, self.max_ngvl]] 369 | if self.share_emb: 370 | self.p_v_enc, self.p_vl_enc = \ 371 | self.g_v_enc, self.g_vl_enc 372 | else: 373 | self.p_v_enc, self.p_vl_enc = \ 374 | [self.create_enc(max_n, self.base) for max_n in [self.max_npv, self.max_npvl]] 375 | 376 | # create filter layers 377 | self.vl_flt = self.create_filter(config["filter_net"]) 378 | 379 | # create embedding layers 380 | self.g_vl_emb = self.create_emb(self.g_vl_enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) 381 | if self.share_emb: 382 | self.p_vl_emb = self.g_vl_emb 383 | else: 384 | self.p_vl_emb = self.create_emb(self.p_vl_enc.embedding_dim, self.emb_dim, init_emb=self.init_emb) 385 | 386 | # create networks 387 | # create predict layers 388 | 389 | def get_enc_dim(self): 390 | g_dim = self.base * (get_enc_len(self.max_ngv-1, self.base) + get_enc_len(self.max_ngvl-1, self.base)) 391 | if self.share_emb: 392 | return g_dim, g_dim 393 | else: 394 | p_dim = self.base * (get_enc_len(self.max_npv-1, self.base) + get_enc_len(self.max_npvl-1, self.base)) 395 | return p_dim, g_dim 396 | 397 | def get_enc(self, pattern, pattern_len, graph, graph_len): 398 | bsz = pattern_len.size(0) 399 | 400 | pattern_v, pattern_vl = self.p_v_enc(pattern.ndata["id"]), self.p_vl_enc(pattern.ndata["label"]) 401 | graph_v, graph_vl = self.g_v_enc(graph.ndata["id"]), self.g_vl_enc(graph.ndata["label"]) 402 | 403 | p_enc = torch.cat([pattern_v, pattern_vl], dim=1) 404 | g_enc = torch.cat([graph_v, graph_vl], dim=1) 405 | return p_enc, g_enc 406 | 407 | def get_emb(self, pattern, pattern_len, graph, graph_len): 408 | bsz = pattern_len.size(0) 409 | 410 | pattern_v, pattern_vl = self.p_v_enc(pattern.ndata["id"]), self.p_vl_enc(pattern.ndata["label"]) 411 | graph_v, graph_vl = self.g_v_enc(graph.ndata["id"]), self.g_vl_enc(graph.ndata["label"]) 412 | 413 | if self.init_emb == "None": 414 | p_emb = pattern_vl 415 | g_emb = graph_vl 416 | else: 417 | p_emb = self.p_vl_emb(pattern_vl) 418 | g_emb = self.g_vl_emb(graph_vl) 419 | return p_emb, g_emb 420 | 421 | def get_filter_gate(self, pattern, pattern_len, graph, graph_len): 422 | gate = None 423 | if self.vl_flt is not None: 424 | gate = self.vl_flt( 425 | split_and_batchify_graph_feats(pattern.ndata["label"].unsqueeze(-1), pattern_len)[0], 426 | split_and_batchify_graph_feats(graph.ndata["label"].unsqueeze(-1), graph_len)[0]) 427 | if gate is not None: 428 | bsz = graph_len.size(0) 429 | max_g_len = graph_len.max() 430 | if bsz * max_g_len != graph.number_of_nodes(): 431 | graph_mask = batch_convert_len_to_mask(graph_len) # bsz x max_len 432 | gate = gate.masked_select(graph_mask.unsqueeze(-1)).view(-1, 1) 433 | else: 434 | gate = gate.view(-1, 1) 435 | return gate 436 | 437 | def increase_input_size(self, config): 438 | super(GraphAdjModel, self).increase_input_size(config) 439 | 440 | # create encoding layers 441 | new_g_v_enc, new_g_vl_enc = \ 442 | [self.create_enc(max_n, self.base) for max_n in [config["max_ngv"], config["max_ngvl"]]] 443 | if self.share_emb: 444 | new_p_v_enc, new_p_vl_enc = \ 445 | new_g_v_enc, new_g_vl_enc 446 | else: 447 | new_p_v_enc, new_p_vl_enc = \ 448 | [self.create_enc(max_n, self.base) for max_n in [config["max_npv"], config["max_npvl"]]] 449 | del self.g_v_enc, self.g_vl_enc 450 | del self.p_v_enc, self.p_vl_enc 451 | self.g_v_enc, self.g_vl_enc = new_g_v_enc, new_g_vl_enc 452 | self.p_v_enc, self.p_vl_enc = new_p_v_enc, new_p_vl_enc 453 | 454 | # increase embedding layers 455 | self.g_vl_emb.increase_input_size(self.g_vl_enc.embedding_dim) 456 | if not self.share_emb: 457 | self.p_vl_emb.increase_input_size(self.p_vl_enc.embedding_dim) 458 | 459 | # increase networks 460 | 461 | # increase predict network 462 | 463 | # set new parameters 464 | self.max_npv = config["max_npv"] 465 | self.max_npvl = config["max_npvl"] 466 | self.max_npe = config["max_npe"] 467 | self.max_npel = config["max_npel"] 468 | self.max_ngv = config["max_ngv"] 469 | self.max_ngvl = config["max_ngvl"] 470 | self.max_nge = config["max_nge"] 471 | self.max_ngel = config["max_ngel"] 472 | -------------------------------------------------------------------------------- /src/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from collections import OrderedDict 6 | from basemodel import EdgeSeqModel 7 | from utils import map_activation_str_to_layer 8 | 9 | class CNN(EdgeSeqModel): 10 | def __init__(self, config): 11 | super(CNN, self).__init__(config) 12 | 13 | if len(config["cnn_conv_kernel_sizes"]) != len(config["cnn_pool_kernel_sizes"]): 14 | raise ValueError("Error: the size of cnn_conv_kernel_sizes is not equal to that of cnn_pool_kernel_sizes.") 15 | if len(config["cnn_conv_strides"]) != len(config["cnn_pool_strides"]): 16 | raise ValueError("Error: the size of cnn_conv_strides is not equal to that of cnn_pool_strides.") 17 | if len(config["cnn_conv_kernel_sizes"]) != len(config["cnn_conv_strides"]): 18 | raise ValueError("Error: the size of cnn_conv_kernel_sizes is not equal to that of cnn_conv_strides.") 19 | 20 | # create networks 21 | p_emb_dim, g_emb_dim = self.get_emb_dim() 22 | self.g_net, g_dim = self.create_net( 23 | name="graph", input_dim=g_emb_dim, 24 | conv_channels=config["cnn_conv_channels"], 25 | conv_kernel_sizes=config["cnn_conv_kernel_sizes"], conv_paddings=config["cnn_conv_paddings"], 26 | conv_strides=config["cnn_conv_strides"], 27 | pool_kernel_sizes=config["cnn_pool_kernel_sizes"], pool_paddings=config["cnn_pool_paddings"], 28 | pool_strides=config["cnn_pool_strides"], 29 | act_func=self.act_func, dropout=self.dropout) 30 | self.p_net, p_dim = (self.g_net, g_dim) if self.share_arch else self.create_net( 31 | name="pattern", input_dim=p_emb_dim, 32 | conv_channels=config["cnn_conv_channels"], 33 | conv_kernel_sizes=config["cnn_conv_kernel_sizes"], conv_paddings=config["cnn_conv_paddings"], 34 | conv_strides=config["cnn_conv_strides"], 35 | pool_kernel_sizes=config["cnn_pool_kernel_sizes"], pool_paddings=config["cnn_pool_paddings"], 36 | pool_strides=config["cnn_pool_strides"], 37 | act_func=self.act_func, dropout=self.dropout) 38 | 39 | # create predict layers 40 | self.predict_net = self.create_predict_net(config["predict_net"], 41 | pattern_dim=p_dim, graph_dim=g_dim, hidden_dim=config["predict_net_hidden_dim"], 42 | num_heads=config["predict_net_num_heads"], recurrent_steps=config["predict_net_recurrent_steps"], 43 | mem_len=config["predict_net_mem_len"], mem_init=config["predict_net_mem_init"]) 44 | 45 | def create_net(self, name, input_dim, **kw): 46 | conv_kernel_sizes = kw.get("conv_kernel_sizes", (1,2,3)) 47 | conv_paddings = kw.get("conv_paddings", (-1,-1,-1)) 48 | conv_channels = kw.get("conv_channels", (64,64,64)) 49 | conv_strides = kw.get("conv_strides", (1,1,1)) 50 | pool_kernel_sizes = kw.get("pool_kernel_sizes", (2,3,4)) 51 | pool_strides = kw.get("pool_strides", (1,1,1)) 52 | pool_paddings = kw.get("pool_paddings", (-1,-1,-1)) 53 | act_func = kw.get("act_func", "relu") 54 | dropout = kw.get("dropout", 0.0) 55 | 56 | cnns = nn.ModuleList() 57 | for i, conv_kernel_size in enumerate(conv_kernel_sizes): 58 | conv_stride = conv_strides[i] 59 | conv_padding = conv_paddings[i] 60 | if conv_padding == -1: 61 | conv_padding = conv_kernel_size//2 62 | 63 | pool_kernel_size = pool_kernel_sizes[i] 64 | pool_padding = pool_paddings[i] 65 | pool_stride = pool_strides[i] 66 | if pool_padding == -1: 67 | pool_padding = pool_kernel_size//2 68 | 69 | cnn = nn.Sequential(OrderedDict([ 70 | ("conv", nn.Conv1d(conv_channels[i-1] if i > 0 else input_dim, conv_channels[i], 71 | kernel_size=conv_kernel_size, stride=conv_stride, padding=conv_padding)), 72 | ("act", map_activation_str_to_layer(act_func)), 73 | ("pool", nn.MaxPool1d( 74 | kernel_size=pool_kernel_size, stride=pool_stride, padding=pool_padding)), 75 | # ("norm", nn.BatchNorm1d(conv_channels[i])), 76 | ("drop", nn.Dropout(dropout))])) 77 | cnns.add_module("%s_cnn%d" % (name, i), cnn) 78 | num_features = conv_channels[i] 79 | 80 | # init 81 | for m in cnns.modules(): 82 | if isinstance(m, nn.Conv1d): 83 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=act_func) 84 | nn.init.zeros_(m.bias) 85 | elif isinstance(m, nn.BatchNorm1d): 86 | nn.init.ones_(m.weight) 87 | nn.init.zeros_(m.bias) 88 | return cnns, num_features 89 | 90 | def increase_input_size(self, config): 91 | super(CNN, self).increase_input_size(config) 92 | 93 | def increase_net(self, config): 94 | p_emb_dim, g_emb_dim = self.get_emb_dim() 95 | g_net, g_dim = self.create_net( 96 | name="graph", input_dim=g_emb_dim, 97 | conv_channels=config["cnn_conv_channels"], 98 | conv_kernel_sizes=config["cnn_conv_kernel_sizes"], conv_paddings=config["cnn_conv_paddings"], 99 | conv_strides=config["cnn_conv_strides"], 100 | pool_kernel_sizes=config["cnn_pool_kernel_sizes"], pool_paddings=config["cnn_pool_paddings"], 101 | pool_strides=config["cnn_pool_strides"], 102 | act_func=self.act_func, dropout=self.dropout) 103 | assert len(g_net) >= len(self.g_net) 104 | with torch.no_grad(): 105 | for old_g_cnn, new_g_cnn in zip(self.g_net, g_net): 106 | new_g_cnn.load_state_dict(old_g_cnn.state_dict()) 107 | del self.g_net 108 | self.g_net = g_net 109 | 110 | if self.share_arch: 111 | self.p_net = self.g_net 112 | else: 113 | p_net, p_dim = self.create_net( 114 | name="pattern", input_dim=p_emb_dim, 115 | conv_channels=config["cnn_conv_channels"], 116 | conv_kernel_sizes=config["cnn_conv_kernel_sizes"], conv_paddings=config["cnn_conv_paddings"], 117 | conv_strides=config["cnn_conv_strides"], 118 | pool_kernel_sizes=config["cnn_pool_kernel_sizes"], pool_paddings=config["cnn_pool_paddings"], 119 | pool_strides=config["cnn_pool_strides"], 120 | act_func=self.act_func, dropout=self.dropout) 121 | assert len(p_net) >= len(self.p_net) 122 | with torch.no_grad(): 123 | for old_p_cnn, new_p_cnn in zip(self.p_net, p_net): 124 | new_p_cnn.load_state_dict(old_p_cnn.state_dict()) 125 | del self.p_net 126 | self.p_net = p_net 127 | 128 | def forward(self, pattern, pattern_len, graph, graph_len): 129 | bsz = pattern_len.size(0) 130 | 131 | gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len) 132 | zero_mask = (gate == 0).unsqueeze(-1) if gate is not None else None 133 | pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph, graph_len) 134 | if zero_mask is not None: 135 | graph_emb.masked_fill_(zero_mask, 0.0) 136 | 137 | pattern_output = pattern_emb.transpose(1, 2) 138 | for p_cnn in self.p_net: 139 | o = p_cnn(pattern_output) 140 | if o.size() == pattern_output.size(): 141 | pattern_output = o + pattern_output 142 | else: 143 | pattern_output = o 144 | pattern_output = pattern_output.transpose(1, 2) 145 | 146 | graph_output = graph_emb.transpose(1, 2) 147 | for g_cnn in self.g_net: 148 | o = g_cnn(graph_output) 149 | if o.size() == graph_output.size(): 150 | graph_output = o + graph_output 151 | else: 152 | graph_output = o 153 | graph_output = graph_output.transpose(1, 2) 154 | 155 | pred = self.predict_net(pattern_output, pattern_len, graph_output, graph_len) 156 | 157 | return pred 158 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dgl 4 | import os 5 | import math 6 | import pickle 7 | import json 8 | import copy 9 | import torch.utils.data as data 10 | from collections import defaultdict, Counter 11 | from tqdm import tqdm 12 | from utils import get_enc_len, int2onehot, \ 13 | batch_convert_tensor_to_tensor, batch_convert_array_to_array 14 | 15 | INF = float("inf") 16 | 17 | ############################################## 18 | ################ Sampler Part ################ 19 | ############################################## 20 | class Sampler(data.Sampler): 21 | _type_map = { 22 | int: np.int32, 23 | float: np.float32} 24 | 25 | def __init__(self, dataset, group_by, batch_size, shuffle, drop_last): 26 | super(Sampler, self).__init__(dataset) 27 | if isinstance(group_by, str): 28 | group_by = [group_by] 29 | for attr in group_by: 30 | setattr(self, attr, list()) 31 | self.data_size = len(dataset.data) 32 | for x in dataset.data: 33 | for attr in group_by: 34 | value = x[attr] 35 | if isinstance(value, dgl.DGLGraph): 36 | getattr(self, attr).append(value.number_of_nodes()) 37 | elif hasattr(value, "__len__"): 38 | getattr(self, attr).append(len(value)) 39 | else: 40 | getattr(self, attr).append(value) 41 | self.order = copy.copy(group_by) 42 | self.order.append("rand") 43 | self.batch_size = batch_size 44 | self.shuffle = shuffle 45 | self.drop_last = drop_last 46 | 47 | def make_array(self): 48 | self.rand = np.random.rand(self.data_size).astype(np.float32) 49 | if self.data_size == 0: 50 | types = [np.float32] * len(self.order) 51 | else: 52 | types = [type(getattr(self, attr)[0]) for attr in self.order] 53 | types = [Sampler._type_map.get(t, t) for t in types] 54 | dtype = list(zip(self.order, types)) 55 | array = np.array( 56 | list(zip(*[getattr(self, attr) for attr in self.order])), 57 | dtype=dtype) 58 | return array 59 | 60 | def __iter__(self): 61 | array = self.make_array() 62 | indices = np.argsort(array, axis=0, order=self.order) 63 | batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)] 64 | if self.shuffle: 65 | np.random.shuffle(batches) 66 | batch_idx = 0 67 | while batch_idx < len(batches)-1: 68 | yield batches[batch_idx] 69 | batch_idx += 1 70 | if len(batches) > 0 and (len(batches[batch_idx]) == self.batch_size or not self.drop_last): 71 | yield batches[batch_idx] 72 | 73 | def __len__(self): 74 | if self.drop_last: 75 | return math.floor(self.data_size/self.batch_size) 76 | else: 77 | return math.ceil(self.data_size/self.batch_size) 78 | 79 | 80 | ############################################## 81 | ############# EdgeSeq Data Part ############## 82 | ############################################## 83 | class EdgeSeq: 84 | def __init__(self, code): 85 | self.u = code[:,0] 86 | self.v = code[:,1] 87 | self.ul = code[:,2] 88 | self.el = code[:,3] 89 | self.vl = code[:,4] 90 | 91 | def __len__(self): 92 | if len(self.u.shape) == 2: # single code 93 | return self.u.shape[0] 94 | else: # batch code 95 | return self.u.shape[0] * self.u.shape[1] 96 | 97 | @staticmethod 98 | def batch(data): 99 | b = EdgeSeq(torch.empty((0,5), dtype=torch.long)) 100 | b.u = batch_convert_tensor_to_tensor([x.u for x in data]) 101 | b.v = batch_convert_tensor_to_tensor([x.v for x in data]) 102 | b.ul = batch_convert_tensor_to_tensor([x.ul for x in data]) 103 | b.el = batch_convert_tensor_to_tensor([x.el for x in data]) 104 | b.vl = batch_convert_tensor_to_tensor([x.vl for x in data]) 105 | return b 106 | 107 | def to(self, device): 108 | self.u = self.u.to(device) 109 | self.v = self.v.to(device) 110 | self.ul = self.ul.to(device) 111 | self.el = self.el.to(device) 112 | self.vl = self.vl.to(device) 113 | 114 | 115 | ############################################## 116 | ############# EdgeSeq Data Part ############## 117 | ############################################## 118 | class EdgeSeqDataset(data.Dataset): 119 | def __init__(self, data=None): 120 | super(EdgeSeqDataset, self).__init__() 121 | 122 | if data: 123 | self.data = EdgeSeqDataset.preprocess_batch(data, use_tqdm=True) 124 | else: 125 | self.data = list() 126 | self._to_tensor() 127 | 128 | def _to_tensor(self): 129 | for x in self.data: 130 | for k in ["pattern", "graph", "subisomorphisms"]: 131 | if isinstance(x[k], np.ndarray): 132 | x[k] = torch.from_numpy(x[k]) 133 | 134 | def __len__(self): 135 | return len(self.data) 136 | 137 | def __getitem__(self, idx): 138 | return self.data[idx] 139 | 140 | def save(self, filename): 141 | cache = defaultdict(list) 142 | for x in self.data: 143 | for k in list(x.keys()): 144 | if k.startswith("_"): 145 | cache[k].append(x.pop(k)) 146 | with open(filename, "wb") as f: 147 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 148 | if len(cache) > 0: 149 | keys = cache.keys() 150 | for i in range(len(self.data)): 151 | for k in keys: 152 | self.data[i][k] = cache[k][i] 153 | 154 | def load(self, filename): 155 | with open(filename, "rb") as f: 156 | data = torch.load(f) 157 | del self.data 158 | self.data = data 159 | 160 | return self 161 | 162 | @staticmethod 163 | def graph2edgeseq(graph): 164 | labels = graph.vs["label"] 165 | graph_code = list() 166 | 167 | for edge in graph.es: 168 | v, u = edge.tuple 169 | graph_code.append((v, u, labels[v], edge["label"], labels[u])) 170 | graph_code = np.array(graph_code, dtype=np.int64) 171 | graph_code.view( 172 | [("v", "int64"), ("u", "int64"), ("vl", "int64"), ("el", "int64"), ("ul", "int64")]).sort( 173 | axis=0, order=["v", "u", "el"]) 174 | return graph_code 175 | 176 | @staticmethod 177 | def preprocess(x): 178 | pattern_code = EdgeSeqDataset.graph2edgeseq(x["pattern"]) 179 | graph_code = EdgeSeqDataset.graph2edgeseq(x["graph"]) 180 | subisomorphisms = np.array(x["subisomorphisms"], dtype=np.int32).reshape(-1, x["pattern"].vcount()) 181 | 182 | x = { 183 | "id": x["id"], 184 | "pattern": pattern_code, 185 | "graph": graph_code, 186 | "counts": x["counts"], 187 | "subisomorphisms": subisomorphisms} 188 | return x 189 | 190 | @staticmethod 191 | def preprocess_batch(data, use_tqdm=False): 192 | d = list() 193 | if use_tqdm: 194 | data = tqdm(data) 195 | for x in data: 196 | d.append(EdgeSeqDataset.preprocess(x)) 197 | return d 198 | 199 | @staticmethod 200 | def batchify(batch): 201 | _id = [x["id"] for x in batch] 202 | pattern = EdgeSeq.batch([EdgeSeq(x["pattern"]) for x in batch]) 203 | pattern_len = torch.tensor([x["pattern"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 204 | graph = EdgeSeq.batch([EdgeSeq(x["graph"]) for x in batch]) 205 | graph_len = torch.tensor([x["graph"].shape[0] for x in batch], dtype=torch.int32).view(-1, 1) 206 | counts = torch.tensor([x["counts"] for x in batch], dtype=torch.float32).view(-1, 1) 207 | return _id, pattern, pattern_len, graph, graph_len, counts 208 | 209 | 210 | ############################################## 211 | ######### GraphAdj Data Part ########### 212 | ############################################## 213 | class GraphAdjDataset(data.Dataset): 214 | def __init__(self, data=None): 215 | super(GraphAdjDataset, self).__init__() 216 | 217 | if data: 218 | self.data = GraphAdjDataset.preprocess_batch(data, use_tqdm=True) 219 | else: 220 | self.data = list() 221 | self._to_tensor() 222 | 223 | def _to_tensor(self): 224 | for x in self.data: 225 | for k in ["pattern", "graph"]: 226 | y = x[k] 227 | for k, v in y.ndata.items(): 228 | if isinstance(v, np.ndarray): 229 | y.ndata[k] = torch.from_numpy(v) 230 | for k, v in y.edata.items(): 231 | if isinstance(v, np.ndarray): 232 | y.edata[k] = torch.from_numpy(v) 233 | if isinstance(x["subisomorphisms"], np.ndarray): 234 | x["subisomorphisms"] = torch.from_numpy(x["subisomorphisms"]) 235 | 236 | def __len__(self): 237 | return len(self.data) 238 | 239 | def __getitem__(self, idx): 240 | return self.data[idx] 241 | 242 | def save(self, filename): 243 | cache = defaultdict(list) 244 | for x in self.data: 245 | for k in list(x.keys()): 246 | if k.startswith("_"): 247 | cache[k].append(x.pop(k)) 248 | with open(filename, "wb") as f: 249 | torch.save(self.data, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) 250 | if len(cache) > 0: 251 | keys = cache.keys() 252 | for i in range(len(self.data)): 253 | for k in keys: 254 | self.data[i][k] = cache[k][i] 255 | 256 | def load(self, filename): 257 | with open(filename, "rb") as f: 258 | data = torch.load(f) 259 | del self.data 260 | self.data = data 261 | 262 | return self 263 | 264 | @staticmethod 265 | def comp_indeg_norm(graph): 266 | import igraph as ig 267 | if isinstance(graph, ig.Graph): 268 | # 10x faster 269 | in_deg = np.array(graph.indegree(), dtype=np.float32) 270 | elif isinstance(graph, dgl.DGLGraph): 271 | in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy() 272 | else: 273 | raise NotImplementedError 274 | norm = 1.0 / in_deg 275 | norm[np.isinf(norm)] = 0 276 | return norm 277 | 278 | @staticmethod 279 | def graph2dglgraph(graph): 280 | dglgraph = dgl.DGLGraph(multigraph=True) 281 | dglgraph.add_nodes(graph.vcount()) 282 | edges = graph.get_edgelist() 283 | dglgraph.add_edges([e[0] for e in edges], [e[1] for e in edges]) 284 | dglgraph.readonly(True) 285 | return dglgraph 286 | 287 | @staticmethod 288 | def preprocess(x): 289 | pattern = x["pattern"] 290 | pattern_dglgraph = GraphAdjDataset.graph2dglgraph(pattern) 291 | pattern_dglgraph.ndata["indeg"] = np.array(pattern.indegree(), dtype=np.float32) 292 | pattern_dglgraph.ndata["label"] = np.array(pattern.vs["label"], dtype=np.int64) 293 | pattern_dglgraph.ndata["id"] = np.arange(0, pattern.vcount(), dtype=np.int64) 294 | pattern_dglgraph.edata["label"] = np.array(pattern.es["label"], dtype=np.int64) 295 | 296 | graph = x["graph"] 297 | graph_dglgraph = GraphAdjDataset.graph2dglgraph(graph) 298 | graph_dglgraph.ndata["indeg"] = np.array(graph.indegree(), dtype=np.float32) 299 | graph_dglgraph.ndata["label"] = np.array(graph.vs["label"], dtype=np.int64) 300 | graph_dglgraph.ndata["id"] = np.arange(0, graph.vcount(), dtype=np.int64) 301 | graph_dglgraph.edata["label"] = np.array(graph.es["label"], dtype=np.int64) 302 | 303 | subisomorphisms = np.array(x["subisomorphisms"], dtype=np.int32).reshape(-1, x["pattern"].vcount()) 304 | 305 | x = { 306 | "id": x["id"], 307 | "pattern": pattern_dglgraph, 308 | "graph": graph_dglgraph, 309 | "counts": x["counts"], 310 | "subisomorphisms": subisomorphisms} 311 | return x 312 | 313 | @staticmethod 314 | def preprocess_batch(data, use_tqdm=False): 315 | d = list() 316 | if use_tqdm: 317 | data = tqdm(data) 318 | for x in data: 319 | d.append(GraphAdjDataset.preprocess(x)) 320 | return d 321 | 322 | @staticmethod 323 | def batchify(batch): 324 | _id = [x["id"] for x in batch] 325 | pattern = dgl.batch([x["pattern"] for x in batch]) 326 | pattern_len = torch.tensor([x["pattern"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 327 | graph = dgl.batch([x["graph"] for x in batch]) 328 | graph_len = torch.tensor([x["graph"].number_of_nodes() for x in batch], dtype=torch.int32).view(-1, 1) 329 | counts = torch.tensor([x["counts"] for x in batch], dtype=torch.float32).view(-1, 1) 330 | return _id, pattern, pattern_len, graph, graph_len, counts 331 | 332 | -------------------------------------------------------------------------------- /src/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import extend_dimensions 5 | 6 | 7 | class NormalEmbedding(nn.Module): 8 | def __init__(self, input_dim, emb_dim): 9 | super(NormalEmbedding, self).__init__() 10 | self.input_dim = input_dim 11 | self.emb_dim = emb_dim 12 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 13 | 14 | # init 15 | nn.init.normal_(self.emb_layer.weight, 0.0, 1.0) 16 | 17 | def increase_input_size(self, new_input_dim): 18 | assert new_input_dim >= self.input_dim 19 | if new_input_dim != self.input_dim: 20 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 21 | del self.emb_layer 22 | self.emb_layer = new_emb_layer 23 | self.input_dim = new_input_dim 24 | 25 | def forward(self, x): 26 | emb = self.emb_layer(x) 27 | return emb 28 | 29 | class OrthogonalEmbedding(nn.Module): 30 | def __init__(self, input_dim, emb_dim): 31 | super(OrthogonalEmbedding, self).__init__() 32 | self.input_dim = input_dim 33 | self.emb_dim = emb_dim 34 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 35 | 36 | # init 37 | nn.init.orthogonal_(self.emb_layer.weight) 38 | 39 | def increase_input_size(self, new_input_dim): 40 | assert new_input_dim >= self.input_dim 41 | if new_input_dim != self.input_dim: 42 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 43 | del self.emb_layer 44 | self.emb_layer = new_emb_layer 45 | self.input_dim = new_input_dim 46 | 47 | def forward(self, x): 48 | emb = self.emb_layer(x) 49 | return emb 50 | 51 | class EquivariantEmbedding(nn.Module): 52 | def __init__(self, input_dim, emb_dim): 53 | super(EquivariantEmbedding, self).__init__() 54 | self.input_dim = input_dim 55 | self.emb_dim = emb_dim 56 | self.emb_layer = nn.Linear(input_dim, emb_dim, bias=False) 57 | 58 | # init 59 | nn.init.normal_(self.emb_layer.weight[:,0], 0.0, 1.0) 60 | emb_column = self.emb_layer.weight[:,0] 61 | with torch.no_grad(): 62 | for i in range(1, self.input_dim): 63 | self.emb_layer.weight[:,i].data.copy_(torch.roll(emb_column, i, 0)) 64 | 65 | def increase_input_size(self, new_input_dim): 66 | assert new_input_dim >= self.input_dim 67 | if new_input_dim != self.input_dim: 68 | new_emb_layer = extend_dimensions(self.emb_layer, new_input_dim=new_input_dim, upper=False) 69 | del self.emb_layer 70 | self.emb_layer = new_emb_layer 71 | self.input_dim = new_input_dim 72 | 73 | def forward(self, x): 74 | emb = self.emb_layer(x) 75 | return emb -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import numpy as np 5 | from utils import get_best_epochs, compute_mae, compute_rmse, compute_p_r_f1, compute_tp 6 | 7 | 8 | if __name__ == "__main__": 9 | assert len(sys.argv) == 2 10 | model_dir = sys.argv[1] 11 | 12 | # get the best epoch 13 | if os.path.exists(os.path.join(model_dir, "finetune_log.txt")): 14 | best_epochs = get_best_epochs(os.path.join(model_dir, "finetune_log.txt")) 15 | elif os.path.exists(os.path.join(model_dir, "train_log.txt")): 16 | best_epochs = get_best_epochs(os.path.join(model_dir, "train_log.txt")) 17 | else: 18 | raise FileNotFoundError("finetune_log.txt and train_log.txt cannot be found in %s" % (os.path.join(model_dir))) 19 | print("retrieve the best epoch for training set ({:0>3d}), dev set ({:0>3d}), and test set ({:0>3d})".format( 20 | best_epochs["train"], best_epochs["dev"], best_epochs["test"])) 21 | 22 | with open(os.path.join(model_dir, "dev%d.json" % (best_epochs["dev"])), "r") as f: 23 | results = json.load(f) 24 | pred = np.array(results["data"]["pred"]) 25 | counts = np.array(results["data"]["counts"]) 26 | print("dev-RMSE: %.4f\tdev-MAE: %.4f\tdev-F1_Zero: %.4f\tdev-F1_NonZero: %.4f\tdev-Time: %.4f" % ( 27 | compute_rmse(pred, counts), compute_mae(pred, counts), 28 | compute_p_r_f1(pred < 0.5, counts < 0.5)[2], compute_p_r_f1(pred > 0.5, counts > 0.5)[2], 29 | results["time"]["total"])) 30 | 31 | with open(os.path.join(model_dir, "test%d.json" % (best_epochs["dev"])), "r") as f: 32 | results = json.load(f) 33 | pred = np.array(results["data"]["pred"]) 34 | counts = np.array(results["data"]["counts"]) 35 | print("test-RMSE: %.4f\ttest-MAE: %.4f\ttest-F1_Zero: %.4f\ttest-F1_NonZero: %.4f\ttest-Time: %.4f" % ( 36 | compute_rmse(pred, counts), compute_mae(pred, counts), 37 | compute_p_r_f1(pred < 0.5, counts < 0.5)[2], compute_p_r_f1(pred > 0.5, counts > 0.5)[2], 38 | results["time"]["total"])) -------------------------------------------------------------------------------- /src/filternet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # class MaxGatedFilterNet(nn.Module): 6 | # def __init__(self, pattern_dim, graph_dim): 7 | # super(MaxGatedFilterNet, self).__init__() 8 | # self.g_layer = nn.Linear(graph_dim, pattern_dim) 9 | # self.f_layer = nn.Linear(pattern_dim, 1) 10 | 11 | # # init 12 | # scale = (1/pattern_dim)**0.5 13 | # nn.init.normal_(self.g_layer.weight, 0.0, scale) 14 | # nn.init.zeros_(self.g_layer.bias) 15 | # nn.init.normal_(self.f_layer.weight, 0.0, scale) 16 | # nn.init.ones_(self.f_layer.bias) 17 | 18 | # def forward(self, p_x, g_x): 19 | # max_x = torch.max(p_x, dim=1, keepdim=True)[0].float() 20 | # g_x = self.g_layer(g_x.float()) 21 | # f = self.f_layer(g_x * max_x) 22 | # return F.sigmoid(f) 23 | 24 | class MaxGatedFilterNet(nn.Module): 25 | def __init__(self): 26 | super(MaxGatedFilterNet, self).__init__() 27 | 28 | def forward(self, p_x, g_x): 29 | max_x = torch.max(p_x, dim=1, keepdim=True)[0] 30 | if max_x.dim() == 2: 31 | return g_x <= max_x 32 | else: 33 | return (g_x <= max_x).all(keepdim=True, dim=2) 34 | -------------------------------------------------------------------------------- /src/finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import dgl 5 | import logging 6 | import datetime 7 | import math 8 | import sys 9 | import gc 10 | import re 11 | import subprocess 12 | import json 13 | import torch.nn.functional as F 14 | import warnings 15 | import shutil 16 | from functools import partial 17 | from collections import OrderedDict 18 | from torch.utils.data import DataLoader 19 | try: 20 | from torch.utils.tensorboard import SummaryWriter 21 | except BaseException as e: 22 | from tensorboardX import SummaryWriter 23 | from dataset import Sampler, EdgeSeqDataset, GraphAdjDataset 24 | from utils import anneal_fn, get_enc_len, load_data, get_best_epochs, get_linear_schedule_with_warmup 25 | from cnn import CNN 26 | from rnn import RNN 27 | from txl import TXL 28 | from rgcn import RGCN 29 | from rgin import RGIN 30 | from train import train, evaluate 31 | 32 | warnings.filterwarnings("ignore") 33 | INF = float("inf") 34 | 35 | finetune_config = { 36 | "max_npv": 16, # max_number_pattern_vertices: 8, 16, 32 37 | "max_npe": 16, # max_number_pattern_edges: 8, 16, 32 38 | "max_npvl": 16, # max_number_pattern_vertex_labels: 8, 16, 32 39 | "max_npel": 16, # max_number_pattern_edge_labels: 8, 16, 32 40 | 41 | "max_ngv": 512, # max_number_graph_vertices: 64, 512,4096 42 | "max_nge": 2048, # max_number_graph_edges: 256, 2048, 16384 43 | "max_ngvl": 64, # max_number_graph_vertex_labels: 16, 64, 256 44 | "max_ngel": 64, # max_number_graph_edge_labels: 16, 64, 256 45 | 46 | # "base": 2, 47 | 48 | "gpu_id": -1, 49 | "num_workers": 12, 50 | 51 | "epochs": 100, 52 | "batch_size": 128, 53 | "update_every": 4, # actual batch_size = batch_size * update_every 54 | "print_every": 100, 55 | "share_emb": True, # sharing embedding requires the same vector length 56 | "share_arch": True, # sharing architectures 57 | "dropout": 0.2, 58 | "dropatt": 0.2, 59 | 60 | "predict_net": "SumPredictNet", # MeanPredictNet, SumPredictNet, MaxPredictNet, 61 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 62 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 63 | # DIAMNet 64 | "predict_net_hidden_dim": 128, 65 | "predict_net_num_heads": 4, 66 | "predict_net_mem_len": 4, 67 | "predict_net_mem_init": "mean", # mean, sum, max, attn, circular_mean, circular_sum, circular_max, circular_attn, lstm 68 | "predict_net_recurrent_steps": 3, 69 | 70 | "reg_loss": "MSE", # MAE, MSE, SMAE 71 | "bp_loss": "MSE", # MAE, MSE, SMAE 72 | "bp_loss_slp": "anneal_cosine$1.0$0.01", # 0, 0.01, logistic$1.0$0.01, linear$1.0$0.01, cosine$1.0$0.01, 73 | # cyclical_logistic$1.0$0.01, cyclical_linear$1.0$0.01, cyclical_cosine$1.0$0.01 74 | # anneal_logistic$1.0$0.01, anneal_linear$1.0$0.01, anneal_cosine$1.0$0.01 75 | "lr": 0.001, 76 | "weight_decay": 0.00001, 77 | "max_grad_norm": 8, 78 | 79 | "pattern_dir": "../data/middle/patterns", 80 | "graph_dir": "../data/middle/graphs", 81 | "metadata_dir": "../data/middle/metadata", 82 | "save_data_dir": "../data/middle", 83 | "save_model_dir": "../dumps/middle", 84 | "load_model_dir": "../dumps/middle/XXXX" 85 | } 86 | 87 | if __name__ == "__main__": 88 | torch.manual_seed(0) 89 | np.random.seed(0) 90 | 91 | for i in range(1, len(sys.argv), 2): 92 | arg = sys.argv[i] 93 | value = sys.argv[i+1] 94 | 95 | if arg.startswith("--"): 96 | arg = arg[2:] 97 | if arg not in finetune_config: 98 | print("Warning: %s is not surported now." % (arg)) 99 | continue 100 | finetune_config[arg] = value 101 | try: 102 | value = eval(value) 103 | if isinstance(value, (int, float)): 104 | finetune_config[arg] = value 105 | except: 106 | pass 107 | 108 | # load config 109 | if os.path.exists(os.path.join(finetune_config["load_model_dir"], "train_config.json")): 110 | with open(os.path.join(finetune_config["load_model_dir"], "train_config.json"), "r") as f: 111 | train_config = json.load(f) 112 | elif os.path.exists(os.path.join(finetune_config["load_model_dir"], "finetune_config.json")): 113 | with open(os.path.join(finetune_config["load_model_dir"], "finetune_config.json"), "r") as f: 114 | train_config = json.load(f) 115 | else: 116 | raise FileNotFoundError("finetune_config.json and train_config.json cannot be found in %s" % (os.path.join(finetune_config["load_model_dir"]))) 117 | 118 | for key in train_config: 119 | if key not in finetune_config: 120 | finetune_config[key] = train_config[key] 121 | 122 | ts = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 123 | model_name = "%s_%s_%s" % (finetune_config["model"], finetune_config["predict_net"], ts) 124 | save_model_dir = finetune_config["save_model_dir"] 125 | os.makedirs(save_model_dir, exist_ok=True) 126 | 127 | # save config 128 | with open(os.path.join(save_model_dir, "finetune_config.json"), "w") as f: 129 | json.dump(finetune_config, f) 130 | 131 | # set logger 132 | logger = logging.getLogger() 133 | logger.setLevel(logging.INFO) 134 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%Y/%m/%d %H:%M:%S') 135 | console = logging.StreamHandler() 136 | console.setFormatter(fmt) 137 | logger.addHandler(console) 138 | logfile = logging.FileHandler(os.path.join(save_model_dir, "finetune_log.txt"), 'w') 139 | logfile.setFormatter(fmt) 140 | logger.addHandler(logfile) 141 | 142 | # set device 143 | device = torch.device("cuda:%d" % finetune_config["gpu_id"] if finetune_config["gpu_id"] != -1 else "cpu") 144 | if finetune_config["gpu_id"] != -1: 145 | torch.cuda.set_device(device) 146 | 147 | # check model 148 | if finetune_config["model"] not in ["CNN", "RNN", "TXL", "RGCN", "RGIN"]: 149 | raise NotImplementedError("Currently, the %s model is not supported" % (finetune_config["model"])) 150 | 151 | # reset the pattern parameters 152 | if finetune_config["share_emb"]: 153 | finetune_config["max_npv"], finetune_config["max_npvl"], finetune_config["max_npe"], finetune_config["max_npel"] = \ 154 | finetune_config["max_ngv"], finetune_config["max_ngvl"], finetune_config["max_nge"], finetune_config["max_ngel"] 155 | 156 | # get the best epoch 157 | if os.path.exists(os.path.join(finetune_config["load_model_dir"], "finetune_log.txt")): 158 | best_epochs = get_best_epochs(os.path.join(finetune_config["load_model_dir"], "finetune_log.txt")) 159 | elif os.path.exists(os.path.join(finetune_config["load_model_dir"], "train_log.txt")): 160 | best_epochs = get_best_epochs(os.path.join(finetune_config["load_model_dir"], "train_log.txt")) 161 | else: 162 | raise FileNotFoundError("finetune_log.txt and train_log.txt cannot be found in %s" % (os.path.join(finetune_config["load_model_dir"]))) 163 | logger.info("retrieve the best epoch for training set ({:0>3d}), dev set ({:0>3d}), and test set ({:0>3d})".format( 164 | best_epochs["train"], best_epochs["dev"], best_epochs["test"])) 165 | 166 | # load the model 167 | for key in ["dropout", "dropatt"]: 168 | train_config[key] = finetune_config[key] 169 | 170 | if train_config["model"] == "CNN": 171 | model = CNN(train_config) 172 | elif train_config["model"] == "RNN": 173 | model = RNN(train_config) 174 | elif train_config["model"] == "TXL": 175 | model = TXL(train_config) 176 | elif train_config["model"] == "RGCN": 177 | model = RGCN(train_config) 178 | elif train_config["model"] == "RGIN": 179 | model = RGIN(train_config) 180 | else: 181 | raise NotImplementedError("Currently, the %s model is not supported" % (train_config["model"])) 182 | 183 | model.load_state_dict(torch.load( 184 | os.path.join(finetune_config["load_model_dir"], "epoch%d.pt" % (best_epochs["dev"])), map_location=torch.device("cpu"))) 185 | model.increase_net(finetune_config) 186 | if not all([train_config[key] == finetune_config[key] for key in [ 187 | "max_npv", "max_npe", "max_npvl", "max_npel", "max_ngv", "max_nge", "max_ngvl", "max_ngel", 188 | "share_emb", "share_arch"]]): 189 | model.increase_input_size(finetune_config) 190 | if not all([train_config[key] == finetune_config[key] for key in [ 191 | "predict_net", "predict_net_hidden_dim", 192 | "predict_net_num_heads", "predict_net_mem_len", "predict_net_mem_init", "predict_net_recurrent_steps"]]): 193 | new_predict_net = model.create_predict_net(finetune_config["predict_net"], 194 | pattern_dim=model.predict_net.pattern_dim, graph_dim=model.predict_net.graph_dim, 195 | hidden_dim=finetune_config["predict_net_hidden_dim"], 196 | num_heads=finetune_config["predict_net_num_heads"], recurrent_steps=finetune_config["predict_net_recurrent_steps"], 197 | mem_len=finetune_config["predict_net_mem_len"], mem_init=finetune_config["predict_net_mem_init"]) 198 | del model.predict_net 199 | model.predict_net = new_predict_net 200 | model = model.to(device) 201 | torch.cuda.empty_cache() 202 | logger.info("load the model based on the dev set (epoch: {:0>3d})".format(best_epochs["dev"])) 203 | logger.info(model) 204 | logger.info("num of parameters: %d" % (sum(p.numel() for p in model.parameters() if p.requires_grad))) 205 | 206 | # load data 207 | os.makedirs(finetune_config["save_data_dir"], exist_ok=True) 208 | data_loaders = OrderedDict({"train": None, "dev": None, "test": None}) 209 | if all([os.path.exists(os.path.join(finetune_config["save_data_dir"], 210 | "%s_%s_dataset.pt" % ( 211 | data_type, "dgl" if finetune_config["model"] in ["RGCN", "RGIN"] else "edgeseq"))) for data_type in data_loaders]): 212 | 213 | logger.info("loading data from pt...") 214 | for data_type in data_loaders: 215 | if finetune_config["model"] in ["RGCN", "RGIN"]: 216 | dataset = GraphAdjDataset(list()) 217 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 218 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 219 | data_loader = DataLoader(dataset, 220 | batch_sampler=sampler, 221 | collate_fn=GraphAdjDataset.batchify, 222 | pin_memory=data_type=="train") 223 | else: 224 | dataset = EdgeSeqDataset(list()) 225 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 226 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 227 | data_loader = DataLoader(dataset, 228 | batch_sampler=sampler, 229 | collate_fn=EdgeSeqDataset.batchify, 230 | pin_memory=data_type=="train") 231 | data_loaders[data_type] = data_loader 232 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 233 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), finetune_config["batch_size"])) 234 | else: 235 | data = load_data(finetune_config["graph_dir"], finetune_config["pattern_dir"], finetune_config["metadata_dir"], num_workers=finetune_config["num_workers"]) 236 | logger.info("{}/{}/{} data loaded".format(len(data["train"]), len(data["dev"]), len(data["test"]))) 237 | for data_type, x in data.items(): 238 | if finetune_config["model"] in ["RGCN", "RGIN", "RSIN"]: 239 | if os.path.exists(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))): 240 | dataset = GraphAdjDataset(list()) 241 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 242 | else: 243 | dataset = GraphAdjDataset(x) 244 | dataset.save(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 245 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 246 | data_loader = DataLoader(dataset, 247 | batch_sampler=sampler, 248 | collate_fn=GraphAdjDataset.batchify, 249 | pin_memory=data_type=="train") 250 | else: 251 | if os.path.exists(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))): 252 | dataset = EdgeSeqDataset(list()) 253 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 254 | else: 255 | dataset = EdgeSeqDataset(x) 256 | dataset.save(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 257 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 258 | data_loader = DataLoader(dataset, 259 | batch_sampler=sampler, 260 | collate_fn=EdgeSeqDataset.batchify, 261 | pin_memory=data_type=="train") 262 | data_loaders[data_type] = data_loader 263 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 264 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), finetune_config["batch_size"])) 265 | 266 | # optimizer and losses 267 | writer = SummaryWriter(save_model_dir) 268 | optimizer = torch.optim.AdamW(model.parameters(), lr=finetune_config["lr"], weight_decay=finetune_config["weight_decay"], amsgrad=True) 269 | optimizer.zero_grad() 270 | scheduler = get_linear_schedule_with_warmup(optimizer, 271 | len(data_loaders["train"]), train_config["epochs"]*len(data_loaders["train"]), min_percent=0.0001) 272 | 273 | best_reg_losses = {"train": INF, "dev": INF, "test": INF} 274 | best_reg_epochs = {"train": -1, "dev": -1, "test": -1} 275 | 276 | for epoch in range(finetune_config["epochs"]): 277 | for data_type, data_loader in data_loaders.items(): 278 | 279 | if data_type == "train": 280 | mean_reg_loss, mean_bp_loss = train(model, optimizer, scheduler, data_type, data_loader, device, 281 | finetune_config, epoch, logger=logger, writer=writer) 282 | torch.save(model.state_dict(), os.path.join(save_model_dir, 'epoch%d.pt' % (epoch))) 283 | else: 284 | mean_reg_loss, mean_bp_loss, evaluate_results = evaluate(model, data_type, data_loader, device, 285 | finetune_config, epoch, logger=logger, writer=writer) 286 | with open(os.path.join(save_model_dir, '%s%d.json' % (data_type, epoch)), "w") as f: 287 | json.dump(evaluate_results, f) 288 | 289 | if mean_reg_loss <= best_reg_losses[data_type]: 290 | best_reg_losses[data_type] = mean_reg_loss 291 | best_reg_epochs[data_type] = epoch 292 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, mean_reg_loss, epoch)) 293 | for data_type in data_loaders.keys(): 294 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, best_reg_losses[data_type], best_reg_epochs[data_type])) 295 | -------------------------------------------------------------------------------- /src/finetune_mutag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import dgl 5 | import logging 6 | import datetime 7 | import math 8 | import sys 9 | import gc 10 | import re 11 | import subprocess 12 | import json 13 | import torch.nn.functional as F 14 | import warnings 15 | from functools import partial 16 | from collections import OrderedDict 17 | from torch.utils.data import DataLoader 18 | try: 19 | from torch.utils.tensorboard import SummaryWriter 20 | except BaseException as e: 21 | from tensorboardX import SummaryWriter 22 | from dataset import Sampler, EdgeSeqDataset, GraphAdjDataset 23 | from utils import anneal_fn, get_enc_len, load_data, get_best_epochs, get_linear_schedule_with_warmup 24 | from mlp import MLP 25 | from rnn import RNN 26 | from transformerxl import TXL 27 | from cnn import CNN 28 | from resnet import ResNet 29 | from rgcn import RGCN 30 | from rgin import RGIN 31 | from rsin import RSIN 32 | from train import train, evaluate 33 | 34 | warnings.filterwarnings("ignore") 35 | INF = float("inf") 36 | 37 | finetune_config = { 38 | "max_npv": 8, # max_number_pattern_vertices: 8, 16, 32 39 | "max_npe": 8, # max_number_pattern_edges: 8, 16, 32 40 | "max_npvl": 8, # max_number_pattern_vertex_labels: 8, 16, 32 41 | "max_npel": 8, # max_number_pattern_edge_labels: 8, 16, 32 42 | 43 | "max_ngv": 64, # max_number_graph_vertices: 64, 512,4096 44 | "max_nge": 256, # max_number_graph_edges: 256, 2048, 16384 45 | "max_ngvl": 16, # max_number_graph_vertex_labels: 16, 64, 256 46 | "max_ngel": 16, # max_number_graph_edge_labels: 16, 64, 256 47 | 48 | # "base": 2, 49 | 50 | "gpu_id": -1, 51 | "num_workers": 12, 52 | 53 | "epochs": 100, 54 | "batch_size": 64, 55 | "update_every": 1, # actual batch_sizer = batch_size * update_every 56 | "print_every": 100, 57 | "share_emb": True, # sharing embedding requires the same vector length 58 | "share_arch": True, # sharing architectures 59 | "dropout": 0.2, 60 | "dropatt": 0.2, 61 | 62 | "predict_net": "SumPredictNet", # MeanPredictNet, SumPredictNet, MaxPredictNet, 63 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 64 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 65 | # DIAMNet 66 | "predict_net_hidden_dim": 128, 67 | "predict_net_num_heads": 4, 68 | "predict_net_mem_len": 4, 69 | "predict_net_mem_init": "mean", # mean, sum, max, attn, circular_mean, circular_sum, circular_max, circular_attn, lstm 70 | "predict_net_recurrent_steps": 3, 71 | 72 | "reg_loss": "MSE", # MAE, MSE, SMAE 73 | "bp_loss": "MSE", # MAE, MSE, SMAE 74 | "bp_loss_slp": "anneal_cosine$1.0$0.01", # 0, 0.01, logistic$1.0$0.01, linear$1.0$0.01, cosine$1.0$0.01, 75 | # cyclical_logistic$1.0$0.01, cyclical_linear$1.0$0.01, cyclical_cosine$1.0$0.01 76 | # anneal_logistic$1.0$0.01, anneal_linear$1.0$0.01, anneal_cosine$1.0$0.01 77 | "lr": 0.001, 78 | "weight_decay": 0.00001, 79 | "max_grad_norm": 8, 80 | 81 | "train_ratio": 1.0, 82 | "pattern_dir": "../data/MUTAG/patterns", 83 | "graph_dir": "../data/MUTAG/raw", 84 | "metadata_dir": "../data/MUTAG/metadata", 85 | "save_data_dir": "../data/MUTAG", 86 | "save_model_dir": "../dumps/MUTAG", 87 | "load_model_dir": "../dumps/small/RGCN_DIAMNet" 88 | } 89 | 90 | if __name__ == "__main__": 91 | torch.manual_seed(0) 92 | np.random.seed(0) 93 | 94 | for i in range(1, len(sys.argv), 2): 95 | arg = sys.argv[i] 96 | value = sys.argv[i+1] 97 | 98 | if arg.startswith("--"): 99 | arg = arg[2:] 100 | if arg not in finetune_config: 101 | print("Warning: %s is not surported now." % (arg)) 102 | continue 103 | finetune_config[arg] = value 104 | try: 105 | value = eval(value) 106 | if isinstance(value, (int, float)): 107 | finetune_config[arg] = value 108 | except: 109 | pass 110 | 111 | 112 | # load config 113 | if os.path.exists(os.path.join(finetune_config["load_model_dir"], "train_config.json")): 114 | with open(os.path.join(finetune_config["load_model_dir"], "train_config.json"), "r") as f: 115 | train_config = json.load(f) 116 | elif os.path.exists(os.path.join(finetune_config["load_model_dir"], "finetune_config.json")): 117 | with open(os.path.join(finetune_config["load_model_dir"], "finetune_config.json"), "r") as f: 118 | train_config = json.load(f) 119 | else: 120 | raise FileNotFoundError("finetune_config.json and train_config.json cannot be found in %s" % (os.path.join(finetune_config["load_model_dir"]))) 121 | 122 | for key in train_config: 123 | if key not in finetune_config: 124 | finetune_config[key] = train_config[key] 125 | 126 | ts = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 127 | model_name = "%s_%s_%s" % (finetune_config["model"], finetune_config["predict_net"], ts) 128 | save_model_dir = finetune_config["save_model_dir"] 129 | os.makedirs(save_model_dir, exist_ok=True) 130 | 131 | # save config 132 | with open(os.path.join(save_model_dir, "finetune_config.json"), "w") as f: 133 | json.dump(finetune_config, f) 134 | 135 | # set logger 136 | logger = logging.getLogger() 137 | logger.setLevel(logging.INFO) 138 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%Y/%m/%d %H:%M:%S') 139 | console = logging.StreamHandler() 140 | console.setFormatter(fmt) 141 | logger.addHandler(console) 142 | logfile = logging.FileHandler(os.path.join(save_model_dir, "finetune_log.txt"), 'w') 143 | logfile.setFormatter(fmt) 144 | logger.addHandler(logfile) 145 | 146 | # set device 147 | device = torch.device("cuda:%d" % finetune_config["gpu_id"] if finetune_config["gpu_id"] != -1 else "cpu") 148 | if finetune_config["gpu_id"] != -1: 149 | torch.cuda.set_device(device) 150 | 151 | # check model 152 | if finetune_config["model"] not in ["MLP","CNN", "DN", "RN", "RNN", "TXL", "RGCN", "RGIN", "RGIN"]: 153 | raise NotImplementedError("Currently, the %s model is not supported" % (finetune_config["model"])) 154 | 155 | # reset the pattern parameters 156 | if finetune_config["share_emb"]: 157 | finetune_config["max_npv"], finetune_config["max_npvl"], finetune_config["max_npe"], finetune_config["max_npel"] = \ 158 | finetune_config["max_ngv"], finetune_config["max_ngvl"], finetune_config["max_nge"], finetune_config["max_ngel"] 159 | 160 | # get the best epoch 161 | if os.path.exists(os.path.join(finetune_config["load_model_dir"], "finetune_log.txt")): 162 | best_epochs = get_best_epochs(os.path.join(finetune_config["load_model_dir"], "finetune_log.txt")) 163 | elif os.path.exists(os.path.join(finetune_config["load_model_dir"], "train_log.txt")): 164 | best_epochs = get_best_epochs(os.path.join(finetune_config["load_model_dir"], "train_log.txt")) 165 | else: 166 | raise FileNotFoundError("finetune_log.txt and train_log.txt cannot be found in %s" % (os.path.join(finetune_config["load_model_dir"]))) 167 | logger.info("retrieve the best epoch for training set ({:0>3d}), dev set ({:0>3d}), and test set ({:0>3d})".format( 168 | best_epochs["train"], best_epochs["dev"], best_epochs["test"])) 169 | 170 | # load the model 171 | for key in ["dropout", "dropatt"]: 172 | train_config[key] = finetune_config[key] 173 | 174 | if train_config["model"] == "CNN": 175 | model = CNN(train_config) 176 | elif train_config["model"] == "RNN": 177 | model = RNN(train_config) 178 | elif train_config["model"] == "TXL": 179 | model = TXL(train_config) 180 | elif train_config["model"] == "RGCN": 181 | model = RGCN(train_config) 182 | elif train_config["model"] == "RGIN": 183 | model = RGIN(train_config) 184 | else: 185 | raise NotImplementedError("Currently, the %s model is not supported" % (train_config["model"])) 186 | 187 | model.load_state_dict(torch.load( 188 | os.path.join(finetune_config["load_model_dir"], "epoch%d.pt" % (best_epochs["dev"])), map_location=torch.device("cpu"))) 189 | model.increase_net(finetune_config) 190 | if not all([train_config[key] == finetune_config[key] for key in [ 191 | "max_npv", "max_npe", "max_npvl", "max_npel", "max_ngv", "max_nge", "max_ngvl", "max_ngel", "share_emb", "share_arch"]]): 192 | model.increase_input_size(finetune_config) 193 | if not all([train_config[key] == finetune_config[key] for key in [ 194 | "predict_net", "predict_net_hidden_dim", 195 | "predict_net_num_heads", "predict_net_mem_len", "predict_net_mem_init", "predict_net_recurrent_steps"]]): 196 | new_predict_net = model.create_predict_net(finetune_config["predict_net"], 197 | pattern_dim=model.predict_net.pattern_dim, graph_dim=model.predict_net.graph_dim, 198 | hidden_dim=finetune_config["predict_net_hidden_dim"], 199 | num_heads=finetune_config["predict_net_num_heads"], recurrent_steps=finetune_config["predict_net_recurrent_steps"], 200 | mem_len=finetune_config["predict_net_mem_len"], mem_init=finetune_config["predict_net_mem_init"]) 201 | del model.predict_net 202 | model.predict_net = new_predict_net 203 | model = model.to(device) 204 | torch.cuda.empty_cache() 205 | logger.info("load the model based on the dev set (epoch: {:0>3d})".format(best_epochs["dev"])) 206 | logger.info(model) 207 | logger.info("num of parameters: %d" % (sum(p.numel() for p in model.parameters() if p.requires_grad))) 208 | 209 | # load data 210 | os.makedirs(finetune_config["save_data_dir"], exist_ok=True) 211 | data_loaders = OrderedDict({"train": None, "dev": None, "test": None}) 212 | if all([os.path.exists(os.path.join(finetune_config["save_data_dir"], 213 | "%s_%s_dataset.pt" % ( 214 | data_type, "dgl" if finetune_config["model"] in ["RGCN", "RGIN", "RGIN"] else "edgeseq"))) for data_type in data_loaders]): 215 | 216 | logger.info("loading data from pt...") 217 | for data_type in data_loaders: 218 | if finetune_config["model"] in ["RGCN", "RGIN", "RGIN"]: 219 | dataset = GraphAdjDataset(list()) 220 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 221 | if data_type == "train": 222 | np.random.shuffle(dataset.data) 223 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*finetune_config["train_ratio"])] 224 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 225 | data_loader = DataLoader(dataset, 226 | batch_sampler=sampler, 227 | collate_fn=GraphAdjDataset.batchify, 228 | pin_memory=data_type=="train") 229 | else: 230 | dataset = EdgeSeqDataset(list()) 231 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 232 | if data_type == "train": 233 | np.random.shuffle(dataset.data) 234 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*finetune_config["train_ratio"])] 235 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 236 | data_loader = DataLoader(dataset, 237 | batch_sampler=sampler, 238 | collate_fn=EdgeSeqDataset.batchify, 239 | pin_memory=data_type=="train") 240 | data_loaders[data_type] = data_loader 241 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 242 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), finetune_config["batch_size"])) 243 | else: 244 | data = load_data(finetune_config["graph_dir"], finetune_config["pattern_dir"], finetune_config["metadata_dir"], num_workers=finetune_config["num_workers"]) 245 | logger.info("{}/{}/{} data loaded".format(len(data["train"]), len(data["dev"]), len(data["test"]))) 246 | for data_type, x in data.items(): 247 | if finetune_config["model"] in ["RGCN", "RGIN", "RGIN"]: 248 | if os.path.exists(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))): 249 | dataset = GraphAdjDataset(list()) 250 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 251 | else: 252 | dataset = GraphAdjDataset(x) 253 | dataset.save(os.path.join(finetune_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 254 | if data_type == "train": 255 | np.random.shuffle(dataset.data) 256 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*finetune_config["train_ratio"])] 257 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 258 | data_loader = DataLoader(dataset, 259 | batch_sampler=sampler, 260 | collate_fn=GraphAdjDataset.batchify, 261 | pin_memory=data_type=="train") 262 | else: 263 | if os.path.exists(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))): 264 | dataset = EdgeSeqDataset(list()) 265 | dataset.load(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 266 | else: 267 | dataset = EdgeSeqDataset(x) 268 | dataset.save(os.path.join(finetune_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 269 | if data_type == "train": 270 | np.random.shuffle(dataset.data) 271 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*finetune_config["train_ratio"])] 272 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=finetune_config["batch_size"], shuffle=data_type=="train", drop_last=False) 273 | data_loader = DataLoader(dataset, 274 | batch_sampler=sampler, 275 | collate_fn=EdgeSeqDataset.batchify, 276 | pin_memory=data_type=="train") 277 | data_loaders[data_type] = data_loader 278 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 279 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), finetune_config["batch_size"])) 280 | 281 | # optimizer and losses 282 | writer = SummaryWriter(save_model_dir) 283 | optimizer = torch.optim.AdamW(model.parameters(), lr=finetune_config["lr"], weight_decay=finetune_config["weight_decay"], amsgrad=True) 284 | optimizer.zero_grad() 285 | scheduler = get_linear_schedule_with_warmup(optimizer, 286 | len(data_loaders["train"]), train_config["epochs"]*len(data_loaders["train"]), min_percent=0.0001) 287 | 288 | best_reg_losses = {"train": INF, "dev": INF, "test": INF} 289 | best_reg_epochs = {"train": -1, "dev": -1, "test": -1} 290 | 291 | for epoch in range(finetune_config["epochs"]): 292 | for data_type, data_loader in data_loaders.items(): 293 | 294 | if data_type == "train": 295 | mean_reg_loss, mean_bp_loss = train(model, optimizer, scheduler, data_type, data_loader, device, 296 | finetune_config, epoch, logger=logger, writer=writer) 297 | torch.save(model.state_dict(), os.path.join(save_model_dir, 'epoch%d.pt' % (epoch))) 298 | else: 299 | mean_reg_loss, mean_bp_loss, evaluate_results = evaluate(model, data_type, data_loader, device, 300 | finetune_config, epoch, logger=logger, writer=writer) 301 | with open(os.path.join(save_model_dir, '%s%d.json' % (data_type, epoch)), "w") as f: 302 | json.dump(evaluate_results, f) 303 | 304 | if mean_reg_loss <= best_reg_losses[data_type]: 305 | best_reg_losses[data_type] = mean_reg_loss 306 | best_reg_epochs[data_type] = epoch 307 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, mean_reg_loss, epoch)) 308 | for data_type in data_loaders.keys(): 309 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, best_reg_losses[data_type], best_reg_epochs[data_type])) 310 | -------------------------------------------------------------------------------- /src/rgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import dgl.function as fn 6 | import copy 7 | from functools import partial 8 | from dgl.nn.pytorch.conv import RelGraphConv 9 | from basemodel import GraphAdjModel 10 | from utils import map_activation_str_to_layer, split_and_batchify_graph_feats 11 | 12 | 13 | class RGCN(GraphAdjModel): 14 | def __init__(self, config): 15 | super(RGCN, self).__init__(config) 16 | 17 | self.ignore_norm = config["rgcn_ignore_norm"] 18 | 19 | # create networks 20 | p_emb_dim, g_emb_dim = self.get_emb_dim() 21 | self.g_net, g_dim = self.create_net( 22 | name="graph", input_dim=g_emb_dim, hidden_dim=config["rgcn_hidden_dim"], 23 | num_layers=config["rgcn_graph_num_layers"], 24 | num_rels=self.max_ngel, num_bases=config["rgcn_num_bases"], regularizer=config["rgcn_regularizer"], 25 | act_func=self.act_func, dropout=self.dropout) 26 | self.p_net, p_dim = (self.g_net, g_dim) if self.share_arch else self.create_net( 27 | name="pattern", input_dim=p_emb_dim, hidden_dim=config["rgcn_hidden_dim"], 28 | num_layers=config["rgcn_pattern_num_layers"], 29 | num_rels=self.max_npel, num_bases=config["rgcn_num_bases"], regularizer=config["rgcn_regularizer"], 30 | act_func=self.act_func, dropout=self.dropout) 31 | 32 | # create predict layers 33 | if self.add_enc: 34 | p_enc_dim, g_enc_dim = self.get_enc_dim() 35 | p_dim += p_enc_dim 36 | g_dim += g_enc_dim 37 | if self.add_degree: 38 | p_dim += 1 39 | g_dim += 1 40 | self.predict_net = self.create_predict_net(config["predict_net"], 41 | pattern_dim=p_dim, graph_dim=g_dim, hidden_dim=config["predict_net_hidden_dim"], 42 | num_heads=config["predict_net_num_heads"], recurrent_steps=config["predict_net_recurrent_steps"], 43 | mem_len=config["predict_net_mem_len"], mem_init=config["predict_net_mem_init"]) 44 | 45 | def create_net(self, name, input_dim, **kw): 46 | num_layers = kw.get("num_layers", 1) 47 | hidden_dim = kw.get("hidden_dim", 64) 48 | num_rels = kw.get("num_rels", 1) 49 | num_bases = kw.get("num_bases", 8) 50 | regularizer = kw.get("regularizer", "basis") 51 | act_func = kw.get("act_func", "relu") 52 | dropout = kw.get("dropout", 0.0) 53 | 54 | rgcns = nn.ModuleList() 55 | for i in range(num_layers): 56 | rgcns.add_module("%s_rgc%d" % (name, i), RelGraphConv( 57 | in_feat=hidden_dim if i > 0 else input_dim, out_feat=hidden_dim, num_rels=num_rels, 58 | regularizer=regularizer, num_bases=num_bases, 59 | activation=map_activation_str_to_layer(act_func), self_loop=True, dropout=dropout)) 60 | 61 | for m in rgcns.modules(): 62 | if isinstance(m, RelGraphConv): 63 | if hasattr(m, "weight") and m.weight is not None: 64 | nn.init.normal_(m.weight, 0.0, 1/(hidden_dim)**0.5) 65 | if hasattr(m, "w_comp") and m.w_comp is not None: 66 | nn.init.normal_(m.w_comp, 0.0, 1/(hidden_dim)**0.5) 67 | if hasattr(m, "loop_weight") and m.loop_weight is not None: 68 | nn.init.normal_(m.loop_weight, 0.0, 1/(hidden_dim)**0.5) 69 | if hasattr(m, "h_bias") and m.h_bias is not None: 70 | nn.init.zeros_(m.h_bias) 71 | 72 | return rgcns, hidden_dim 73 | 74 | def increase_input_size(self, config): 75 | old_p_enc_dim, old_g_enc_dim = self.get_enc_dim() 76 | old_max_npel, old_max_ngel = self.max_npel, self.max_ngel 77 | super(RGCN, self).increase_input_size(config) 78 | new_p_enc_dim, new_g_enc_dim = self.get_enc_dim() 79 | new_max_npel, new_max_ngel = self.max_npel, self.max_ngel 80 | 81 | # increase networks 82 | if new_max_ngel != old_max_ngel: 83 | for g_rgcn in self.g_net: 84 | num_bases = g_rgcn.num_bases 85 | device = g_rgcn.weight.device 86 | regularizer = g_rgcn.regularizer 87 | if regularizer == "basis": 88 | if num_bases < old_max_ngel: 89 | new_w_comp = nn.Parameter( 90 | torch.zeros((new_max_ngel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 91 | with torch.no_grad(): 92 | new_w_comp[:old_max_ngel].data.copy_(g_rgcn.w_comp) 93 | else: 94 | new_w_comp = nn.Parameter( 95 | torch.zeros((new_max_ngel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 96 | with torch.no_grad(): 97 | ind = np.diag_indices(num_bases) 98 | new_w_comp[ind[0], ind[1]] = 1.0 99 | del g_rgcn.w_comp 100 | g_rgcn.w_comp = new_w_comp 101 | elif regularizer == "bdd": 102 | new_weight = nn.Parameter( 103 | torch.zeros((new_max_ngel, g_rgcn.weight.size(1)), 104 | dtype=torch.float32, device=device, requires_grad=True)) 105 | with torch.no_grad(): 106 | new_weight[:old_max_ngel].data.copy_(g_rgcn.weight) 107 | del g_rgcn.weight 108 | g_rgcn.weight = new_weight 109 | else: 110 | raise NotImplementedError 111 | if self.share_arch: 112 | del self.p_net 113 | self.p_net = self.g_net 114 | elif new_max_npel != old_max_npel: 115 | for p_rgcn in self.p_net: 116 | num_bases = p_rgcn.num_bases 117 | device = p_rgcn.weight.device 118 | regularizer = p_rgcn.regularizer 119 | if regularizer == "basis": 120 | if num_bases < old_max_npel: 121 | new_w_comp = nn.Parameter( 122 | torch.zeros((new_max_npel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 123 | with torch.no_grad(): 124 | new_w_comp[:old_max_npel].data.copy_(p_rgcn.w_comp) 125 | else: 126 | new_w_comp = nn.Parameter( 127 | torch.zeros((new_max_npel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 128 | with torch.no_grad(): 129 | ind = np.diap_indices(num_bases) 130 | new_w_comp[ind[0], ind[1]] = 1.0 131 | del p_rgcn.w_comp 132 | p_rgcn.w_comp = new_w_comp 133 | elif regularizer == "bdd": 134 | new_weight = nn.Parameter( 135 | torch.zeros((max_npel, p_rgcn.weight.size(1)), dtype=torch.float32, device=device, requires_grad=True)) 136 | with torch.no_grad(): 137 | new_weight[:old_max_npel].data.copy_(p_rgcn.weight) 138 | del p_rgcn.weight 139 | p_rgcn.weight = new_weight 140 | else: 141 | raise NotImplementedError 142 | 143 | # increase predict network 144 | if self.add_enc and (new_g_enc_dim != old_g_enc_dim or new_p_enc_dim != old_p_enc_dim): 145 | self.predict_net.increase_input_size( 146 | self.predict_net.pattern_dim+new_p_enc_dim-old_p_enc_dim, 147 | self.predict_net.graph_dim+new_g_enc_dim-old_g_enc_dim) 148 | 149 | def increase_net(self, config): 150 | p_emb_dim, g_emb_dim = self.get_emb_dim() 151 | g_net, g_dim = self.create_net( 152 | name="graph", input_dim=g_emb_dim, hidden_dim=config["rgcn_hidden_dim"], 153 | num_layers=config["rgcn_graph_num_layers"], 154 | num_rels=self.max_ngel, num_bases=config["rgcn_num_bases"], regularizer=config["rgcn_regularizer"], 155 | act_func=self.act_func, dropout=self.dropout) 156 | assert len(g_net) >= len(self.g_net) 157 | with torch.no_grad(): 158 | for old_g_rgcn, new_g_rgcn in zip(self.g_net, g_net): 159 | new_g_rgcn.load_state_dict(old_g_rgcn.state_dict()) 160 | del self.g_net 161 | self.g_net = g_net 162 | 163 | if self.share_arch: 164 | self.p_net = self.g_net 165 | else: 166 | p_net, p_dim = self.create_net( 167 | name="pattern", input_dim=p_emb_dim, hidden_dim=config["rgcn_hidden_dim"], 168 | num_layers=config["rgcn_pattern_num_layers"], 169 | num_rels=self.max_npel, num_bases=config["rgcn_num_bases"], regularizer=config["rgcn_regularizer"], 170 | act_func=self.act_func, dropout=self.dropout) 171 | assert len(p_net) >= len(self.p_net) 172 | with torch.no_grad(): 173 | for old_p_rgcn, new_p_rgcn in zip(self.p_net, p_net): 174 | new_p_rgcn.load_state_dict(old_p_rgcn.state_dict()) 175 | del self.p_net 176 | self.p_net = p_net 177 | 178 | def forward(self, pattern, pattern_len, graph, graph_len): 179 | bsz = pattern_len.size(0) 180 | 181 | gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len) 182 | zero_mask = (gate == 0) if gate is not None else None 183 | pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph, graph_len) 184 | if zero_mask is not None: 185 | graph_emb.masked_fill_(zero_mask, 0.0) 186 | 187 | pattern_output = pattern_emb 188 | if self.ignore_norm: 189 | pattern_norm = None 190 | else: 191 | if "norm" in pattern.edata: 192 | pattern_norm = pattern.edata["norm"] 193 | else: 194 | pattern.apply_edges(lambda e: {"norm": 1.0/e.dst["indeg"]}) 195 | pattern.edata["norm"].masked_fill_(torch.isinf(pattern.edata["norm"]), 0.0) 196 | pattern.edata["norm"] = pattern.edata["norm"].unsqueeze(-1) 197 | pattern_norm = pattern.edata["norm"] 198 | for p_rgcn in self.p_net: 199 | o = p_rgcn(pattern, pattern_output, pattern.edata["label"], pattern_norm) 200 | pattern_output = o + pattern_output 201 | 202 | graph_output = graph_emb 203 | if self.ignore_norm: 204 | graph_norm = None 205 | else: 206 | if "norm" in graph.edata: 207 | graph_norm = graph.edata["norm"] 208 | else: 209 | graph.apply_edges(lambda e: {"norm": 1.0/e.dst["indeg"]}) 210 | graph.edata["norm"].masked_fill_(torch.isinf(graph.edata["norm"]), 0.0) 211 | graph.edata["norm"] = graph.edata["norm"].unsqueeze(-1) 212 | graph_norm = graph.edata["norm"] 213 | for g_rgcn in self.g_net: 214 | o = g_rgcn(graph, graph_output, graph.edata["label"], graph_norm) 215 | graph_output = o + graph_output 216 | if zero_mask is not None: 217 | graph_output.masked_fill_(zero_mask, 0.0) 218 | 219 | if self.add_enc and self.add_degree: 220 | pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len) 221 | if zero_mask is not None: 222 | graph_enc.masked_fill_(zero_mask, 0.0) 223 | pattern_output = torch.cat([pattern_enc, pattern_output, pattern.ndata["indeg"].unsqueeze(-1)], dim=1) 224 | graph_output = torch.cat([graph_enc, graph_output, graph.ndata["indeg"].unsqueeze(-1)], dim=1) 225 | elif self.add_enc: 226 | pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len) 227 | if zero_mask is not None: 228 | graph_enc.masked_fill_(zero_mask, 0.0) 229 | pattern_output = torch.cat([pattern_enc, pattern_output], dim=1) 230 | graph_output = torch.cat([graph_enc, graph_output], dim=1) 231 | elif self.add_degree: 232 | pattern_output = torch.cat([pattern_output, pattern.ndata["indeg"].unsqueeze(-1)], dim=1) 233 | graph_output = torch.cat([graph_output, graph.ndata["indeg"].unsqueeze(-1)], dim=1) 234 | 235 | pred = self.predict_net( 236 | split_and_batchify_graph_feats(pattern_output, pattern_len)[0], pattern_len, 237 | split_and_batchify_graph_feats(graph_output, graph_len)[0], graph_len) 238 | 239 | return pred 240 | -------------------------------------------------------------------------------- /src/rgin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | import dgl.function as fn 6 | import copy 7 | import numpy as np 8 | from functools import partial 9 | from dgl.nn.pytorch.conv import RelGraphConv 10 | from basemodel import GraphAdjModel 11 | from utils import map_activation_str_to_layer, split_and_batchify_graph_feats 12 | 13 | 14 | class RGINLayer(nn.Module): 15 | def __init__(self, in_feat, out_feat, num_rels, regularizer="basis", num_bases=None, act_func="relu", dropout=0.0): 16 | super(RGINLayer, self).__init__() 17 | self.rgc_layer = RelGraphConv( 18 | in_feat=in_feat, out_feat=out_feat, num_rels=num_rels, 19 | regularizer=regularizer, num_bases=num_bases, 20 | activation=None, self_loop=True, dropout=0.0) 21 | self.mlp = nn.Sequential( 22 | nn.Linear(out_feat, out_feat), 23 | # nn.BatchNorm1d(out_feat), 24 | map_activation_str_to_layer(act_func), 25 | nn.Linear(out_feat, out_feat), 26 | map_activation_str_to_layer(act_func)) 27 | self.drop = nn.Dropout(dropout) 28 | 29 | # init 30 | if hasattr(self.rgc_layer, "weight") and self.rgc_layer.weight is not None: 31 | nn.init.normal_(self.rgc_layer.weight, 0.0, 1/(out_feat)**0.5) 32 | if hasattr(self.rgc_layer, "w_comp") and self.rgc_layer.w_comp is not None: 33 | nn.init.normal_(self.rgc_layer.w_comp, 0.0, 1/(out_feat)**0.5) 34 | if hasattr(self.rgc_layer, "loop_weight") and self.rgc_layer.loop_weight is not None: 35 | nn.init.normal_(self.rgc_layer.loop_weight, 0.0, 1/(out_feat)**0.5) 36 | if hasattr(self.rgc_layer, "h_bias") and self.rgc_layer.h_bias is not None: 37 | nn.init.zeros_(self.rgc_layer.h_bias) 38 | for m in self.mlp.modules(): 39 | if isinstance(m, nn.Linear): 40 | nn.init.normal_(m.weight, 0.0, 1/(out_feat)**0.5) 41 | if hasattr(m, "bias") and m.bias is not None: 42 | nn.init.zeros_(m.bias) 43 | elif isinstance(m, nn.BatchNorm1d): 44 | nn.init.ones_(m.weight) 45 | nn.init.zeros_(m.bias) 46 | 47 | def forward(self, g, x, etypes): 48 | g = self.rgc_layer(g, x, etypes, norm=None) 49 | g = self.mlp(g) 50 | g = self.drop(g) 51 | return g 52 | 53 | 54 | class RGIN(GraphAdjModel): 55 | def __init__(self, config): 56 | super(RGIN, self).__init__(config) 57 | 58 | # create networks 59 | p_emb_dim, g_emb_dim = self.get_emb_dim() 60 | self.g_net, g_dim = self.create_net( 61 | name="graph", input_dim=g_emb_dim, hidden_dim=config["rgin_hidden_dim"], 62 | num_layers=config["rgin_graph_num_layers"], 63 | num_rels=self.max_ngel, num_bases=config["rgin_num_bases"], regularizer=config["rgin_regularizer"], 64 | act_func=self.act_func, dropout=self.dropout) 65 | self.p_net, p_dim = (self.g_net, g_dim) if self.share_arch else self.create_net( 66 | name="pattern", input_dim=p_emb_dim, hidden_dim=config["rgin_hidden_dim"], 67 | num_layers=config["rgin_pattern_num_layers"], 68 | num_rels=self.max_npel, num_bases=config["rgin_num_bases"], regularizer=config["rgin_regularizer"], 69 | act_func=self.act_func, dropout=self.dropout) 70 | 71 | if self.add_enc: 72 | p_enc_dim, g_enc_dim = self.get_enc_dim() 73 | p_dim += p_enc_dim 74 | g_dim += g_enc_dim 75 | if self.add_degree: 76 | p_dim += 1 77 | g_dim += 1 78 | self.predict_net = self.create_predict_net(config["predict_net"], 79 | pattern_dim=p_dim, graph_dim=g_dim, hidden_dim=config["predict_net_hidden_dim"], 80 | num_heads=config["predict_net_num_heads"], recurrent_steps=config["predict_net_recurrent_steps"], 81 | mem_len=config["predict_net_mem_len"], mem_init=config["predict_net_mem_init"]) 82 | 83 | def create_net(self, name, input_dim, **kw): 84 | num_layers = kw.get("num_layers", 1) 85 | hidden_dim = kw.get("hidden_dim", 64) 86 | num_rels = kw.get("num_rels", 1) 87 | num_bases = kw.get("num_bases", 8) 88 | regularizer = kw.get("regularizer", "basis") 89 | act_func = kw.get("act_func", "relu") 90 | dropout = kw.get("dropout", 0.0) 91 | 92 | rgins = nn.ModuleList() 93 | for i in range(num_layers): 94 | rgins.add_module("%s_rgi%d" % (name, i), RGINLayer( 95 | in_feat=hidden_dim if i > 0 else input_dim, out_feat=hidden_dim, num_rels=num_rels, 96 | regularizer=regularizer, num_bases=num_bases, 97 | act_func=act_func, dropout=dropout)) 98 | 99 | return rgins, hidden_dim 100 | 101 | def increase_input_size(self, config): 102 | old_p_enc_dim, old_g_enc_dim = self.get_enc_dim() 103 | old_max_npel, old_max_ngel = self.max_npel, self.max_ngel 104 | super(RGIN, self).increase_input_size(config) 105 | new_p_enc_dim, new_g_enc_dim = self.get_enc_dim() 106 | new_max_npel, new_max_ngel = self.max_npel, self.max_ngel 107 | 108 | # increase networks 109 | if new_max_ngel != old_max_ngel: 110 | for g_rgin in self.g_net: 111 | num_bases = g_rgin.rgc_layer.num_bases 112 | device = g_rgin.rgc_layer.weight.device 113 | regularizer = g_rgin.rgc_layer.regularizer 114 | if regularizer == "basis": 115 | if num_bases < old_max_ngel: 116 | new_w_comp = nn.Parameter( 117 | torch.zeros((new_max_ngel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 118 | with torch.no_grad(): 119 | new_w_comp[:old_max_ngel].data.copy_(g_rgin.rgc_layer.w_comp) 120 | else: 121 | new_w_comp = nn.Parameter( 122 | torch.zeros((new_max_ngel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 123 | with torch.no_grad(): 124 | ind = np.diag_indices(num_bases) 125 | new_w_comp[ind[0], ind[1]] = 1.0 126 | del g_rgin.rgc_layer.w_comp 127 | g_rgin.rgc_layer.w_comp = new_w_comp 128 | elif regularizer == "bdd": 129 | new_weight = nn.Parameter( 130 | torch.zeros((new_max_ngel, g_rgin.rgc_layer.weight.size(1)), 131 | dtype=torch.float32, device=device, requires_grad=True)) 132 | with torch.no_grad(): 133 | new_weight[:old_max_ngel].data.copy_(g_rgin.rgc_layer.weight) 134 | del g_rgin.rgc_layer.weight 135 | g_rgin.rgc_layer.weight = new_weight 136 | else: 137 | raise NotImplementedError 138 | if self.share_arch: 139 | del self.p_net 140 | self.p_net = self.g_net 141 | elif new_max_npel != old_max_npel: 142 | for p_rgin in self.p_net: 143 | num_bases = p_rgin.rgc_layer.num_bases 144 | device = p_rgin.rgc_layer.weight.device 145 | regularizer = p_rgin.rgc_layer.regularizer 146 | if regularizer == "basis": 147 | if num_bases < old_max_npel: 148 | new_w_comp = nn.Parameter( 149 | torch.zeros((new_max_npel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 150 | with torch.no_grad(): 151 | new_w_comp[:old_max_npel].data.copy_(p_rgin.rgc_layer.w_comp) 152 | else: 153 | new_w_comp = nn.Parameter( 154 | torch.zeros((new_max_npel, num_bases), dtype=torch.float32, device=device, requires_grad=True)) 155 | with torch.no_grad(): 156 | ind = np.diap_indices(num_bases) 157 | new_w_comp[ind[0], ind[1]] = 1.0 158 | del p_rgin.rgc_layer.w_comp 159 | p_rgin.rgc_layer.w_comp = new_w_comp 160 | elif regularizer == "bdd": 161 | new_weight = nn.Parameter( 162 | torch.zeros((max_npel, p_rgin.rgc_layer.weight.size(1)), dtype=torch.float32, device=device, requires_grad=True)) 163 | with torch.no_grad(): 164 | new_weight[:old_max_npel].data.copy_(p_rgin.rgc_layer.weight) 165 | del p_rgin.rgc_layer.weight 166 | p_rgin.rgc_layer.weight = new_weight 167 | else: 168 | raise NotImplementedError 169 | 170 | # increase predict network 171 | if self.add_enc and (new_g_enc_dim != old_g_enc_dim or new_p_enc_dim != old_p_enc_dim): 172 | self.predict_net.increase_input_size( 173 | self.predict_net.pattern_dim+new_p_enc_dim-old_p_enc_dim, 174 | self.predict_net.graph_dim+new_g_enc_dim-old_g_enc_dim) 175 | 176 | def increase_net(self, config): 177 | p_emb_dim, g_emb_dim = self.get_emb_dim() 178 | g_net, g_dim = self.create_net( 179 | name="graph", input_dim=g_emb_dim, hidden_dim=config["rgin_hidden_dim"], 180 | num_layers=config["rgin_graph_num_layers"], 181 | num_rels=self.max_ngel, num_bases=config["rgin_num_bases"], regularizer=config["rgin_regularizer"], 182 | act_func=self.act_func, dropout=self.dropout) 183 | assert len(g_net) >= len(self.g_net) 184 | with torch.no_grad(): 185 | for old_g_rgin, new_g_rgin in zip(self.g_net, g_net): 186 | new_g_rgin.load_state_dict(old_g_rgin.state_dict()) 187 | del self.g_net 188 | self.g_net = g_net 189 | 190 | if self.share_arch: 191 | self.p_net = self.g_net 192 | else: 193 | p_net, p_dim = self.create_net( 194 | name="pattern", input_dim=p_emb_dim, hidden_dim=config["rgin_hidden_dim"], 195 | num_layers=config["rgin_pattern_num_layers"], 196 | num_rels=self.max_npel, num_bases=config["rgin_num_bases"], regularizer=config["rgin_regularizer"], 197 | act_func=self.act_func, dropout=self.dropout) 198 | assert len(p_net) >= len(self.p_net) 199 | with torch.no_grad(): 200 | for old_p_rgin, new_p_rgin in zip(self.p_net, p_net): 201 | new_p_rgin.load_state_dict(old_p_rgin.state_dict()) 202 | del self.p_net 203 | self.p_net = p_net 204 | 205 | def forward(self, pattern, pattern_len, graph, graph_len): 206 | bsz = pattern_len.size(0) 207 | 208 | gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len) 209 | zero_mask = (gate == 0) if gate is not None else None 210 | pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph, graph_len) 211 | if zero_mask is not None: 212 | graph_emb.masked_fill_(zero_mask, 0.0) 213 | 214 | pattern_output = pattern_emb 215 | for p_rgin in self.p_net: 216 | o = p_rgin(pattern, pattern_output, pattern.edata["label"]) 217 | pattern_output = o + pattern_output 218 | 219 | graph_output = graph_emb 220 | for g_rgin in self.g_net: 221 | o = g_rgin(graph, graph_output, graph.edata["label"]) 222 | graph_output = o + graph_output 223 | if zero_mask is not None: 224 | graph_output.masked_fill_(zero_mask, 0.0) 225 | 226 | if self.add_enc and self.add_degree: 227 | pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len) 228 | if zero_mask is not None: 229 | graph_enc.masked_fill_(zero_mask, 0.0) 230 | pattern_output = torch.cat([pattern_enc, pattern_output, pattern.ndata["indeg"].unsqueeze(-1)], dim=1) 231 | graph_output = torch.cat([graph_enc, graph_output, graph.ndata["indeg"].unsqueeze(-1)], dim=1) 232 | elif self.add_enc: 233 | pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len) 234 | if zero_mask is not None: 235 | graph_enc.masked_fill_(zero_mask, 0.0) 236 | pattern_output = torch.cat([pattern_enc, pattern_output], dim=1) 237 | graph_output = torch.cat([graph_enc, graph_output], dim=1) 238 | elif self.add_degree: 239 | pattern_output = torch.cat([pattern_output, pattern.ndata["indeg"].unsqueeze(-1)], dim=1) 240 | graph_output = torch.cat([graph_output, graph.ndata["indeg"].unsqueeze(-1)], dim=1) 241 | 242 | pred = self.predict_net( 243 | split_and_batchify_graph_feats(pattern_output, pattern_len)[0], pattern_len, 244 | split_and_batchify_graph_feats(graph_output, graph_len)[0], graph_len) 245 | 246 | return pred 247 | -------------------------------------------------------------------------------- /src/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | import math 6 | from utils import segment_length 7 | from basemodel import EdgeSeqModel 8 | from utils import map_activation_str_to_layer, batch_convert_len_to_mask 9 | 10 | 11 | class RNNLayer(nn.Module): 12 | def __init__(self, input_dim, hidden_dim, rnn_type, bidirectional, dropout): 13 | super(RNNLayer, self).__init__() 14 | if rnn_type == "GRU": 15 | rnn_layer = nn.GRU 16 | elif rnn_type == "LSTM": 17 | rnn_layer = nn.LSTM 18 | else: 19 | raise NotImplementedError("Currently, %s is not supported!" % (rnn_type)) 20 | self.rnn = rnn_layer(input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional) 21 | self.drop = nn.Dropout(dropout) 22 | 23 | # init 24 | for layer_weights in self.rnn._all_weights: 25 | for w in layer_weights: 26 | if "weight" in w: 27 | weight = getattr(self.rnn, w) 28 | nn.init.orthogonal_(weight) 29 | elif "bias" in w: 30 | bias = getattr(self.rnn, w) 31 | if bias is not None: 32 | nn.init.zeros_(bias) 33 | 34 | def forward(self, x): 35 | x, h = self.rnn(x) 36 | x = self.drop(x) 37 | return x 38 | 39 | class RNN(EdgeSeqModel): 40 | def __init__(self, config): 41 | super(RNN, self).__init__(config) 42 | 43 | # create networks 44 | p_emb_dim, g_emb_dim = self.get_emb_dim() 45 | self.g_net, g_dim = self.create_net( 46 | name="graph", input_dim=g_emb_dim, hidden_dim=config["rnn_hidden_dim"], 47 | num_layers=config["rnn_graph_num_layers"], 48 | rnn_type=config["rnn_type"], bidirectional=config["rnn_bidirectional"], 49 | dropout=self.dropout) 50 | self.p_net, p_dim = (self.g_net, g_dim) if self.share_arch else self.create_net( 51 | name="pattern", input_dim=p_emb_dim, hidden_dim=config["rnn_hidden_dim"], 52 | num_layers=config["rnn_pattern_num_layers"], 53 | rnn_type=config["rnn_type"], bidirectional=config["rnn_bidirectional"], 54 | dropout=self.dropout) 55 | 56 | # create predict layers 57 | if self.add_enc: 58 | p_enc_dim, g_enc_dim = self.get_enc_dim() 59 | p_dim += p_enc_dim 60 | g_dim += g_enc_dim 61 | self.predict_net = self.create_predict_net(config["predict_net"], 62 | pattern_dim=p_dim, graph_dim=g_dim, hidden_dim=config["predict_net_hidden_dim"], 63 | num_heads=config["predict_net_num_heads"], recurrent_steps=config["predict_net_recurrent_steps"], 64 | mem_len=config["predict_net_mem_len"], mem_init=config["predict_net_mem_init"]) 65 | 66 | def create_net(self, name, input_dim, **kw): 67 | num_layers = kw.get("num_layers", 3) 68 | hidden_dim = kw.get("hidden_dim", 64) 69 | rnn_type = kw.get("rnn_type", "LSTM") 70 | bidirectional = kw.get("bidirectional", "False") 71 | dropout = kw.get("dropout", 0.0) 72 | 73 | num_features = hidden_dim*2 if bidirectional else hidden_dim 74 | rnns = nn.ModuleList() 75 | for i in range(num_layers): 76 | rnns.add_module("%s_rnn%d" % (name, i),RNNLayer( 77 | input_dim=input_dim if i == 0 else num_features, hidden_dim=hidden_dim, 78 | rnn_type=rnn_type, bidirectional=bidirectional, 79 | dropout=dropout)) 80 | 81 | return rnns, num_features 82 | 83 | def increase_input_size(self, config): 84 | old_p_enc_dim, old_g_enc_dim = self.get_enc_dim() 85 | super(RNN, self).increase_input_size(config) 86 | new_p_enc_dim, new_g_enc_dim = self.get_enc_dim() 87 | 88 | # increase predict network 89 | if self.add_enc and (new_g_enc_dim != old_g_enc_dim or new_p_enc_dim != old_p_enc_dim): 90 | self.predict_net.increase_input_size( 91 | self.predict_net.pattern_dim+new_p_enc_dim-old_p_enc_dim, 92 | self.predict_net.graph_dim+new_g_enc_dim-old_g_enc_dim) 93 | 94 | def increase_net(self, config): 95 | p_emb_dim, g_emb_dim = self.get_emb_dim() 96 | g_net, g_dim = self.create_net( 97 | name="graph", input_dim=g_emb_dim, hidden_dim=config["rnn_hidden_dim"], 98 | num_layers=config["rnn_graph_num_layers"], 99 | rnn_type=config["rnn_type"], bidirectional=config["rnn_bidirectional"], 100 | dropout=self.dropout) 101 | assert len(g_net) >= len(self.g_net) 102 | with torch.no_grad(): 103 | for old_g_rnn, new_g_rnn in zip(self.g_net, g_net): 104 | new_g_rnn.load_state_dict(old_g_rnn.state_dict()) 105 | del self.g_net 106 | self.g_net = g_net 107 | 108 | if self.share_arch: 109 | self.p_net = self.g_net 110 | else: 111 | p_net, p_dim = self.create_net( 112 | name="pattern", input_dim=p_emb_dim, hidden_dim=config["rnn_hidden_dim"], 113 | num_layers=config["rnn_graph_num_layers"], 114 | rnn_type=config["rnn_type"], bidirectional=config["rnn_bidirectional"], 115 | dropout=self.dropout) 116 | assert len(p_net) >= len(self.p_net) 117 | with torch.no_grad(): 118 | for old_p_rnn, new_p_rnn in zip(self.p_net, p_net): 119 | new_p_rnn.load_state_dict(old_p_rnn.state_dict()) 120 | del self.p_net 121 | self.p_net = p_net 122 | 123 | def forward(self, pattern, pattern_len, graph, graph_len): 124 | bsz = pattern_len.size(0) 125 | 126 | gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len) 127 | zero_mask = (gate == 0).unsqueeze(-1) if gate is not None else None 128 | pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph, graph_len) 129 | if zero_mask is not None: 130 | graph_emb.masked_fill_(zero_mask, 0.0) 131 | 132 | pattern_output = pattern_emb 133 | for p_rnn in self.p_net: 134 | o = p_rnn(pattern_output) 135 | pattern_output = o + pattern_output 136 | pattern_mask = (batch_convert_len_to_mask(pattern_len)==0).unsqueeze(-1) 137 | pattern_output.masked_fill_(pattern_mask, 0.0) 138 | 139 | graph_output = graph_emb 140 | for g_rnn in self.g_net: 141 | o = g_rnn(graph_output) 142 | graph_output = o + graph_output 143 | if zero_mask is not None: 144 | graph_output.masked_fill_(zero_mask, 0.0) 145 | graph_mask = (batch_convert_len_to_mask(graph_len)==0).unsqueeze(-1) 146 | graph_output.masked_fill_(graph_mask, 0.0) 147 | 148 | if self.add_enc: 149 | pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len) 150 | if zero_mask is not None: 151 | graph_enc.masked_fill_(zero_mask, 0.0) 152 | pattern_output = torch.cat([pattern_enc, pattern_output], dim=2) 153 | graph_output = torch.cat([graph_enc, graph_output], dim=2) 154 | 155 | pred = self.predict_net(pattern_output, pattern_len, graph_output, graph_len) 156 | 157 | return pred 158 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import dgl 5 | import logging 6 | import datetime 7 | import math 8 | import sys 9 | import gc 10 | import json 11 | import time 12 | import torch.nn.functional as F 13 | import warnings 14 | from functools import partial 15 | from collections import OrderedDict 16 | from torch.utils.data import DataLoader 17 | try: 18 | from torch.utils.tensorboard import SummaryWriter 19 | except BaseException as e: 20 | from tensorboardX import SummaryWriter 21 | from dataset import Sampler, EdgeSeqDataset, GraphAdjDataset 22 | from utils import anneal_fn, get_enc_len, load_data, get_linear_schedule_with_warmup 23 | from cnn import CNN 24 | from rnn import RNN 25 | from txl import TXL 26 | from rgcn import RGCN 27 | from rgin import RGIN 28 | 29 | warnings.filterwarnings("ignore") 30 | INF = float("inf") 31 | 32 | train_config = { 33 | "max_npv": 8, # max_number_pattern_vertices: 8, 16, 32 34 | "max_npe": 8, # max_number_pattern_edges: 8, 16, 32 35 | "max_npvl": 8, # max_number_pattern_vertex_labels: 8, 16, 32 36 | "max_npel": 8, # max_number_pattern_edge_labels: 8, 16, 32 37 | 38 | "max_ngv": 64, # max_number_graph_vertices: 64, 512,4096 39 | "max_nge": 256, # max_number_graph_edges: 256, 2048, 16384 40 | "max_ngvl": 16, # max_number_graph_vertex_labels: 16, 64, 256 41 | "max_ngel": 16, # max_number_graph_edge_labels: 16, 64, 256 42 | 43 | "base": 2, 44 | 45 | "gpu_id": -1, 46 | "num_workers": 12, 47 | 48 | "epochs": 100, 49 | "batch_size": 512, 50 | "update_every": 1, # actual batch_sizer = batch_size * update_every 51 | "print_every": 100, 52 | "init_emb": "Equivariant", # None, Orthogonal, Normal, Equivariant 53 | "share_emb": True, # sharing embedding requires the same vector length 54 | "share_arch": True, # sharing architectures 55 | "dropout": 0.2, 56 | "dropatt": 0.2, 57 | 58 | "reg_loss": "MSE", # MAE, MSEl 59 | "bp_loss": "MSE", # MAE, MSE 60 | "bp_loss_slp": "anneal_cosine$1.0$0.01", # 0, 0.01, logistic$1.0$0.01, linear$1.0$0.01, cosine$1.0$0.01, 61 | # cyclical_logistic$1.0$0.01, cyclical_linear$1.0$0.01, cyclical_cosine$1.0$0.01 62 | # anneal_logistic$1.0$0.01, anneal_linear$1.0$0.01, anneal_cosine$1.0$0.01 63 | "lr": 0.001, 64 | "weight_decay": 0.00001, 65 | "max_grad_norm": 8, 66 | 67 | "model" : "CNN", # CNN, RNN, TXL, RGCN, RGIN, RSIN 68 | 69 | "emb_dim": 128, 70 | "activation_function": "leaky_relu", # sigmoid, softmax, tanh, relu, leaky_relu, prelu, gelu 71 | 72 | "filter_net": "MaxGatedFilterNet", # None, MaxGatedFilterNet 73 | "predict_net": "SumPredictNet", # MeanPredictNet, SumPredictNet, MaxPredictNet, 74 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 75 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 76 | # DIAMNet 77 | "predict_net_add_enc": True, 78 | "predict_net_add_degree": True, 79 | "predict_net_hidden_dim": 128, 80 | "predict_net_num_heads": 4, 81 | "predict_net_mem_len": 4, 82 | "predict_net_mem_init": "mean", # mean, sum, max, attn, circular_mean, circular_sum, circular_max, circular_attn, lstm 83 | "predict_net_recurrent_steps": 3, 84 | 85 | "cnn_hidden_dim": 128, 86 | "cnn_conv_channels": (128, 128, 128), 87 | "cnn_conv_kernel_sizes": (2, 3, 4), 88 | "cnn_conv_strides": (1, 1, 1), 89 | "cnn_conv_paddings": (1, 1, 1), 90 | "cnn_pool_kernel_sizes": (2, 3, 4), 91 | "cnn_pool_strides": (1, 1, 1), 92 | "cnn_pool_paddings": (0, 1, 2), 93 | 94 | "rnn_type": "LSTM", # GRU, LSTM 95 | "rnn_bidirectional": False, 96 | "rnn_graph_num_layers": 3, 97 | "rnn_pattern_num_layers": 3, 98 | "rnn_hidden_dim": 128, 99 | 100 | "txl_graph_num_layers": 3, 101 | "txl_pattern_num_layers": 3, 102 | "txl_d_model": 128, 103 | "txl_d_inner": 128, 104 | "txl_n_head": 4, 105 | "txl_d_head": 4, 106 | "txl_pre_lnorm": True, 107 | "txl_tgt_len": 64, 108 | "txl_ext_len": 0, # useless in current settings 109 | "txl_mem_len": 64, 110 | "txl_clamp_len": -1, # max positional embedding index 111 | "txl_attn_type": 0, # 0 for Dai et al, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 112 | "txl_same_len": False, 113 | 114 | "rgcn_num_bases": 8, 115 | "rgcn_regularizer": "bdd", # basis, bdd 116 | "rgcn_graph_num_layers": 3, 117 | "rgcn_pattern_num_layers": 3, 118 | "rgcn_hidden_dim": 128, 119 | "rgcn_ignore_norm": False, # ignorm=True -> RGCN-SUM 120 | 121 | "rgin_num_bases": 8, 122 | "rgin_regularizer": "bdd", # basis, bdd 123 | "rgin_graph_num_layers": 3, 124 | "rgin_pattern_num_layers": 3, 125 | "rgin_hidden_dim": 128, 126 | 127 | "pattern_dir": "../data/debug/patterns", 128 | "graph_dir": "../data/debug/graphs", 129 | "metadata_dir": "../data/debug/metadata", 130 | "save_data_dir": "../data/debug", 131 | "save_model_dir": "../dumps/debug", 132 | } 133 | 134 | def train(model, optimizer, scheduler, data_type, data_loader, device, config, epoch, logger=None, writer=None): 135 | epoch_step = len(data_loader) 136 | total_step = config["epochs"] * epoch_step 137 | total_reg_loss = 0 138 | total_bp_loss = 0 139 | total_cnt = 1e-6 140 | 141 | if config["reg_loss"] == "MAE": 142 | reg_crit = lambda pred, target: F.l1_loss(F.relu(pred), target) 143 | elif config["reg_loss"] == "MSE": 144 | reg_crit = lambda pred, target: F.mse_loss(F.relu(pred), target) 145 | elif config["reg_loss"] == "SMSE": 146 | reg_crit = lambda pred, target: F.smooth_l1_loss(F.relu(pred), target) 147 | else: 148 | raise NotImplementedError 149 | 150 | if config["bp_loss"] == "MAE": 151 | bp_crit = lambda pred, target, neg_slp: F.l1_loss(F.leaky_relu(pred, neg_slp), target) 152 | elif config["bp_loss"] == "MSE": 153 | bp_crit = lambda pred, target, neg_slp: F.mse_loss(F.leaky_relu(pred, neg_slp), target) 154 | elif config["bp_loss"] == "SMSE": 155 | bp_crit = lambda pred, target, neg_slp: F.smooth_l1_loss(F.leaky_relu(pred, neg_slp), target) 156 | else: 157 | raise NotImplementedError 158 | 159 | model.train() 160 | 161 | for batch_id, batch in enumerate(data_loader): 162 | ids, pattern, pattern_len, graph, graph_len, counts = batch 163 | cnt = counts.shape[0] 164 | total_cnt += cnt 165 | 166 | pattern.to(device) 167 | graph.to(device) 168 | pattern_len, graph_len, counts = pattern_len.to(device), graph_len.to(device), counts.to(device) 169 | 170 | pred = model(pattern, pattern_len, graph, graph_len) 171 | 172 | reg_loss = reg_crit(pred, counts) 173 | 174 | if isinstance(config["bp_loss_slp"], (int, float)): 175 | neg_slp = float(config["bp_loss_slp"]) 176 | else: 177 | bp_loss_slp, l0, l1 = config["bp_loss_slp"].rsplit("$", 3) 178 | neg_slp = anneal_fn(bp_loss_slp, batch_id+epoch*epoch_step, T=total_step//4, lambda0=float(l0), lambda1=float(l1)) 179 | bp_loss = bp_crit(pred, counts, neg_slp) 180 | 181 | reg_loss_item = reg_loss.item() 182 | bp_loss_item = bp_loss.item() 183 | total_reg_loss += reg_loss_item * cnt 184 | total_bp_loss += bp_loss_item * cnt 185 | 186 | if writer: 187 | writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item, epoch*epoch_step+batch_id) 188 | writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, epoch*epoch_step+batch_id) 189 | 190 | if logger and (batch_id % config["print_every"] == 0 or batch_id == epoch_step-1): 191 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\tbatch: {:0>5d}/{:0>5d}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}\tground: {:.3f}\tpredict: {:.3f}".format( 192 | epoch, config["epochs"], data_type, batch_id, epoch_step, 193 | reg_loss_item, bp_loss_item, 194 | counts[0].item(), pred[0].item())) 195 | 196 | bp_loss.backward() 197 | if (config["update_every"] < 2 or batch_id % config["update_every"] == 0 or batch_id == epoch_step-1): 198 | if config["max_grad_norm"] > 0: 199 | torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) 200 | if scheduler is not None: 201 | scheduler.step(epoch*epoch_step+batch_id) 202 | optimizer.step() 203 | optimizer.zero_grad() 204 | 205 | mean_reg_loss = total_reg_loss/total_cnt 206 | mean_bp_loss = total_bp_loss/total_cnt 207 | if writer: 208 | writer.add_scalar("%s/REG-%s-epoch" % (data_type, config["reg_loss"]), mean_reg_loss, epoch) 209 | writer.add_scalar("%s/BP-%s-epoch" % (data_type, config["bp_loss"]), mean_bp_loss, epoch) 210 | if logger: 211 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}".format( 212 | epoch, config["epochs"], data_type, mean_reg_loss, mean_bp_loss)) 213 | 214 | gc.collect() 215 | return mean_reg_loss, mean_bp_loss 216 | 217 | def evaluate(model, data_type, data_loader, device, config, epoch, logger=None, writer=None): 218 | epoch_step = len(data_loader) 219 | total_step = config["epochs"] * epoch_step 220 | total_reg_loss = 0 221 | total_bp_loss = 0 222 | total_cnt = 1e-6 223 | 224 | evaluate_results = {"data": {"id": list(), "counts": list(), "pred": list()}, 225 | "error": {"mae": INF, "mse": INF}, 226 | "time": {"avg": list(), "total": 0.0}} 227 | 228 | if config["reg_loss"] == "MAE": 229 | reg_crit = lambda pred, target: F.l1_loss(F.relu(pred), target, reduce="none") 230 | elif config["reg_loss"] == "MSE": 231 | reg_crit = lambda pred, target: F.mse_loss(F.relu(pred), target, reduce="none") 232 | elif config["reg_loss"] == "SMSE": 233 | reg_crit = lambda pred, target: F.smooth_l1_loss(F.relu(pred), target, reduce="none") 234 | else: 235 | raise NotImplementedError 236 | 237 | if config["bp_loss"] == "MAE": 238 | bp_crit = lambda pred, target, neg_slp: F.l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 239 | elif config["bp_loss"] == "MSE": 240 | bp_crit = lambda pred, target, neg_slp: F.mse_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 241 | elif config["bp_loss"] == "SMSE": 242 | bp_crit = lambda pred, target, neg_slp: F.smooth_l1_loss(F.leaky_relu(pred, neg_slp), target, reduce="none") 243 | else: 244 | raise NotImplementedError 245 | 246 | model.eval() 247 | 248 | with torch.no_grad(): 249 | for batch_id, batch in enumerate(data_loader): 250 | ids, pattern, pattern_len, graph, graph_len, counts = batch 251 | cnt = counts.shape[0] 252 | total_cnt += cnt 253 | 254 | evaluate_results["data"]["id"].extend(ids) 255 | evaluate_results["data"]["counts"].extend(counts.view(-1).tolist()) 256 | 257 | pattern.to(device) 258 | graph.to(device) 259 | pattern_len, graph_len, counts = pattern_len.to(device), graph_len.to(device), counts.to(device) 260 | 261 | st = time.time() 262 | pred = model(pattern, pattern_len, graph, graph_len) 263 | et = time.time() 264 | evaluate_results["time"]["total"] += (et-st) 265 | avg_t = (et-st) / (cnt + 1e-8) 266 | evaluate_results["time"]["avg"].extend([avg_t]*cnt) 267 | evaluate_results["data"]["pred"].extend(pred.cpu().view(-1).tolist()) 268 | 269 | reg_loss = reg_crit(pred, counts) 270 | 271 | if isinstance(config["bp_loss_slp"], (int, float)): 272 | neg_slp = float(config["bp_loss_slp"]) 273 | else: 274 | bp_loss_slp, l0, l1 = config["bp_loss_slp"].rsplit("$", 3) 275 | neg_slp = anneal_fn(bp_loss_slp, batch_id+epoch*epoch_step, T=total_step//4, lambda0=float(l0), lambda1=float(l1)) 276 | bp_loss = bp_crit(pred, counts, neg_slp) 277 | 278 | reg_loss_item = reg_loss.mean().item() 279 | bp_loss_item = bp_loss.mean().item() 280 | total_reg_loss += reg_loss_item * cnt 281 | total_bp_loss += bp_loss_item * cnt 282 | 283 | evaluate_results["error"]["mae"] += F.l1_loss(F.relu(pred), counts, reduce="none").sum().item() 284 | evaluate_results["error"]["mse"] += F.mse_loss(F.relu(pred), counts, reduce="none").sum().item() 285 | 286 | if writer: 287 | writer.add_scalar("%s/REG-%s" % (data_type, config["reg_loss"]), reg_loss_item, epoch*epoch_step+batch_id) 288 | writer.add_scalar("%s/BP-%s" % (data_type, config["bp_loss"]), bp_loss_item, epoch*epoch_step+batch_id) 289 | 290 | if logger and batch_id == epoch_step-1: 291 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\tbatch: {:0>5d}/{:0>5d}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}\tground: {:.3f}\tpredict: {:.3f}".format( 292 | epoch, config["epochs"], data_type, batch_id, epoch_step, 293 | reg_loss_item, bp_loss_item, 294 | counts[0].item(), pred[0].item())) 295 | mean_reg_loss = total_reg_loss/total_cnt 296 | mean_bp_loss = total_bp_loss/total_cnt 297 | if writer: 298 | writer.add_scalar("%s/REG-%s-epoch" % (data_type, config["reg_loss"]), mean_reg_loss, epoch) 299 | writer.add_scalar("%s/BP-%s-epoch" % (data_type, config["bp_loss"]), mean_bp_loss, epoch) 300 | if logger: 301 | logger.info("epoch: {:0>3d}/{:0>3d}\tdata_type: {:<5s}\treg loss: {:0>10.3f}\tbp loss: {:0>16.3f}".format( 302 | epoch, config["epochs"], data_type, mean_reg_loss, mean_bp_loss)) 303 | 304 | evaluate_results["error"]["mae"] = evaluate_results["error"]["mae"] / total_cnt 305 | evaluate_results["error"]["mse"] = evaluate_results["error"]["mse"] / total_cnt 306 | 307 | gc.collect() 308 | return mean_reg_loss, mean_bp_loss, evaluate_results 309 | 310 | 311 | if __name__ == "__main__": 312 | torch.manual_seed(0) 313 | np.random.seed(0) 314 | 315 | for i in range(1, len(sys.argv), 2): 316 | arg = sys.argv[i] 317 | value = sys.argv[i+1] 318 | 319 | if arg.startswith("--"): 320 | arg = arg[2:] 321 | if arg not in train_config: 322 | print("Warning: %s is not surported now." % (arg)) 323 | continue 324 | train_config[arg] = value 325 | try: 326 | value = eval(value) 327 | if isinstance(value, (int, float)): 328 | train_config[arg] = value 329 | except: 330 | pass 331 | 332 | ts = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 333 | model_name = "%s_%s_%s" % (train_config["model"], train_config["predict_net"], ts) 334 | save_model_dir = train_config["save_model_dir"] 335 | os.makedirs(save_model_dir, exist_ok=True) 336 | 337 | # save config 338 | with open(os.path.join(save_model_dir, "train_config.json"), "w") as f: 339 | json.dump(train_config, f) 340 | 341 | # set logger 342 | logger = logging.getLogger() 343 | logger.setLevel(logging.INFO) 344 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%Y/%m/%d %H:%M:%S') 345 | console = logging.StreamHandler() 346 | console.setFormatter(fmt) 347 | logger.addHandler(console) 348 | logfile = logging.FileHandler(os.path.join(save_model_dir, "train_log.txt"), 'w') 349 | logfile.setFormatter(fmt) 350 | logger.addHandler(logfile) 351 | 352 | # set device 353 | device = torch.device("cuda:%d" % train_config["gpu_id"] if train_config["gpu_id"] != -1 else "cpu") 354 | if train_config["gpu_id"] != -1: 355 | torch.cuda.set_device(device) 356 | 357 | # reset the pattern parameters 358 | if train_config["share_emb"]: 359 | train_config["max_npv"], train_config["max_npvl"], train_config["max_npe"], train_config["max_npel"] = \ 360 | train_config["max_ngv"], train_config["max_ngvl"], train_config["max_nge"], train_config["max_ngel"] 361 | 362 | # construct the model 363 | if train_config["model"] == "CNN": 364 | model = CNN(train_config) 365 | elif train_config["model"] == "RNN": 366 | model = RNN(train_config) 367 | elif train_config["model"] == "TXL": 368 | model = TXL(train_config) 369 | elif train_config["model"] == "RGCN": 370 | model = RGCN(train_config) 371 | elif train_config["model"] == "RGIN": 372 | model = RGIN(train_config) 373 | else: 374 | raise NotImplementedError("Currently, the %s model is not supported" % (train_config["model"])) 375 | model = model.to(device) 376 | logger.info(model) 377 | logger.info("num of parameters: %d" % (sum(p.numel() for p in model.parameters() if p.requires_grad))) 378 | 379 | # load data 380 | os.makedirs(train_config["save_data_dir"], exist_ok=True) 381 | data_loaders = OrderedDict({"train": None, "dev": None, "test": None}) 382 | if all([os.path.exists(os.path.join(train_config["save_data_dir"], 383 | "%s_%s_dataset.pt" % ( 384 | data_type, "dgl" if train_config["model"] in ["RGCN", "RGIN"] else "edgeseq"))) for data_type in data_loaders]): 385 | 386 | logger.info("loading data from pt...") 387 | for data_type in data_loaders: 388 | if train_config["model"] in ["RGCN", "RGIN"]: 389 | dataset = GraphAdjDataset(list()) 390 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 391 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 392 | data_loader = DataLoader(dataset, 393 | batch_sampler=sampler, 394 | collate_fn=GraphAdjDataset.batchify, 395 | pin_memory=data_type=="train") 396 | else: 397 | dataset = EdgeSeqDataset(list()) 398 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 399 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 400 | data_loader = DataLoader(dataset, 401 | batch_sampler=sampler, 402 | collate_fn=EdgeSeqDataset.batchify, 403 | pin_memory=data_type=="train") 404 | data_loaders[data_type] = data_loader 405 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 406 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), train_config["batch_size"])) 407 | else: 408 | data = load_data(train_config["graph_dir"], train_config["pattern_dir"], train_config["metadata_dir"], num_workers=train_config["num_workers"]) 409 | logger.info("{}/{}/{} data loaded".format(len(data["train"]), len(data["dev"]), len(data["test"]))) 410 | for data_type, x in data.items(): 411 | if train_config["model"] in ["RGCN", "RGIN"]: 412 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))): 413 | dataset = GraphAdjDataset(list()) 414 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 415 | else: 416 | dataset = GraphAdjDataset(x) 417 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 418 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 419 | data_loader = DataLoader(dataset, 420 | batch_sampler=sampler, 421 | collate_fn=GraphAdjDataset.batchify, 422 | pin_memory=data_type=="train") 423 | else: 424 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))): 425 | dataset = EdgeSeqDataset(list()) 426 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 427 | else: 428 | dataset = EdgeSeqDataset(x) 429 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 430 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 431 | data_loader = DataLoader(dataset, 432 | batch_sampler=sampler, 433 | collate_fn=EdgeSeqDataset.batchify, 434 | pin_memory=data_type=="train") 435 | data_loaders[data_type] = data_loader 436 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 437 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), train_config["batch_size"])) 438 | 439 | # optimizer and losses 440 | writer = SummaryWriter(save_model_dir) 441 | optimizer = torch.optim.AdamW(model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"], amsgrad=True) 442 | optimizer.zero_grad() 443 | scheduler = None 444 | # scheduler = get_linear_schedule_with_warmup(optimizer, 445 | # len(data_loaders["train"]), train_config["epochs"]*len(data_loaders["train"]), min_percent=0.0001) 446 | 447 | best_reg_losses = {"train": INF, "dev": INF, "test": INF} 448 | best_reg_epochs = {"train": -1, "dev": -1, "test": -1} 449 | 450 | for epoch in range(train_config["epochs"]): 451 | for data_type, data_loader in data_loaders.items(): 452 | 453 | if data_type == "train": 454 | mean_reg_loss, mean_bp_loss = train(model, optimizer, scheduler, data_type, data_loader, device, 455 | train_config, epoch, logger=logger, writer=writer) 456 | torch.save(model.state_dict(), os.path.join(save_model_dir, 'epoch%d.pt' % (epoch))) 457 | else: 458 | mean_reg_loss, mean_bp_loss, evaluate_results = evaluate(model, data_type, data_loader, device, 459 | train_config, epoch, logger=logger, writer=writer) 460 | with open(os.path.join(save_model_dir, '%s%d.json' % (data_type, epoch)), "w") as f: 461 | json.dump(evaluate_results, f) 462 | if mean_reg_loss <= best_reg_losses[data_type]: 463 | best_reg_losses[data_type] = mean_reg_loss 464 | best_reg_epochs[data_type] = epoch 465 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, mean_reg_loss, epoch)) 466 | for data_type in data_loaders.keys(): 467 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, best_reg_losses[data_type], best_reg_epochs[data_type])) 468 | -------------------------------------------------------------------------------- /src/train_mutag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import dgl 5 | import logging 6 | import datetime 7 | import math 8 | import sys 9 | import gc 10 | import json 11 | import time 12 | import torch.nn.functional as F 13 | import warnings 14 | from functools import partial 15 | from collections import OrderedDict 16 | from torch.utils.data import DataLoader 17 | try: 18 | from torch.utils.tensorboard import SummaryWriter 19 | except BaseException as e: 20 | from tensorboardX import SummaryWriter 21 | from dataset import Sampler, EdgeSeqDataset, GraphAdjDataset 22 | from utils import anneal_fn, get_enc_len, load_data, get_linear_schedule_with_warmup 23 | from cnn import CNN 24 | from rnn import RNN 25 | from txl import TXL 26 | from rgcn import RGCN 27 | from rgin import RGIN 28 | from train import evaluate, train 29 | 30 | warnings.filterwarnings("ignore") 31 | INF = float("inf") 32 | 33 | train_config = { 34 | "max_npv": 4, # max_number_pattern_vertices: 8, 16, 32 35 | "max_npe": 3, # max_number_pattern_edges: 8, 16, 32 36 | "max_npvl": 2, # max_number_pattern_vertex_labels: 8, 16, 32 37 | "max_npel": 2, # max_number_pattern_edge_labels: 8, 16, 32 38 | 39 | "max_ngv": 28, # max_number_graph_vertices: 64, 512,4096 40 | "max_nge": 66, # max_number_graph_edges: 256, 2048, 16384 41 | "max_ngvl": 7, # max_number_graph_vertex_labels: 16, 64, 256 42 | "max_ngel": 4, # max_number_graph_edge_labels: 16, 64, 256 43 | 44 | "base": 2, 45 | 46 | "gpu_id": -1, 47 | "num_workers": 12, 48 | 49 | "epochs": 100, 50 | "batch_size": 64, 51 | "update_every": 1, # actual batch_sizer = batch_size * update_every 52 | "print_every": 100, 53 | "init_emb": "Equivariant", # None, Orthogonal, Normal, Equivariant 54 | "share_emb": True, # sharing embedding requires the same vector length 55 | "share_arch": True, # sharing architectures 56 | "dropout": 0.2, 57 | "dropatt": 0.2, 58 | 59 | "reg_loss": "MSE", # MAE, MSEl 60 | "bp_loss": "MSE", # MAE, MSE 61 | "bp_loss_slp": "anneal_cosine$1.0$0.01", # 0, 0.01, logistic$1.0$0.01, linear$1.0$0.01, cosine$1.0$0.01, 62 | # cyclical_logistic$1.0$0.01, cyclical_linear$1.0$0.01, cyclical_cosine$1.0$0.01 63 | # anneal_logistic$1.0$0.01, anneal_linear$1.0$0.01, anneal_cosine$1.0$0.01 64 | "lr": 0.001, 65 | "weight_decay": 0.00001, 66 | "max_grad_norm": 8, 67 | 68 | "model" : "CNN", # CNN, RNN, TXL, RGCN, RGIN, RSIN 69 | 70 | "emb_dim": 128, 71 | "activation_function": "leaky_relu", # sigmoid, softmax, tanh, relu, leaky_relu, prelu, gelu 72 | 73 | "filter_net": "MaxGatedFilterNet", # None, MaxGatedFilterNet 74 | "predict_net": "SumPredictNet", # MeanPredictNet, SumPredictNet, MaxPredictNet, 75 | # MeanAttnPredictNet, SumAttnPredictNet, MaxAttnPredictNet, 76 | # MeanMemAttnPredictNet, SumMemAttnPredictNet, MaxMemAttnPredictNet, 77 | # DIAMNet 78 | "predict_net_add_enc": True, 79 | "predict_net_add_degree": True, 80 | "predict_net_hidden_dim": 128, 81 | "predict_net_num_heads": 4, 82 | "predict_net_mem_len": 4, 83 | "predict_net_mem_init": "mean", # mean, sum, max, attn, circular_mean, circular_sum, circular_max, circular_attn, lstm 84 | "predict_net_recurrent_steps": 3, 85 | 86 | "cnn_hidden_dim": 128, 87 | "cnn_conv_channels": (128, 128, 128), 88 | "cnn_conv_kernel_sizes": (2, 3, 4), 89 | "cnn_conv_strides": (1, 1, 1), 90 | "cnn_conv_paddings": (0, 1, 1), 91 | "cnn_pool_kernel_sizes": (2, 3, 4), 92 | "cnn_pool_strides": (1, 1, 1), 93 | "cnn_pool_paddings": (1, 1, 2), 94 | 95 | "rnn_type": "LSTM", # GRU, LSTM 96 | "rnn_bidirectional": False, 97 | "rnn_graph_num_layers": 3, 98 | "rnn_pattern_num_layers": 3, 99 | "rnn_hidden_dim": 128, 100 | 101 | "txl_graph_num_layers": 3, 102 | "txl_pattern_num_layers": 3, 103 | "txl_d_model": 128, 104 | "txl_d_inner": 128, 105 | "txl_n_head": 4, 106 | "txl_d_head": 4, 107 | "txl_pre_lnorm": True, 108 | "txl_tgt_len": 64, 109 | "txl_ext_len": 0, # useless in current settings 110 | "txl_mem_len": 64, 111 | "txl_clamp_len": -1, # max positional embedding index 112 | "txl_attn_type": 0, # 0 for Dai et al, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. 113 | "txl_same_len": False, 114 | 115 | "rgcn_num_bases": 8, 116 | "rgcn_regularizer": "bdd", # basis, bdd 117 | "rgcn_graph_num_layers": 3, 118 | "rgcn_pattern_num_layers": 3, 119 | "rgcn_hidden_dim": 128, 120 | "rgcn_ignore_norm": False, # ignorm=True -> RGCN-SUM 121 | 122 | "rgin_num_bases": 8, 123 | "rgin_regularizer": "bdd", # basis, bdd 124 | "rgin_graph_num_layers": 3, 125 | "rgin_pattern_num_layers": 3, 126 | "rgin_hidden_dim": 128, 127 | 128 | "train_ratio": 1.0, 129 | "pattern_dir": "../data/MUTAG/patterns", 130 | "graph_dir": "../data/MUTAG/raw", 131 | "metadata_dir": "../data/MUTAG/metadata", 132 | "save_data_dir": "../data/MUTAG/", 133 | "save_model_dir": "../dumps/MUTAG", 134 | } 135 | 136 | 137 | if __name__ == "__main__": 138 | torch.manual_seed(0) 139 | np.random.seed(0) 140 | 141 | for i in range(1, len(sys.argv), 2): 142 | arg = sys.argv[i] 143 | value = sys.argv[i+1] 144 | 145 | if arg.startswith("--"): 146 | arg = arg[2:] 147 | if arg not in train_config: 148 | print("Warning: %s is not surported now." % (arg)) 149 | continue 150 | train_config[arg] = value 151 | try: 152 | value = eval(value) 153 | if isinstance(value, (int, float)): 154 | train_config[arg] = value 155 | except: 156 | pass 157 | 158 | ts = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 159 | model_name = "%s_%s_%s" % (train_config["model"], train_config["predict_net"], ts) 160 | save_model_dir = train_config["save_model_dir"] 161 | os.makedirs(save_model_dir, exist_ok=True) 162 | 163 | # save config 164 | with open(os.path.join(save_model_dir, "train_config.json"), "w") as f: 165 | json.dump(train_config, f) 166 | 167 | # set logger 168 | logger = logging.getLogger() 169 | logger.setLevel(logging.INFO) 170 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%Y/%m/%d %H:%M:%S') 171 | console = logging.StreamHandler() 172 | console.setFormatter(fmt) 173 | logger.addHandler(console) 174 | logfile = logging.FileHandler(os.path.join(save_model_dir, "train_log.txt"), 'w') 175 | logfile.setFormatter(fmt) 176 | logger.addHandler(logfile) 177 | 178 | # set device 179 | device = torch.device("cuda:%d" % train_config["gpu_id"] if train_config["gpu_id"] != -1 else "cpu") 180 | if train_config["gpu_id"] != -1: 181 | torch.cuda.set_device(device) 182 | 183 | # reset the pattern parameters 184 | if train_config["share_emb"]: 185 | train_config["max_npv"], train_config["max_npvl"], train_config["max_npe"], train_config["max_npel"] = \ 186 | train_config["max_ngv"], train_config["max_ngvl"], train_config["max_nge"], train_config["max_ngel"] 187 | 188 | # construct the model 189 | if train_config["model"] == "CNN": 190 | model = CNN(train_config) 191 | elif train_config["model"] == "RNN": 192 | model = RNN(train_config) 193 | elif train_config["model"] == "TXL": 194 | model = TXL(train_config) 195 | elif train_config["model"] == "RGCN": 196 | model = RGCN(train_config) 197 | elif train_config["model"] == "RGIN": 198 | model = RGIN(train_config) 199 | else: 200 | raise NotImplementedError("Currently, the %s model is not supported" % (train_config["model"])) 201 | 202 | model = model.to(device) 203 | logger.info(model) 204 | logger.info("num of parameters: %d" % (sum(p.numel() for p in model.parameters() if p.requires_grad))) 205 | 206 | # load data 207 | os.makedirs(train_config["save_data_dir"], exist_ok=True) 208 | data_loaders = OrderedDict({"train": None, "dev": None, "test": None}) 209 | if all([os.path.exists(os.path.join(train_config["save_data_dir"], 210 | "%s_%s_dataset.pt" % ( 211 | data_type, "dgl" if train_config["model"] in ["RGCN", "RGIN"] else "edgeseq"))) for data_type in data_loaders]): 212 | 213 | logger.info("loading data from pt...") 214 | for data_type in data_loaders: 215 | if train_config["model"] in ["RGCN", "RGIN"]: 216 | dataset = GraphAdjDataset(list()) 217 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 218 | if data_type == "train": 219 | np.random.shuffle(dataset.data) 220 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*train_config["train_ratio"])] 221 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 222 | data_loader = DataLoader(dataset, 223 | batch_sampler=sampler, 224 | collate_fn=GraphAdjDataset.batchify, 225 | pin_memory=data_type=="train") 226 | else: 227 | dataset = EdgeSeqDataset(list()) 228 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 229 | if data_type == "train": 230 | np.random.shuffle(dataset.data) 231 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*train_config["train_ratio"])] 232 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 233 | data_loader = DataLoader(dataset, 234 | batch_sampler=sampler, 235 | collate_fn=EdgeSeqDataset.batchify, 236 | pin_memory=data_type=="train") 237 | data_loaders[data_type] = data_loader 238 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 239 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), train_config["batch_size"])) 240 | else: 241 | data = load_data(train_config["graph_dir"], train_config["pattern_dir"], train_config["metadata_dir"], num_workers=train_config["num_workers"]) 242 | logger.info("{}/{}/{} data loaded".format(len(data["train"]), len(data["dev"]), len(data["test"]))) 243 | for data_type, x in data.items(): 244 | if train_config["model"] in ["RGCN", "RGIN"]: 245 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))): 246 | dataset = GraphAdjDataset(list()) 247 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 248 | else: 249 | dataset = GraphAdjDataset(x) 250 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_dgl_dataset.pt" % (data_type))) 251 | if data_type == "train": 252 | np.random.shuffle(dataset.data) 253 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*train_config["train_ratio"])] 254 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 255 | data_loader = DataLoader(dataset, 256 | batch_sampler=sampler, 257 | collate_fn=GraphAdjDataset.batchify, 258 | pin_memory=data_type=="train") 259 | else: 260 | if os.path.exists(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))): 261 | dataset = EdgeSeqDataset(list()) 262 | dataset.load(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 263 | else: 264 | dataset = EdgeSeqDataset(x) 265 | dataset.save(os.path.join(train_config["save_data_dir"], "%s_edgeseq_dataset.pt" % (data_type))) 266 | if data_type == "train": 267 | np.random.shuffle(dataset.data) 268 | dataset.data = dataset.data[:math.ceil(len(dataset.data)*train_config["train_ratio"])] 269 | sampler = Sampler(dataset, group_by=["graph", "pattern"], batch_size=train_config["batch_size"], shuffle=data_type=="train", drop_last=False) 270 | data_loader = DataLoader(dataset, 271 | batch_sampler=sampler, 272 | collate_fn=EdgeSeqDataset.batchify, 273 | pin_memory=data_type=="train") 274 | data_loaders[data_type] = data_loader 275 | logger.info("data (data_type: {:<5s}, len: {}) generated".format(data_type, len(dataset.data))) 276 | logger.info("data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated".format(data_type, len(data_loader), train_config["batch_size"])) 277 | 278 | # optimizer and losses 279 | writer = SummaryWriter(save_model_dir) 280 | optimizer = torch.optim.AdamW(model.parameters(), lr=train_config["lr"], weight_decay=train_config["weight_decay"], amsgrad=True) 281 | optimizer.zero_grad() 282 | scheduler = None 283 | # scheduler = get_linear_schedule_with_warmup(optimizer, 284 | # len(data_loaders["train"]), train_config["epochs"]*len(data_loaders["train"]), min_percent=0.0001) 285 | 286 | best_reg_losses = {"train": INF, "dev": INF, "test": INF} 287 | best_reg_epochs = {"train": -1, "dev": -1, "test": -1} 288 | 289 | for epoch in range(train_config["epochs"]): 290 | for data_type, data_loader in data_loaders.items(): 291 | 292 | if data_type == "train": 293 | mean_reg_loss, mean_bp_loss = train(model, optimizer, scheduler, data_type, data_loader, device, 294 | train_config, epoch, logger=logger, writer=writer) 295 | torch.save(model.state_dict(), os.path.join(save_model_dir, 'epoch%d.pt' % (epoch))) 296 | else: 297 | mean_reg_loss, mean_bp_loss, evaluate_results = evaluate(model, data_type, data_loader, device, 298 | train_config, epoch, logger=logger, writer=writer) 299 | with open(os.path.join(save_model_dir, '%s%d.json' % (data_type, epoch)), "w") as f: 300 | json.dump(evaluate_results, f) 301 | if mean_reg_loss <= best_reg_losses[data_type]: 302 | best_reg_losses[data_type] = mean_reg_loss 303 | best_reg_epochs[data_type] = epoch 304 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, mean_reg_loss, epoch)) 305 | for data_type in data_loaders.keys(): 306 | logger.info("data_type: {:<5s}\tbest mean loss: {:.3f} (epoch: {:0>3d})".format(data_type, best_reg_losses[data_type], best_reg_epochs[data_type])) 307 | 308 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import math 6 | import numpy as np 7 | import re 8 | import os 9 | import json 10 | from torch.optim.lr_scheduler import LambdaLR 11 | from collections import OrderedDict 12 | from multiprocessing import Pool 13 | from tqdm import tqdm 14 | from sklearn.metrics import precision_recall_fscore_support 15 | 16 | ########################################################## 17 | ################## Evaluation Functions ################## 18 | ########################################################## 19 | def compute_mae(predict, count): 20 | error = np.absolute(predict-count) 21 | return error.mean() 22 | 23 | def compute_rmse(predict, count): 24 | error = np.power(predict-count, 2) 25 | return np.power(error.mean(), 0.5) 26 | 27 | def compute_p_r_f1(predict, count): 28 | p, r, f1, _ = precision_recall_fscore_support(predict, count, average="binary") 29 | return p, r, f1 30 | 31 | def compute_tp(predict, count): 32 | true_count = count == 1 33 | true_pred = predict == 1 34 | true_pred_count = true_count * true_pred 35 | return np.count_nonzero(true_pred_count) / np.count_nonzero(true_count) 36 | 37 | ########################################################## 38 | #################### Parsing Functions ################### 39 | ########################################################## 40 | def parse_pattern_info(x): 41 | p = re.findall(r"N(\d+)_E(\d+)_NL(\d+)_EL(\d+)", x)[0] 42 | return {"V": int(p[0]), "E": int(p[1]), "VL": int(p[2]), "EL": int(p[3])} 43 | 44 | def parse_graph_info(x): 45 | g = re.findall(r"N(\d+)_E(\d+)_NL(\d+)_EL(\d+)_A([\d\.]+)", x)[0] 46 | return {"V": int(g[0]), "E": int(g[1]), "VL": int(g[2]), "EL": int(g[3]), "alpha": float(g[4])} 47 | 48 | ########################################################## 49 | ######### Representation and Encoding Functions ########## 50 | ########################################################## 51 | def get_enc_len(x, base=10): 52 | # return math.floor(math.log(x, base)+1.0) 53 | l = 0 54 | while x: 55 | l += 1 56 | x = x // base 57 | return l 58 | 59 | def int2onehot(x, len_x, base=10): 60 | if isinstance(x, (int, list)): 61 | x = np.array(x) 62 | x_shape = x.shape 63 | x = x.reshape(-1) 64 | one_hot = np.zeros((len_x*base, x.shape[0]), dtype=np.float32) 65 | x = x % (base**len_x) 66 | idx = one_hot.shape[0] - base 67 | while np.any(x): 68 | x, y = x//base, x%base 69 | cond = y.reshape(1, -1) == np.arange(0, base, dtype=y.dtype).reshape(base, 1) 70 | one_hot[idx:idx+base] = np.where(cond, 1.0, 0.0) 71 | idx -= base 72 | while idx >= 0: 73 | one_hot[idx] = 1.0 74 | idx -= base 75 | one_hot = one_hot.transpose(1, 0).reshape(*x_shape, len_x*base) 76 | return one_hot 77 | 78 | ########################################################## 79 | ################ Deep Learning Functions ################# 80 | ########################################################## 81 | def segment_data(data, max_len): 82 | bsz = data.size(0) 83 | pad_len = max_len - data.size(1) % max_len 84 | if pad_len != max_len: 85 | pad_size = list(data.size()) 86 | pad_size[1] = pad_len 87 | zero_pad = torch.zeros(pad_size, device=data.device, dtype=data.dtype, requires_grad=False) 88 | data = torch.cat([data, zero_pad], dim=1) 89 | return torch.split(data, max_len, dim=1) 90 | 91 | def segment_length(data_len, max_len): 92 | bsz = data_len.size(0) 93 | list_len = math.ceil(data_len.max().float()/max_len) 94 | segment_lens = torch.arange(0, max_len*list_len, max_len, dtype=data_len.dtype, device=data_len.device, requires_grad=False).view(1, list_len) 95 | diff = data_len.view(-1, 1) - segment_lens 96 | fill_max = diff > max_len 97 | fill_zero = diff < 0 98 | segment_lens = diff.masked_fill(fill_max, max_len) 99 | segment_lens.masked_fill_(fill_zero, 0) 100 | return torch.split(segment_lens.view(bsz, -1), 1, dim=1) 101 | 102 | def clones(module, N): 103 | "Produce N identical layers." 104 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 105 | 106 | def split_and_batchify_graph_feats(batched_graph_feats, graph_sizes): 107 | bsz = graph_sizes.size(0) 108 | dim, dtype, device = batched_graph_feats.size(-1), batched_graph_feats.dtype, batched_graph_feats.device 109 | 110 | min_size, max_size = graph_sizes.min(), graph_sizes.max() 111 | mask = torch.ones((bsz, max_size), dtype=torch.uint8, device=device, requires_grad=False) 112 | 113 | if min_size == max_size: 114 | return batched_graph_feats.view(bsz, max_size, -1), mask 115 | else: 116 | graph_sizes_list = graph_sizes.view(-1).tolist() 117 | unbatched_graph_feats = list(torch.split(batched_graph_feats, graph_sizes_list, dim=0)) 118 | for i, l in enumerate(graph_sizes_list): 119 | if l == max_size: 120 | continue 121 | elif l > max_size: 122 | unbatched_graph_feats[i] = unbatched_graph_feats[i][:max_size] 123 | else: 124 | mask[i, l:].fill_(0) 125 | zeros = torch.zeros((max_size-l, dim), dtype=dtype, device=device, requires_grad=False) 126 | unbatched_graph_feats[i] = torch.cat([unbatched_graph_feats[i], zeros], dim=0) 127 | return torch.stack(unbatched_graph_feats, dim=0), mask 128 | 129 | def gather_indices_by_lens(lens): 130 | result = list() 131 | i, j = 0, 1 132 | max_j = len(lens) 133 | indices = np.arange(0, max_j) 134 | while j < max_j: 135 | if lens[i] != lens[j]: 136 | result.append(indices[i:j]) 137 | i = j 138 | j += 1 139 | if i != j: 140 | result.append(indices[i:j]) 141 | return result 142 | 143 | def batch_convert_array_to_array(batch_array, max_seq_len=-1): 144 | batch_lens = [v.shape[0] for v in batch_array] 145 | if max_seq_len == -1: 146 | max_seq_len = max(batch_lens) 147 | result = np.zeros([len(batch_array), max_seq_len] + list(batch_array[0].shape)[1:], dtype=batch_array[0].dtype) 148 | for i, t in enumerate(batch_array): 149 | len_t = batch_lens[i] 150 | if len_t < max_seq_len: 151 | result[i, :len_t] = t 152 | elif len_t == max_seq_len: 153 | result[i] = t 154 | else: 155 | result[i] = t[:max_seq_len] 156 | return result 157 | 158 | def batch_convert_tensor_to_tensor(batch_tensor, max_seq_len=-1): 159 | batch_lens = [v.shape[0] for v in batch_tensor] 160 | if max_seq_len == -1: 161 | max_seq_len = max(batch_lens) 162 | result = torch.zeros([len(batch_tensor), max_seq_len] + list(batch_tensor[0].size())[1:], dtype=batch_tensor[0].dtype, requires_grad=False) 163 | for i, t in enumerate(batch_tensor): 164 | len_t = batch_lens[i] 165 | if len_t < max_seq_len: 166 | result[i, :len_t].data.copy_(t) 167 | elif len_t == max_seq_len: 168 | result[i].data.copy_(t) 169 | else: 170 | result[i].data.copy_(t[:max_seq_len]) 171 | return result 172 | 173 | def batch_convert_len_to_mask(batch_lens, max_seq_len=-1): 174 | if max_seq_len == -1: 175 | max_seq_len = max(batch_lens) 176 | mask = torch.ones((len(batch_lens), max_seq_len), dtype=torch.uint8, device=batch_lens[0].device, requires_grad=False) 177 | for i, l in enumerate(batch_lens): 178 | mask[i, l:].fill_(0) 179 | return mask 180 | 181 | def convert_dgl_graph_to_edgeseq(graph, x_emb, x_len, e_emb): 182 | uid, vid, eid = graph.all_edges(form="all", order="srcdst") 183 | e = e_emb[eid] 184 | if x_emb is not None: 185 | u, v = x_emb[uid], x_emb[vid] 186 | e = torch.cat([u, v, e], dim=1) 187 | e_len = torch.tensor(graph.batch_num_edges, dtype=x_len.dtype, device=x_len.device).view(x_len.size()) 188 | return e, e_len 189 | 190 | def mask_seq_by_len(x, len_x): 191 | x_size = list(x.size()) 192 | if x_size[1] == len_x.max(): 193 | mask = batch_convert_len_to_mask(len_x) 194 | mask_size = x_size[0:2] + [1]*(len(x_size)-2) 195 | x = x * mask.view(*mask_size) 196 | return x 197 | 198 | def extend_dimensions(old_layer, new_input_dim=-1, new_output_dim=-1, upper=False): 199 | if isinstance(old_layer, nn.Linear): 200 | old_output_dim, old_input_dim = old_layer.weight.size() 201 | if new_input_dim == -1: 202 | new_input_dim = old_input_dim 203 | if new_output_dim == -1: 204 | new_output_dim = old_output_dim 205 | assert new_input_dim >= old_input_dim and new_output_dim >= old_output_dim 206 | 207 | if new_input_dim != old_input_dim or new_output_dim != old_output_dim: 208 | use_bias = old_layer.bias is not None 209 | new_layer = nn.Linear(new_input_dim, new_output_dim, bias=use_bias) 210 | with torch.no_grad(): 211 | nn.init.zeros_(new_layer.weight) 212 | if upper: 213 | new_layer.weight[:old_output_dim, :old_input_dim].data.copy_(old_layer.weight) 214 | else: 215 | new_layer.weight[-old_output_dim:, -old_input_dim:].data.copy_(old_layer.weight) 216 | if use_bias: 217 | nn.init.zeros_(new_layer.bias) 218 | if upper: 219 | new_layer.bias[:old_output_dim].data.copy_(old_layer.bias) 220 | else: 221 | new_layer.bias[-old_output_dim:].data.copy_(old_layer.bias) 222 | else: 223 | new_layer = old_layer 224 | elif isinstance(old_layer, nn.LayerNorm): 225 | old_input_dim = old_layer.normalized_shape 226 | if len(old_input_dim) != 1: 227 | raise NotImplementedError 228 | old_input_dim = old_input_dim[0] 229 | assert new_input_dim >= old_input_dim 230 | if new_input_dim != old_input_dim and old_layer.elementwise_affine: 231 | new_layer = nn.LayerNorm(new_input_dim, elementwise_affine=True) 232 | with torch.no_grad(): 233 | nn.init.ones_(new_layer.weight) 234 | nn.init.zeros_(new_layer.bias) 235 | if upper: 236 | new_layer.weight[:old_input_dim].data.copy_(old_layer.weight) 237 | new_layer.bias[:old_input_dim].data.copy_(old_layer.bias) 238 | else: 239 | new_layer.weight[-old_input_dim:].data.copy_(old_layer.weight) 240 | new_layer.bias[-old_input_dim:].data.copy_(old_layer.bias) 241 | else: 242 | new_layer = old_layer 243 | elif isinstance(old_layer, nn.LSTM): 244 | old_input_dim, old_output_dim = old_layer.input_size, old_layer.hidden_size 245 | if new_input_dim == -1: 246 | new_input_dim = old_input_dim 247 | if new_output_dim == -1: 248 | new_output_dim = old_output_dim 249 | assert new_input_dim >= old_input_dim and new_output_dim >= old_output_dim 250 | 251 | if new_input_dim != old_input_dim or new_output_dim != old_output_dim: 252 | new_layer = nn.LSTM(new_input_dim, new_output_dim, 253 | num_layers=old_layer.num_layers, bidirectional=old_layer.bidirectional, 254 | batch_first=old_layer.batch_first, bias=old_layer.bias) 255 | for layer_weights in new_layer._all_weights: 256 | for w in layer_weights: 257 | with torch.no_grad(): 258 | if "weight" in w: 259 | new_weight = getattr(new_layer, w) 260 | old_weight = getattr(old_layer, w) 261 | nn.init.zeros_(new_weight) 262 | if upper: 263 | new_weight[:old_weight.shape[0], :old_weight.shape[1]].data.copy_(old_weight) 264 | else: 265 | new_weight[-old_weight.shape[0]:, -old_weight.shape[1]:].data.copy_(old_weight) 266 | if "bias" in w: 267 | new_bias = getattr(new_layer, w) 268 | old_bias = getattr(old_layer, w) 269 | if new_bias is not None: 270 | nn.init.zeros_(new_bias) 271 | if upper: 272 | new_bias[:old_bias.shape[0]].data.copy_(old_bias) 273 | else: 274 | new_bias[-old_bias.shape[0]:].data.copy_(old_bias) 275 | return new_layer 276 | 277 | 278 | _act_map = {"none": lambda x: x, 279 | "relu": nn.ReLU(), 280 | "tanh": nn.Tanh(), 281 | "softmax": nn.Softmax(dim=-1), 282 | "sigmoid": nn.Sigmoid(), 283 | "leaky_relu": nn.LeakyReLU(1/5.5), 284 | "prelu": nn.PReLU(), 285 | "gelu": nn.GELU()} 286 | 287 | def map_activation_str_to_layer(act_str): 288 | try: 289 | return _act_map[act_str] 290 | except: 291 | raise NotImplementedError("Error: %s activation fuction is not supported now." % (act_str)) 292 | 293 | def anneal_fn(fn, t, T, lambda0=0.0, lambda1=1.0): 294 | if not fn or fn == "none": 295 | return lambda1 296 | elif fn == "logistic": 297 | K = 8 / T 298 | return float(lambda0 + (lambda1-lambda0)/(1+np.exp(-K*(t-T/2)))) 299 | elif fn == "linear": 300 | return float(lambda0 + (lambda1-lambda0) * t/T) 301 | elif fn == "cosine": 302 | return float(lambda0 + (lambda1-lambda0) * (1 - math.cos(math.pi * t/T))/2) 303 | elif fn.startswith("cyclical"): 304 | R = 0.5 305 | t = t % T 306 | if t <= R * T: 307 | return anneal_fn(fn.split("_", 1)[1], t, R*T, lambda0, lambda1) 308 | else: 309 | return anneal_fn(fn.split("_", 1)[1], t-R*T, R*T, lambda1, lambda0) 310 | elif fn.startswith("anneal"): 311 | R = 0.5 312 | t = t % T 313 | if t <= R * T: 314 | return anneal_fn(fn.split("_", 1)[1], t, R*T, lambda0, lambda1) 315 | else: 316 | return lambda1 317 | else: 318 | raise NotImplementedError 319 | 320 | def get_constant_schedule(optimizer, last_epoch=-1): 321 | """ Create a schedule with a constant learning rate. 322 | """ 323 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 324 | 325 | 326 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 327 | """ Create a schedule with a constant learning rate preceded by a warmup 328 | period during which the learning rate increases linearly between 0 and 1. 329 | """ 330 | 331 | def lr_lambda(current_step): 332 | if current_step < num_warmup_steps: 333 | return float(current_step) / float(max(1.0, num_warmup_steps)) 334 | return 1.0 335 | 336 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 337 | 338 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1, min_percent=0.0): 339 | """ Create a schedule with a learning rate that decreases linearly after 340 | linearly increasing during a warmup period. 341 | """ 342 | 343 | def lr_lambda(current_step): 344 | if current_step < num_warmup_steps: 345 | return float(current_step) / float(max(1, num_warmup_steps)) 346 | return max(min_percent, float(num_training_steps - current_step) / float(max(1.0, num_training_steps - num_warmup_steps))) 347 | 348 | return LambdaLR(optimizer, lr_lambda, last_epoch) 349 | 350 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1, min_percent=0.0): 351 | """ Create a schedule with a learning rate that decreases following the 352 | values of the cosine function between 0 and `pi * cycles` after a warmup 353 | period during which it increases linearly between 0 and 1. 354 | """ 355 | 356 | def lr_lambda(current_step): 357 | if current_step < num_warmup_steps: 358 | return float(current_step) / float(max(1, num_warmup_steps)) 359 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 360 | return max(min_percent, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 361 | 362 | return LambdaLR(optimizer, lr_lambda, last_epoch) 363 | 364 | def get_cosine_with_hard_restarts_schedule_with_warmup( 365 | optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1, min_percent=0.0): 366 | """ Create a schedule with a learning rate that decreases following the 367 | values of the cosine function with several hard restarts, after a warmup 368 | period during which it increases linearly between 0 and 1. 369 | """ 370 | 371 | def lr_lambda(current_step): 372 | if current_step < num_warmup_steps: 373 | return float(current_step) / float(max(1, num_warmup_steps)) 374 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 375 | if progress >= 1.0: 376 | return min_percent 377 | return max(min_percent, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) 378 | 379 | return LambdaLR(optimizer, lr_lambda, last_epoch) 380 | 381 | ############################################## 382 | ############ OS Function Parts ############### 383 | ############################################## 384 | def _get_subdirs(dirpath, leaf_only=True): 385 | subdirs = list() 386 | is_leaf = True 387 | for filename in os.listdir(dirpath): 388 | if os.path.isdir(os.path.join(dirpath, filename)): 389 | is_leaf = False 390 | subdirs.extend(_get_subdirs(os.path.join(dirpath, filename), leaf_only=leaf_only)) 391 | if not leaf_only or is_leaf: 392 | subdirs.append(dirpath) 393 | return subdirs 394 | 395 | def _read_graphs_from_dir(dirpath): 396 | import igraph as ig 397 | graphs = dict() 398 | for filename in os.listdir(dirpath): 399 | if not os.path.isdir(os.path.join(dirpath, filename)): 400 | names = os.path.splitext(os.path.basename(filename)) 401 | if names[1] != ".gml": 402 | continue 403 | try: 404 | graph = ig.read(os.path.join(dirpath, filename)) 405 | graph.vs["label"] = [int(x) for x in graph.vs["label"]] 406 | graph.es["label"] = [int(x) for x in graph.es["label"]] 407 | graph.es["key"] = [int(x) for x in graph.es["key"]] 408 | graphs[names[0]] = graph 409 | except BaseException as e: 410 | print(e) 411 | break 412 | return graphs 413 | 414 | def read_graphs_from_dir(dirpath, num_workers=4): 415 | graphs = dict() 416 | subdirs = _get_subdirs(dirpath) 417 | with Pool(num_workers if num_workers > 0 else os.cpu_count()) as pool: 418 | results = list() 419 | for subdir in subdirs: 420 | results.append((subdir, pool.apply_async(_read_graphs_from_dir, args=(subdir, )))) 421 | pool.close() 422 | 423 | for subdir, x in tqdm(results): 424 | x = x.get() 425 | graphs[os.path.basename(subdir)] = x 426 | return graphs 427 | 428 | def read_patterns_from_dir(dirpath, num_workers=4): 429 | patterns = dict() 430 | subdirs = _get_subdirs(dirpath) 431 | with Pool(num_workers if num_workers > 0 else os.cpu_count()) as pool: 432 | results = list() 433 | for subdir in subdirs: 434 | results.append((subdir, pool.apply_async(_read_graphs_from_dir, args=(subdir, )))) 435 | pool.close() 436 | 437 | for subdir, x in tqdm(results): 438 | x = x.get() 439 | patterns.update(x) 440 | return patterns 441 | 442 | def _read_metadata_from_dir(dirpath): 443 | meta = dict() 444 | for filename in os.listdir(dirpath): 445 | if not os.path.isdir(os.path.join(dirpath, filename)): 446 | names = os.path.splitext(os.path.basename(filename)) 447 | if names[1] != ".meta": 448 | continue 449 | try: 450 | with open(os.path.join(dirpath, filename), "r") as f: 451 | meta[names[0]] = json.load(f) 452 | except BaseException as e: 453 | print(e) 454 | return meta 455 | 456 | def read_metadata_from_dir(dirpath, num_workers=4): 457 | meta = dict() 458 | subdirs = _get_subdirs(dirpath) 459 | with Pool(num_workers if num_workers > 0 else os.cpu_count()) as pool: 460 | results = list() 461 | for subdir in subdirs: 462 | results.append((subdir, pool.apply_async(_read_metadata_from_dir, args=(subdir, )))) 463 | pool.close() 464 | 465 | for subdir, x in tqdm(results): 466 | x = x.get() 467 | meta[os.path.basename(subdir)] = x 468 | return meta 469 | 470 | def load_data(graph_dir, pattern_dir, metadata_dir, num_workers=4): 471 | patterns = read_patterns_from_dir(pattern_dir, num_workers=num_workers) 472 | graphs = read_graphs_from_dir(graph_dir, num_workers=num_workers) 473 | meta = read_metadata_from_dir(metadata_dir, num_workers=num_workers) 474 | 475 | train_data, dev_data, test_data = list(), list(), list() 476 | for p, pattern in patterns.items(): 477 | if p in graphs: 478 | for g, graph in graphs[p].items(): 479 | x = dict() 480 | x["id"] = ("%s-%s" % (p, g)) 481 | x["pattern"] = pattern 482 | x["graph"] = graph 483 | x["subisomorphisms"] = meta[p][g]["subisomorphisms"] 484 | x["counts"] = meta[p][g]["counts"] 485 | 486 | g_idx = int(g.rsplit("_", 1)[-1]) 487 | if g_idx % 10 == 0: 488 | dev_data.append(x) 489 | elif g_idx % 10 == 1: 490 | test_data.append(x) 491 | else: 492 | train_data.append(x) 493 | elif len(graphs) == 1 and "raw" in graphs.keys(): 494 | for g, graph in graphs["raw"].items(): 495 | x = dict() 496 | x["id"] = ("%s-%s" % (p, g)) 497 | x["pattern"] = pattern 498 | x["graph"] = graph 499 | x["subisomorphisms"] = meta[p][g]["subisomorphisms"] 500 | x["counts"] = meta[p][g]["counts"] 501 | 502 | g_idx = int(g.rsplit("_", 1)[-1]) 503 | if g_idx % 3 == 0: 504 | dev_data.append(x) 505 | elif g_idx % 3 == 1: 506 | test_data.append(x) 507 | else: 508 | train_data.append(x) 509 | return OrderedDict({"train": train_data, "dev": dev_data, "test": test_data}) 510 | 511 | def get_best_epochs(log_file): 512 | regex = re.compile(r"data_type:\s+(\w+)\s+best\s+([\s\w\-]+).*?\(epoch:\s+(\d+)\)") 513 | best_epochs = dict() 514 | # get the best epoch 515 | try: 516 | lines = subprocess.check_output(["tail", log_file, "-n3"]).decode("utf-8").split("\n")[0:-1] 517 | print(lines) 518 | except: 519 | with open(log_file, "r") as f: 520 | lines = f.readlines() 521 | 522 | for line in lines[-3:]: 523 | matched_results = regex.findall(line) 524 | for matched_result in matched_results: 525 | if "loss" in matched_result[1]: 526 | best_epochs[matched_result[0]] = int(matched_result[2]) 527 | if len(best_epochs) != 3: 528 | for line in lines: 529 | matched_results = regex.findall(line) 530 | for matched_result in matched_results: 531 | if "loss" in matched_result[1]: 532 | best_epochs[matched_result[0]] = int(matched_result[2]) 533 | return best_epochs 534 | --------------------------------------------------------------------------------