├── LAN.pdf ├── README.md ├── .gitignore ├── cross_graph_learn.py ├── wl_labelling.py ├── init_node_sel_model_train.py ├── neigh_pruning_model_training.py └── routing_with_neigh_pruning.py /LAN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csypeng/LAN/HEAD/LAN.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LAN: Learning-based Approximate k-NN Search in Graph Databases 2 | 3 | Due to space limitations, please refer to the Table I in our paper to download the datasets. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /cross_graph_learn.py: -------------------------------------------------------------------------------- 1 | 2 | import dgl 3 | from dgl.data import DGLDataset 4 | import dgl.function as fn 5 | from dgl.nn.pytorch.conv import GINConv 6 | from dgl.udf import EdgeBatch 7 | from dgl.heterograph import DGLHeteroGraph 8 | import networkx as nx 9 | from networkx.classes.graph import Graph as NXGraph 10 | import numpy as np 11 | import os 12 | import random 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | from typing import List 18 | from wl_labelling import compress_graph 19 | import time 20 | 21 | GPUID=0 22 | torch.cuda.set_device(GPUID) 23 | 24 | 25 | 26 | def read_and_split_to_individual_graph(fname): 27 | 28 | f = open(fname) 29 | lines = f.read() 30 | f.close() 31 | 32 | lines2 = lines.split("t # ") 33 | 34 | lines3 = [g.strip().split("\n") for g in lines2] 35 | 36 | glist = [] 37 | max_node_label = 0 38 | max_edge_label = 0 39 | for idx in range(1, len(lines3)): 40 | cur_g = lines3[idx] 41 | 42 | gid_line = cur_g[0].strip().split(' ') 43 | gid = gid_line[0] 44 | 45 | g = nx.Graph(id = gid) 46 | 47 | for idy in range(1, len(cur_g)): 48 | tmp = cur_g[idy].split(' ') 49 | if tmp[0] == 'v': 50 | g.add_node(tmp[1], label=int(tmp[2])) 51 | max_node_label = max(max_node_label, int(tmp[2])) 52 | 53 | if tmp[0] == 'e': 54 | g.add_edge(tmp[1], tmp[2]) 55 | max_edge_label = max(max_edge_label, 0) 56 | 57 | glist.append(g) 58 | 59 | return glist, max_node_label, max_edge_label 60 | 61 | 62 | 63 | class GINDataset(DGLDataset): 64 | 65 | def __init__(self, name, gid2gmap, self_loop=False, degree_as_nlabel=False, 66 | raw_dir=None, force_reload=False, verbose=False): 67 | 68 | self._name = name 69 | gin_url = "" 70 | 71 | self.gid2gmap = gid2gmap # key: graph id, value: DGLHeteroGraph 72 | 73 | self.g1List = [] # list of DGLHeteroGraph 74 | self.g2List = [] # list of DGLHeteroGraph 75 | self.ground_truth = [] # list of ground truth 76 | 77 | 78 | super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel), 79 | raw_dir=raw_dir, force_reload=force_reload, verbose=verbose) 80 | 81 | @property 82 | def raw_path(self): 83 | return os.path.join(".", self.raw_dir) 84 | 85 | 86 | def download(self): 87 | pass 88 | 89 | 90 | def __len__(self): 91 | return len(self.g1List) 92 | 93 | 94 | def __getitem__(self, idx): 95 | return self.g1List[idx], self.g2List[idx], self.ground_truth[idx] 96 | 97 | 98 | def _file_path(self): 99 | return self.file 100 | 101 | 102 | def process(self): 103 | 104 | for k in self.gid2gmap.keys(): 105 | g1 = self.gid2gmap[k] 106 | 107 | # randomly delete an edge from g1 108 | g2 = nx.Graph(g1) 109 | rand_edge = random.sample(g1.edges(), 1) 110 | g2.remove_edge(rand_edge[0][0], rand_edge[0][1]) 111 | 112 | self.g1List.append(g1) 113 | self.g2List.append(g2) 114 | self.ground_truth.append(1) 115 | 116 | 117 | 118 | def save(self): 119 | pass 120 | 121 | def load(self): 122 | pass 123 | 124 | def has_cache(self): 125 | pass 126 | 127 | 128 | 129 | 130 | 131 | class myGINConv(nn.Module): 132 | 133 | def __init__(self): 134 | super(myGINConv, self).__init__() 135 | 136 | def forward(self, graph): 137 | graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h')) 138 | 139 | 140 | 141 | 142 | class Model(nn.Module): 143 | def __init__(self): 144 | super(Model, self).__init__() 145 | 146 | 147 | self.RELU = torch.nn.ReLU(inplace=True) 148 | 149 | # self.conv1_for_g = myGINConv() 150 | # self.conv2_for_g = myGINConv() 151 | self.conv1_for_g = GINConv(None, 'mean') 152 | self.conv2_for_g = GINConv(None, 'mean') 153 | 154 | 155 | self.fc = nn.Linear(64, 256, bias=True) 156 | self.fc2 = nn.Linear(256, 256, bias=True) 157 | self.fc3 = nn.Linear(256, 1, bias=True) 158 | 159 | # max_node_label: 60 160 | self.fc_init_node = nn.Linear(60, 32, bias=True) 161 | 162 | self.fc_att = nn.Linear(64, 1, bias=True) 163 | 164 | 165 | def forward(self, g1, g2): 166 | 167 | h0_g1 = self.fc_init_node(g1.ndata['h']) 168 | h0_g2 = self.fc_init_node(g2.ndata['h']) 169 | 170 | g1.ndata['h'] = h0_g1 171 | g2.ndata['h'] = h0_g2 172 | 173 | # h1_g1 = self.conv1_for_g(g1, 'h') 174 | # h1_g2 = self.conv1_for_g(g2, 'h') 175 | g1.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h')) 176 | g2.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h')) 177 | 178 | t1_g1 = g1.ndata['h'] 179 | t1_g2 = g2.ndata['h'] 180 | 181 | t1_g1_rep = t1_g1.repeat(1, g2.number_of_nodes()).view(-1, 32) 182 | t1_g2_rep = t1_g2.repeat(g1.number_of_nodes(), 1) 183 | t1_g1_cat_t1_g2 = torch.cat([t1_g1_rep, t1_g2_rep], 1) 184 | 185 | alpha = self.fc_att(t1_g1_cat_t1_g2) 186 | 187 | # compute attention weights for g1 188 | alpha_for_g1 = alpha.view(g1.number_of_nodes(), -1) 189 | alpha_for_g1_sum = alpha_for_g1.sum(1).view(-1, 1) 190 | alpha_for_g1 = alpha_for_g1 / alpha_for_g1_sum 191 | alpha_for_g1 = alpha_for_g1.view(-1,1) 192 | mu_g1 = alpha_for_g1 * t1_g2_rep 193 | 194 | mu_g1 = torch.split(mu_g1, g2.number_of_nodes()) 195 | mu_g1 = torch.stack(mu_g1) 196 | mu_g1 = mu_g1.sum(1) 197 | h1_g1 = t1_g1 + mu_g1 198 | 199 | g1.ndata['h'] = h1_g1 200 | 201 | # compute attention weights for g2 202 | alpha_for_g2 = alpha.view(g1.number_of_nodes(), -1).transpose(0,1) 203 | alpha_for_g2_sum = alpha_for_g2.sum(1).view(-1,1) 204 | alpha_for_g2 = alpha_for_g2 / alpha_for_g2_sum 205 | alpha_for_g2 = alpha_for_g2.contiguous().view(-1, 1) 206 | 207 | t1_g1_rep_v2 = t1_g1.repeat(g2.number_of_nodes(), 1) 208 | mu_g2 = alpha_for_g2 * t1_g1_rep_v2 209 | 210 | mu_g2 = torch.split(mu_g2, g1.number_of_nodes()) 211 | mu_g2 = torch.stack(mu_g2) 212 | mu_g2 = mu_g2.sum(1) 213 | h1_g2 = t1_g2 + mu_g2 214 | 215 | g2.ndata['h'] = h1_g2 216 | 217 | g1_emb = dgl.mean_nodes(g1, 'h') 218 | g2_emb = dgl.mean_nodes(g2, 'h') 219 | 220 | g1_emb_cat_g2_emb = torch.cat([g1_emb, g2_emb], 1) 221 | 222 | 223 | H = self.fc(g1_emb_cat_g2_emb) 224 | H2 = self.fc2(H) 225 | pred = self.fc3(H2) 226 | 227 | 228 | return pred 229 | 230 | 231 | 232 | 233 | mseLoss = nn.MSELoss() 234 | 235 | 236 | def myloss(preds, gts): 237 | return mseLoss(preds.view(-1,1), gts.view(-1,1).float()) 238 | 239 | 240 | def make_a_dglgraph(g: NXGraph) -> DGLHeteroGraph: 241 | g = g.to_directed() 242 | dg = dgl.from_networkx(g, node_attrs=["label"]) 243 | return dg 244 | 245 | 246 | def make_a_dglgraph(g): 247 | max_deg = 60 # largest node label 248 | ones = torch.eye(max_deg) 249 | 250 | g = g.to_directed() 251 | dg = dgl.from_networkx(g, node_attrs=["label"]) 252 | h0 = dg.ndata['label'].view(1,-1).squeeze() 253 | h0 = ones.index_select(0, h0).float() 254 | dg.ndata['h'] = h0 255 | 256 | return dg 257 | 258 | 259 | def collate(samples): 260 | g1List, g2List, gtList = map(list, zip(*samples)) 261 | dg_g1 = make_a_dglgraph(g1List[0]) 262 | dg_g2 = make_a_dglgraph(g2List[0]) 263 | return dg_g1, dg_g2, torch.tensor(gtList) 264 | 265 | 266 | 267 | ##################################################################################################### 268 | ### do the job as follows 269 | ##################################################################################################### 270 | 271 | 272 | 273 | 274 | entire_dataset, max_node_label, max_edge_label = read_and_split_to_individual_graph("aids.txt") 275 | 276 | print('entire_dataset len: ', len(entire_dataset)) 277 | print("max_node_label", max_node_label) # max_node_label 59 278 | print("max_edge_label", max_edge_label) # max_edge_label 3 279 | 280 | 281 | gid2gmap = {} # key: graph id, value: DGLHeteroGraph 282 | for g in entire_dataset: 283 | gid2gmap[g.graph.get('id')] = g 284 | 285 | 286 | gid2dgmap = {} # key: graph id, value: DGLHeteroGraph 287 | for g in entire_dataset: 288 | dg = make_a_dglgraph(g) 289 | dg = dgl.add_self_loop(dg) 290 | gid2dgmap[g.graph.get('id')] = dg 291 | 292 | 293 | train_data = GINDataset("aids", gid2gmap) 294 | 295 | 296 | dataloader = DataLoader( 297 | train_data, 298 | batch_size=1, 299 | collate_fn=collate, 300 | drop_last=False, 301 | shuffle=True) 302 | 303 | 304 | 305 | n_epochs = 100000 # epochs to train 306 | lr = 0.01 # learning rate 307 | l2norm = 0 # L2 norm coefficient 308 | 309 | # create model 310 | model = Model() 311 | model#.cuda() 312 | 313 | optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = l2norm) 314 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.9) 315 | 316 | torch.backends.cudnn.enabled = True 317 | torch.backends.cudnn.benchmark = True 318 | 319 | with torch.autograd.set_detect_anomaly(True): 320 | for epoch in range(n_epochs): 321 | model.train() 322 | batch_count = 0 323 | for g1, g2, gt in dataloader: 324 | print('='*40+" epoch ", epoch, "batch ", batch_count) 325 | preds = model(g1, g2) 326 | loss = myloss(preds, gt) 327 | optimizer.zero_grad() 328 | loss.backward() 329 | optimizer.step() 330 | 331 | torch.cuda.empty_cache() 332 | 333 | batch_count += 1 334 | 335 | # do a test after an epoch 336 | print("do test ....") 337 | model.eval() 338 | with torch.no_grad(): 339 | for g1, g2, gt in dataloader: 340 | preds = model(g1, g2) 341 | torch.cuda.empty_cache() 342 | print("+++++++"*5, 'test finish') 343 | 344 | 345 | -------------------------------------------------------------------------------- /wl_labelling.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import namedtuple 3 | import dgl 4 | from dgl.heterograph import DGLHeteroGraph 5 | import dgl.function as fn 6 | from dgl.nn import GINConv 7 | from dgl.udf import NodeBatch 8 | import networkx as nx 9 | import os 10 | import torch 11 | from torch import Tensor 12 | from typing import List, Tuple 13 | from visualize_graph import visualize_wl_label_graph 14 | 15 | class gen_wl_label_udf: 16 | """ 17 | A reduce function to generate wl label. 18 | 19 | src: source label, e.g. wl_label_0 20 | dst: destination label, e.g. wl_label_1 21 | src_collection_to_dst: {(0, 1, 2): 4, (1, 2, 3): 5, ...} 22 | dst_to_src_collection: {4: (0, 1, 2), 5: (1, 2, 3), ...} 23 | """ 24 | def __init__(self, src: str, dst: str, start_idx: int): 25 | self.src: str = src 26 | self.dst: str = dst 27 | self.start_idx: int = start_idx 28 | self.src_collection_to_dst = {} 29 | self.dst_to_src_collection = {} 30 | 31 | def __call__(self, nodes: NodeBatch): 32 | mailbox: dict = nodes.mailbox 33 | 34 | # shape: (number of nodes with same number of messages received, number of messages received) 35 | m: torch.tensor = mailbox[self.src] 36 | 37 | # source label, e.g. wl_label_0 38 | # after unsqueeze, shape should be (n, 1) 39 | src_label: torch.tensor = nodes.data[self.src].unsqueeze(-1) 40 | 41 | # collect messages from neighbors and combine with src_label 42 | # shape: (number of nodes with same number of messages received, number of messages received + 1) 43 | # src_label_collection = torch.cat((src_label, m), 1).type(torch.int) 44 | src_label_collection = m.type(torch.int) 45 | 46 | # sort along column dimension 47 | sorted_src_label_collection, _ = torch.sort(src_label_collection, 1) 48 | 49 | # relabel src_label_collection to new label to 0, 1, 2, ... 50 | dst_label = torch.zeros(len(sorted_src_label_collection)) 51 | for i, l in enumerate(sorted_src_label_collection): 52 | l_tuple: Tuple[int] = tuple(l.tolist()) 53 | if l_tuple not in self.src_collection_to_dst: 54 | # start from 1 to differentiate the ones who has no message received 55 | dst_label_idx = self.start_idx + len(self.src_collection_to_dst) 56 | self.src_collection_to_dst.update({ l_tuple : dst_label_idx }) 57 | self.dst_to_src_collection.update({ dst_label_idx : l_tuple }) 58 | dst_label[i] = dst_label_idx 59 | else: 60 | dst_label[i] = self.src_collection_to_dst[l_tuple] 61 | return {self.dst: dst_label} 62 | 63 | 64 | def compress_graph(gid: int, g: DGLHeteroGraph, max_degree: int) -> DGLHeteroGraph: 65 | """ 66 | 1. Re-label all nodes with new wl_label_0 67 | 2. Use update_all to generate wl_label_1 with wl_label_0 68 | 3. Use update_all to generate wl_label_2 with wl_label_2 69 | 4. Generate a new graph with wl_label_0, wl_label_1, wl_label_2 and hGx nodes 70 | 5. Create edges 71 | """ 72 | 73 | # Relabel all nodes with new wl_label_0. 74 | # The number of wl_label_0 is the same as the number of unique labels in the graph. 75 | unique_label_tensor: Tensor = torch.unique(g.ndata["label"]).type(torch.long) 76 | unique_label_list: List = sorted(unique_label_tensor.tolist()) 77 | n_wl_label_0 = len(unique_label_list) 78 | label_to_wl_label_0 = { l: idx for (idx, l) in enumerate(unique_label_list) } 79 | wl_label_0_to_label = { idx: l for (idx, l) in enumerate(unique_label_list) } 80 | n_data_wl_label_0 = g.ndata["label"].clone().detach() 81 | n_data_wl_label_0.apply_(lambda l: label_to_wl_label_0[int(l)]) 82 | g.ndata["wl_label_0"] = n_data_wl_label_0 83 | 84 | # Calculate wl label 1 and the mapping from wl labe 0 to wl label 1 85 | wl_label_0_to_1 = gen_wl_label_udf("wl_label_0", "wl_label_1", n_wl_label_0) 86 | g.update_all(fn.copy_u('wl_label_0', 'wl_label_0'), wl_label_0_to_1) 87 | n_wl_label_1 = len(wl_label_0_to_1.dst_to_src_collection) 88 | 89 | # Calcualte wl label 2 and the mapping from wl label 1 to wl label 2 90 | wl_label_1_to_2 = gen_wl_label_udf("wl_label_1", "wl_label_2", n_wl_label_0 + n_wl_label_1) 91 | g.update_all(fn.copy_u('wl_label_1', 'wl_label_1'), wl_label_1_to_2) 92 | n_wl_label_2 = len(wl_label_1_to_2.dst_to_src_collection) 93 | 94 | # Final node for each graph 95 | hGx = n_wl_label_0 + n_wl_label_1 + n_wl_label_2 96 | 97 | # Create a new graph 98 | edges = [] 99 | edges_to_ids = {} 100 | edge_weight = [] 101 | wl_label_0_to_1_edge = [] 102 | wl_label_1_to_2_edge = [] 103 | wl_label_2_to_hGx_edge = [] 104 | 105 | def upsert_edge(edge: Tuple[int], label_edge_list: List[int]): 106 | if edge in edges_to_ids: 107 | # If exists add edge weight by 1 108 | edge_id = edges_to_ids[edge] 109 | edge_weight[edge_id] += 1 110 | else: 111 | # If not exists, add a new edge 112 | edges_to_ids[edge] = len(edge_weight) 113 | label_edge_list.append(edge) 114 | edge_weight.append(1) 115 | edges.append(edge) 116 | 117 | processed_wl_label_1 = set() 118 | processed_wl_label_2 = set() 119 | # g.nodes() -> tensor([0, 1, 2, 3, 4]) 120 | for node_idx in g.nodes(): 121 | label = g.ndata["label"][node_idx] 122 | wl_label_0 = int(g.ndata["wl_label_0"][node_idx]) 123 | wl_label_1 = int(g.ndata["wl_label_1"][node_idx]) 124 | wl_label_2 = int(g.ndata["wl_label_2"][node_idx]) 125 | 126 | # Add edges from wl_label_0 to wl_label_1 127 | wl_label_0_collection = wl_label_0_to_1.dst_to_src_collection[wl_label_1] 128 | if wl_label_1 not in processed_wl_label_1: 129 | for wl_label_0 in wl_label_0_collection: 130 | edge = (wl_label_0, wl_label_1) 131 | upsert_edge(edge, wl_label_0_to_1_edge) 132 | processed_wl_label_1.add(wl_label_1) 133 | 134 | # Add edges from wl_label_1 to wl_label_2 135 | wl_label_1_collection = wl_label_1_to_2.dst_to_src_collection[wl_label_2] 136 | if wl_label_2 not in processed_wl_label_2: 137 | for wl_label_1 in wl_label_1_collection: 138 | edge = (wl_label_1, wl_label_2) 139 | upsert_edge(edge, wl_label_1_to_2_edge) 140 | processed_wl_label_2.add(wl_label_2) 141 | 142 | # Add edges from wl_label_2 to hGx 143 | edge = (wl_label_2, hGx) 144 | upsert_edge(edge, wl_label_2_to_hGx_edge) 145 | 146 | # new_g is directional graph 147 | # Only wl_label_0 nodes have valid h one hot encoding 148 | diagonal_ones = torch.eye(max_degree) 149 | wl_label_0_h_one_hot = diagonal_ones.index_select(0, torch.LongTensor(unique_label_list)) 150 | 151 | new_g = dgl.graph(tuple(zip(*edges))) 152 | new_g.edata["weight"] = torch.tensor(edge_weight) 153 | new_g.ndata['h'] = torch.zeros(new_g.num_nodes(), max_degree) 154 | new_g.ndata['h'][:n_wl_label_0] = wl_label_0_h_one_hot 155 | return { 156 | "old_g": g, 157 | "new_g": new_g, 158 | "wl_label_0_nodes": sorted(list(wl_label_0_to_label.keys())), 159 | "wl_label_1_nodes": sorted(list(wl_label_0_to_1.dst_to_src_collection.keys())), 160 | "wl_label_2_nodes": sorted(list(wl_label_1_to_2.dst_to_src_collection.keys())), 161 | "wl_label_0_to_label": wl_label_0_to_label, 162 | "wl_label_1_to_0": wl_label_0_to_1.dst_to_src_collection, 163 | "wl_label_2_to_1": wl_label_1_to_2.dst_to_src_collection, 164 | "wl_label_0_to_1_edge": wl_label_0_to_1_edge, 165 | "wl_label_1_to_2_edge": wl_label_1_to_2_edge, 166 | "wl_label_2_to_hGx_edge": wl_label_2_to_hGx_edge 167 | } 168 | 169 | 170 | def read_and_split_to_individual_graph(fname): 171 | 172 | f = open(fname) 173 | lines = f.read() 174 | f.close() 175 | 176 | lines2 = lines.split("t # ") 177 | 178 | lines3 = [g.strip().split("\n") for g in lines2] 179 | 180 | glist = [] 181 | for idx in range(1, len(lines3)): 182 | cur_g = lines3[idx] 183 | 184 | gid_line = cur_g[0].strip().split(' ') 185 | gid = gid_line[0] 186 | 187 | g = nx.Graph(id = gid) 188 | 189 | for idy in range(1, len(cur_g)): 190 | tmp = cur_g[idy].split(' ') 191 | if tmp[0] == 'v': 192 | g.add_node(tmp[1], att="0") 193 | if tmp[0] == 'e': 194 | g.add_edge(tmp[1], tmp[2], att="0") 195 | 196 | glist.append(g) 197 | 198 | return glist 199 | 200 | 201 | def make_a_dglgraph(g: nx.classes.graph.Graph) -> DGLHeteroGraph: 202 | 203 | # max_deg = 40 # largest node degree 204 | # ones = torch.eye(max_deg) 205 | 206 | edges = [[],[]] 207 | for edge in g.edges(): # create un-directed graph 208 | end1 = edge[0] 209 | end2 = edge[1] 210 | edges[0].append(int(end1)) 211 | edges[1].append(int(end2)) 212 | edges[0].append(int(end2)) 213 | edges[1].append(int(end1)) 214 | dg:DGLHeteroGraph = dgl.graph((torch.tensor(edges[0]), torch.tensor(edges[1]))) 215 | 216 | # h0 = dg.in_degrees().view(1,-1).squeeze() # h0.shape -> torch.Size([19]) 217 | dg.ndata['label'] = dg.in_degrees().type(torch.float) 218 | # h0 = ones.index_select(0, h0).float() # convert to one hot tensor, h0.shape -> torch.Size([19, 20]) 219 | # dg.ndata['h'] = h0 220 | 221 | return dg 222 | 223 | if __name__ == "__main__": 224 | 225 | entire_dataset: List[nx.classes.graph.Graph] = read_and_split_to_individual_graph("aids.txt") 226 | entire_dataset = [g for g in entire_dataset if g.number_of_nodes() < 10] 227 | entire_dataset = sorted(entire_dataset, key=lambda x: x.number_of_edges()) 228 | print('entire_dataset len: ', len(entire_dataset)) 229 | 230 | output_dir = "visualized_graphs" 231 | 232 | gid2dgmap = {} # key: graph id, value: DGLHeteroGraph 233 | for g in entire_dataset: 234 | gid = g.graph.get('id') 235 | # if gid != "3134": 236 | # continue 237 | dg = make_a_dglgraph(g) 238 | dg = dgl.add_self_loop(dg) 239 | gid2dgmap[gid] = dg 240 | cg = compress_graph(gid, dg, 20) 241 | visualize_wl_label_graph(gid, cg, output_dir) 242 | 243 | # remove useless *.svg files and keep *.svg.svg files 244 | for i in os.listdir(output_dir): 245 | if i.endswith(".svg") and not i.endswith(".svg.svg"): 246 | os.remove(os.path.join(output_dir, i)) -------------------------------------------------------------------------------- /init_node_sel_model_train.py: -------------------------------------------------------------------------------- 1 | # neighborhood prediction model 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from dgl import DGLGraph 8 | import dgl.function as fn 9 | from functools import partial 10 | import dgl 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | import time 14 | import random 15 | from dgl.nn.pytorch.conv import GINConv, SGConv, SAGEConv 16 | import os 17 | import networkx as nx 18 | from dgl.data import DGLDataset 19 | 20 | 21 | GPUID=1 22 | torch.cuda.set_device(GPUID) 23 | 24 | 25 | def read_and_split_to_individual_graph(fname, gsizeNoLessThan=15, gsizeLessThan=30, writeTo=None, prefix=None, fileformat=None, removeEdgeLabel=True, removeNodeLabel=True, graph_num=100000000): 26 | ''' 27 | aids graphs are in one file. write each graph into a single file 28 | :parm fname: the file storing all graphs in the format as follows: 29 | t # 1 30 | v 1 31 | v 2 32 | e 1 2 33 | :parm gsize: some graphs are too large to compute GED, just skip them at the current stage 34 | :parm writeTo: path to write 35 | :parm prefix: give a prefix to each file, e.g., "g" or "q4", etc 36 | :parm fileformat: None means following the original format of aids; 37 | :parm removeEdgeLabel: all edges are given label "0" 38 | :parm removeNodeLabel: all nodes are given label "0" 39 | ''' 40 | if writeTo is not None: 41 | if prefix is None: 42 | print("You want to write each graph into a single file.") 43 | print("You need to give the prefix of the filename to store each graph. For example, g, q4, q8") 44 | exit(-1) 45 | else: 46 | if writeTo[-1] == '/': 47 | writeTo = writeTo+prefix 48 | else: 49 | writeTo = writeTo+"/"+prefix 50 | if fileformat is None: 51 | print("please specify fileformat: aids, gexf") 52 | exit(-1) 53 | 54 | 55 | f = open(fname) 56 | lines = f.read() 57 | f.close() 58 | 59 | lines2 = lines.split("t # ") 60 | 61 | lines3 = [g.strip().split("\n") for g in lines2] 62 | 63 | glist = [] 64 | for idx in range(1, len(lines3)): 65 | cur_g = lines3[idx] 66 | 67 | gid_line = cur_g[0].strip().split(' ') 68 | gid = gid_line[0] 69 | if len(gid_line) == 4: 70 | glabel = gid_line[3] 71 | g = nx.Graph(id = gid, label = glabel) 72 | elif len(gid_line) == 6: 73 | glabel = gid_line[3] 74 | g = nx.Graph(id = gid, label = glabel) 75 | else: 76 | g = nx.Graph(id = gid) 77 | 78 | 79 | for idy in range(1, len(cur_g)): 80 | tmp = cur_g[idy].split(' ') 81 | if tmp[0] == 'v': 82 | if removeNodeLabel == False: 83 | g.add_node(tmp[1], att=tmp[2]) 84 | else: 85 | g.add_node(tmp[1], att="0") 86 | if tmp[0] == 'e': 87 | if removeEdgeLabel == False: 88 | g.add_edge(tmp[1], tmp[2], att=tmp[3]) 89 | else: 90 | g.add_edge(tmp[1], tmp[2], att="0") 91 | 92 | 93 | if g.number_of_nodes() >= gsizeNoLessThan and g.number_of_nodes() < gsizeLessThan: 94 | if writeTo is not None: 95 | if fileformat == "aids": 96 | f2 = open(writeTo+g.graph.get('id')+".txt", "w") 97 | f2.write("t # "+g.graph.get('id')+"\n") 98 | 99 | if removeNodeLabel: 100 | for i in range(0, len(g.nodes())): 101 | f2.write("v "+str(i)+" 0\n") 102 | else: 103 | for i in range(0, len(g.nodes())): 104 | f2.write("v "+str(i)+" "+g.nodes[str(i)].get("att")+"\n") 105 | 106 | if removeEdgeLabel: 107 | for e in g.edges(): 108 | f2.write("e "+e[0]+" "+e[1]+" 0\n") 109 | else: 110 | for e in g.edges(): 111 | f2.write("e "+e[0]+" "+e[1]+" "+g[e[0]][e[1]].get("att")+"\n") 112 | f2.close() 113 | if fileformat == "gexf": 114 | nx.write_gexf(g, writeTo+g.graph.get("id")+".gexf") 115 | 116 | 117 | glist.append(g) 118 | if len(glist) > graph_num: 119 | return glist 120 | 121 | return glist 122 | 123 | 124 | 125 | 126 | 127 | 128 | class GINDataset(DGLDataset): 129 | 130 | def __init__(self, name, database, queries, exact_ans, isTrain=True, self_loop=False, degree_as_nlabel=False, 131 | raw_dir=None, force_reload=False, verbose=False): 132 | 133 | self._name = name 134 | gin_url = "" 135 | self.database = database 136 | self.queries = queries 137 | self.exact_ans = exact_ans 138 | self.isTrain = isTrain 139 | 140 | 141 | self.qList = [] 142 | self.gPosList = [] 143 | self.ground_truth = [] 144 | 145 | super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel), 146 | raw_dir=raw_dir, force_reload=force_reload, verbose=verbose) 147 | 148 | @property 149 | def raw_path(self): 150 | return os.path.join(".", self.raw_dir) 151 | 152 | 153 | def download(self): 154 | pass 155 | 156 | 157 | def __len__(self): 158 | return len(self.queries) 159 | 160 | 161 | def __getitem__(self, idx): 162 | return self.qList[idx], self.gPosList[idx], self.ground_truth[idx] 163 | 164 | 165 | def _file_path(self): 166 | return self.file 167 | 168 | 169 | def get_topkAll_in_a_list(self, topk, x): 170 | kth = x[topk-1] 171 | res = x[0:topk] 172 | for i in range(topk, len(x)): 173 | if x[i][1] == kth[1]: 174 | res.append(x[i]) 175 | return res 176 | 177 | 178 | def process(self): 179 | for q in self.queries: 180 | self.qList.append(q) 181 | 182 | exact_ans_of_q = self.get_topkAll_in_a_list(200, self.exact_ans[q]) 183 | exact_ans_of_q_IDSet = set() 184 | 185 | gt_label = [] 186 | gPos = [] 187 | for cur_ans in exact_ans_of_q: 188 | exact_ans_of_q_IDSet.add(cur_ans[0]) 189 | 190 | for idx in range(0, len(self.database)): 191 | ele = self.database[idx] 192 | if ele.graph.get('id') in exact_ans_of_q_IDSet: 193 | gt_label.append(1.0) 194 | gPos.append(idx) 195 | else: 196 | if self.isTrain: 197 | rand = np.random.randint(10) 198 | if rand > 8: 199 | gt_label.append(0.0) 200 | gPos.append(idx) 201 | else: 202 | gt_label.append(0.0) 203 | gPos.append(idx) 204 | 205 | 206 | if len(gt_label) != len(gPos): 207 | print("len(gt_label) != len(gPos)") 208 | exit(-1) 209 | 210 | self.ground_truth.append(gt_label) 211 | self.gPosList.append(gPos) 212 | 213 | 214 | 215 | 216 | def save(self): 217 | pass 218 | 219 | def load(self): 220 | pass 221 | 222 | def has_cache(self): 223 | pass 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | class Model(nn.Module): 232 | def __init__(self, gInitEmbMap, gid2dgmap, allDBGEmb): 233 | super(Model, self).__init__() 234 | 235 | self.him = 1024 236 | self.allDBGEmb = allDBGEmb 237 | 238 | # use gin 239 | self.conv1_for_g = GINConv(None, 'mean') 240 | self.conv2_for_g = GINConv(None, 'mean') 241 | 242 | self.gnn_bn = torch.nn.BatchNorm1d(self.him) 243 | self.gnn_bn2 = torch.nn.BatchNorm1d(self.him) 244 | 245 | self.gInitEmbMap = gInitEmbMap 246 | self.gid2dgMap = gid2dgmap 247 | 248 | 249 | self.fc = nn.Linear(self.him*2, 128, bias=True) 250 | self.fc2 = nn.Linear(128, 1, bias=True) 251 | 252 | self.bn = torch.nn.BatchNorm1d(128) 253 | 254 | self.RELU = torch.nn.ReLU(inplace=True) 255 | 256 | self.fc_init = nn.Linear(20, self.him, bias=True) 257 | 258 | 259 | def forward(self, qids, gPosList): 260 | preds = [] 261 | 262 | 263 | for idx in range(0, len(qids)): 264 | qid = qids[idx] 265 | gPos = gPosList[idx] 266 | 267 | dg_of_q = self.gid2dgMap[qid] 268 | 269 | dg_of_q.ndata['h2'] = self.fc_init(dg_of_q.ndata['h']) 270 | dg_of_q.ndata['h2'] = self.RELU(self.gnn_bn(self.conv1_for_g(dg_of_q, dg_of_q.ndata['h2']))) 271 | dg_of_q.ndata['h2'] = self.RELU(self.gnn_bn2(self.conv2_for_g(dg_of_q, dg_of_q.ndata['h2']))) 272 | 273 | qemb = dgl.mean_nodes(dg_of_q, 'h2').squeeze() 274 | gEmbList = self.allDBGEmb.index_select(0, torch.tensor(gPos).cuda()) 275 | qemb = qemb.repeat(gEmbList.shape[0]).view(-1, self.him) 276 | 277 | H = torch.cat([qemb, gEmbList], 1) 278 | H2 = self.RELU(self.bn(self.fc(H))) 279 | probs = torch.sigmoid(self.fc2(H2)).view(1,-1).squeeze() 280 | preds.append(probs) 281 | 282 | 283 | return preds 284 | 285 | 286 | def weighted_binary_cross_entropy(output, target, weights=None): 287 | output = torch.clamp(output, min=1e-6, max=1-1e-6) 288 | 289 | if weights is not None: 290 | assert len(weights) == 2 291 | loss = weights[1] * (target * torch.log(output)) + \ 292 | weights[0] * ((1 - target) * torch.log(1 - output)) 293 | else: 294 | loss = target * torch.log(output) + (1 - target) * torch.log(1 - output) 295 | 296 | return torch.neg(torch.mean(loss)) 297 | 298 | 299 | 300 | 301 | bceLoss = nn.BCELoss() 302 | 303 | 304 | 305 | def myloss(preds, gts): 306 | 307 | loss = [] 308 | for i in range(0, len(preds)): 309 | pred = preds[i] 310 | gt = torch.tensor(gts[i]).cuda() 311 | cur_loss= weighted_binary_cross_entropy(pred, gt, [1.0, 10.0]) 312 | loss.append(cur_loss) 313 | loss = torch.stack(loss) 314 | 315 | return torch.mean(loss) 316 | 317 | 318 | from sklearn.metrics import roc_auc_score 319 | 320 | def my_loss_for_test(preds, gts): 321 | avg_auc = 0 322 | for i in range(0, len(preds)): 323 | pred = preds[i] 324 | gt = gts[i] 325 | avg_auc += roc_auc_score(gt, pred.cpu().detach().numpy()) 326 | avg_auc /= len(preds) 327 | 328 | return avg_auc 329 | 330 | 331 | def check_recall(preds, gts): 332 | avg_precision = 0 333 | for i in range(0, len(preds)): 334 | pred = preds[i].cpu().detach().numpy().tolist() 335 | gt = gts[i] 336 | abc = [] 337 | for idx in range(0, len(pred)): 338 | abc.append( (gt[idx], pred[idx]) ) 339 | abc.sort(key = lambda x: -x[1]) 340 | 341 | precision = 0 342 | top_perc10 = 200 343 | for idx in range(0, top_perc10): 344 | if abc[idx][0] == 1: 345 | precision += 1 346 | precision = precision / top_perc10 347 | 348 | avg_precision += precision 349 | 350 | avg_precision = avg_precision / len(preds) 351 | 352 | print('avg precision', avg_precision) 353 | 354 | return avg_precision 355 | 356 | 357 | def read_initial_gemb(addr): 358 | gEmbMap = {} 359 | gfileList = os.listdir(addr) 360 | for gfile in gfileList: 361 | gID = gfile[1:-4] 362 | f = open(addr+"/"+gfile) 363 | lines = f.read() 364 | f.close() 365 | lines = lines.strip().split('\n') 366 | lines = lines[1:] 367 | nodeEmbList = [] 368 | for line in lines: 369 | tmp = line.strip().split(' ') 370 | tmp2 = [float(ele) for ele in tmp[1:]] 371 | nodeEmbList.append(tmp2) 372 | nodeEmbList = torch.tensor(nodeEmbList) 373 | gEmb = torch.mean(nodeEmbList, 0) 374 | gEmbMap[gID] = gEmb 375 | return gEmbMap 376 | 377 | 378 | 379 | def make_a_dglgraph(g): 380 | 381 | max_deg = 20 # largest node degree of q 382 | ones = torch.eye(max_deg) 383 | 384 | edges = [[],[]] 385 | for edge in g.edges(): 386 | end1 = edge[0] 387 | end2 = edge[1] 388 | edges[0].append(int(end1)) 389 | edges[1].append(int(end2)) 390 | edges[0].append(int(end2)) 391 | edges[1].append(int(end1)) 392 | dg = dgl.graph((torch.tensor(edges[0]), torch.tensor(edges[1]))) 393 | 394 | h0 = dg.in_degrees().view(1,-1).squeeze() 395 | h0 = ones.index_select(0, h0).float() 396 | dg.ndata['h'] = h0 397 | 398 | return dg.to(torch.device('cuda:'+str(GPUID))) 399 | 400 | 401 | 402 | 403 | def readQ2GDistBook(fname, validNodeIDSet=None): 404 | ''' 405 | store distance from the query to a data graph 406 | ''' 407 | f = open(fname) 408 | lines = f.read() 409 | f.close() 410 | lines = lines.strip() 411 | lines = lines.split('\n') 412 | distBook = {} 413 | for line in lines: 414 | tmp = line.split(" ") 415 | if validNodeIDSet != None and tmp[1] in validNodeIDSet: 416 | if tmp[0] in distBook: 417 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 418 | else: 419 | distBook[tmp[0]] = {} 420 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 421 | return distBook 422 | 423 | 424 | 425 | 426 | def get_exact_answer(topk, Q2GDistBook): 427 | answer = {} 428 | for query in Q2GDistBook.keys(): 429 | distToGList = list(Q2GDistBook[query].items()) 430 | distToGList.sort(key=lambda x: x[1]) 431 | dist_thr = -1 432 | if topk-1 < len(distToGList): 433 | dist_thr = distToGList[topk-1][1] 434 | else: 435 | dist_thr = 1000000.0 436 | a = [] 437 | for ele in distToGList: 438 | if ele[1] <= dist_thr: 439 | a.append(ele) 440 | else: 441 | break 442 | answer[query] = a 443 | return answer 444 | 445 | 446 | 447 | 448 | def collate(samples): 449 | qids, gPosList, gtlabels = map(list, zip(*samples)) 450 | 451 | return qids, gPosList, gtlabels 452 | 453 | 454 | ##################################################################################################### 455 | ### do the job as follows 456 | ##################################################################################################### 457 | 458 | # read in proximity graph 459 | pgTmp = read_and_split_to_individual_graph("PG.aids.nx", 0, 10000000000) 460 | pgTmp = pgTmp[0] 461 | 462 | 463 | gInitEmbMap = read_initial_gemb('data/AIDS/emb/aids.emb') 464 | print('read g init emb done.') 465 | 466 | entire_dataset = read_and_split_to_individual_graph("aids.txt", 0, 10000000) 467 | print('entire_dataset len: ', len(entire_dataset)) 468 | gid2gmap = {} 469 | for g in entire_dataset: 470 | gid2gmap[g.graph.get("id")] = g 471 | gid2dgmap = {} 472 | for g in entire_dataset: 473 | dg = make_a_dglgraph(g) 474 | dg = dgl.add_self_loop(dg) 475 | gid2dgmap[g.graph.get('id')] = dg.to(torch.device('cuda:'+str(GPUID))) 476 | 477 | 478 | 479 | database = entire_dataset[0:40000] 480 | database_ids = set([ele.graph.get('id') for ele in database]) 481 | 482 | 483 | databaseGEmb = [] 484 | for g in database: 485 | gEmb = gInitEmbMap[g.graph.get('id')] 486 | databaseGEmb.append(gEmb) 487 | databaseGEmb = torch.stack(databaseGEmb).cuda() 488 | 489 | train_queries_ids = [] 490 | f = open('data/AIDS/query_train.txt') 491 | lines = f.read() 492 | f.close() 493 | lines = lines.strip().split('\n') 494 | for line in lines: 495 | qid = line.strip() 496 | train_queries_ids.append(qid) 497 | 498 | 499 | test_queries_ids = [] 500 | f = open('data/AIDS/query_test.txt') 501 | lines = f.read() 502 | f.close() 503 | lines = lines.strip().split('\n') 504 | for line in lines: 505 | qid = line.strip() 506 | test_queries_ids.append(qid) 507 | 508 | 509 | 510 | 511 | q2GDistBook = readQ2GDistBook("data/AIDS/aids.txt", database_ids) 512 | print('readQ2GDistBook done') 513 | exact_ans = get_exact_answer(100000000, q2GDistBook) 514 | 515 | 516 | train_data = GINDataset("aids", database, train_queries_ids, exact_ans, isTrain=True) 517 | test_data = GINDataset("aids", database, test_queries_ids, exact_ans, isTrain=False) 518 | 519 | 520 | 521 | 522 | dataloader = DataLoader( 523 | train_data, 524 | batch_size=200, 525 | collate_fn=collate, 526 | drop_last=False, 527 | shuffle=True) 528 | 529 | testdataloader = DataLoader( 530 | test_data, 531 | batch_size=20, 532 | collate_fn=collate, 533 | drop_last=False, 534 | shuffle=True) 535 | 536 | 537 | n_epochs = 8000 # epochs to train 538 | lr = 0.01 # learning rate 539 | l2norm = 0 # L2 norm coefficient 540 | 541 | 542 | model = Model(gInitEmbMap, gid2dgmap, databaseGEmb) 543 | model.cuda() 544 | 545 | optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = l2norm) 546 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.9) 547 | 548 | 549 | 550 | 551 | with torch.autograd.set_detect_anomaly(True): 552 | for epoch in range(n_epochs): 553 | model.train() 554 | start_train_time = time.time() 555 | batch_count = 0 556 | for qids, gPosList, gts in dataloader: 557 | print('='*40+" epoch ", epoch, "batch ", batch_count) 558 | preds = model(qids, gPosList) 559 | loss = myloss(preds, gts) 560 | print('loss ', loss) 561 | auc = my_loss_for_test(preds, gts) 562 | print('auc ', auc) 563 | optimizer.zero_grad() 564 | loss.backward() 565 | optimizer.step() 566 | batch_count += 1 567 | 568 | 569 | # do a test after an epoch 570 | print("do test ....") 571 | model.eval() 572 | with torch.no_grad(): 573 | # do valid 574 | # print("start valid ................") 575 | # for valid_graph, valid_label, graphIDs, nodeMaps in validdataloader: 576 | # valid_pred = model(valid_graph, valid_label, graphIDs, nodeMaps) 577 | # tmp = my_loss_for_test(valid_pred, valid_label) 578 | # valid_loss = tmp[0] 579 | # print('valid loss ', valid_loss) 580 | # avg_mse_of_valid = avg_mse_of_valid + valid_loss.item() 581 | # valid_count = valid_count + 1 582 | # print("+++++++"*5, "valid finish") 583 | 584 | 585 | # do test 586 | print("start test ................") 587 | print("epoch", epoch) 588 | for qids, gPosList, gts in testdataloader: 589 | preds = model(qids, gPosList) 590 | auc = my_loss_for_test(preds, gts) 591 | check_recall(preds, gts) 592 | print('auc ', auc) 593 | 594 | print("+++++++"*5, 'test finish') 595 | 596 | 597 | 598 | -------------------------------------------------------------------------------- /neigh_pruning_model_training.py: -------------------------------------------------------------------------------- 1 | # neighborhood prediction model 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from dgl import DGLGraph 8 | import dgl.function as fn 9 | from functools import partial 10 | import dgl 11 | from torch.utils.data import DataLoader 12 | import numpy as np 13 | import time 14 | import random 15 | from dgl.nn.pytorch.conv import GINConv, SGConv, SAGEConv 16 | import os 17 | import networkx as nx 18 | from dgl.data import DGLDataset 19 | import jpype 20 | 21 | 22 | 23 | 24 | def read_and_split_to_individual_graph(fname, gsizeNoLessThan=15, gsizeLessThan=30, writeTo=None, prefix=None, fileformat=None, removeEdgeLabel=True, removeNodeLabel=True, graph_num=100000000): 25 | ''' 26 | aids graphs are in one file. write each graph into a single file 27 | :parm fname: the file storing all graphs in the format as follows: 28 | t # 1 29 | v 1 30 | v 2 31 | e 1 2 32 | :parm gsize: some graphs are too large to compute GED, just skip them at the current stage 33 | :parm writeTo: path to write 34 | :parm prefix: give a prefix to each file, e.g., "g" or "q4", etc 35 | :parm fileformat: None means following the original format of aids; 36 | :parm removeEdgeLabel: all edges are given label "0" 37 | :parm removeNodeLabel: all nodes are given label "0" 38 | ''' 39 | if writeTo is not None: 40 | if prefix is None: 41 | print("You want to write each graph into a single file.") 42 | print("You need to give the prefix of the filename to store each graph. For example, g, q4, q8") 43 | exit(-1) 44 | else: 45 | if writeTo[-1] == '/': 46 | writeTo = writeTo+prefix 47 | else: 48 | writeTo = writeTo+"/"+prefix 49 | if fileformat is None: 50 | print("please specify fileformat: aids, gexf") 51 | exit(-1) 52 | 53 | 54 | f = open(fname) 55 | lines = f.read() 56 | f.close() 57 | 58 | lines2 = lines.split("t # ") 59 | 60 | lines3 = [g.strip().split("\n") for g in lines2] 61 | 62 | glist = [] 63 | for idx in range(1, len(lines3)): 64 | cur_g = lines3[idx] 65 | 66 | gid_line = cur_g[0].strip().split(' ') 67 | gid = gid_line[0] 68 | if len(gid_line) == 4: 69 | glabel = gid_line[3] 70 | g = nx.Graph(id = gid, label = glabel) 71 | elif len(gid_line) == 6: 72 | glabel = gid_line[3] 73 | g = nx.Graph(id = gid, label = glabel) 74 | else: 75 | g = nx.Graph(id = gid) 76 | 77 | 78 | for idy in range(1, len(cur_g)): 79 | tmp = cur_g[idy].split(' ') 80 | if tmp[0] == 'v': 81 | if removeNodeLabel == False: 82 | g.add_node(tmp[1], att=tmp[2]) 83 | else: 84 | g.add_node(tmp[1], att="0") 85 | if tmp[0] == 'e': 86 | if removeEdgeLabel == False: 87 | g.add_edge(tmp[1], tmp[2], att=tmp[3]) 88 | else: 89 | g.add_edge(tmp[1], tmp[2], att="0") 90 | 91 | 92 | if g.number_of_nodes() >= gsizeNoLessThan and g.number_of_nodes() < gsizeLessThan: 93 | if writeTo is not None: 94 | if fileformat == "aids": 95 | f2 = open(writeTo+g.graph.get('id')+".txt", "w") 96 | f2.write("t # "+g.graph.get('id')+"\n") 97 | 98 | if removeNodeLabel: 99 | for i in range(0, len(g.nodes())): 100 | f2.write("v "+str(i)+" 0\n") 101 | else: 102 | for i in range(0, len(g.nodes())): 103 | f2.write("v "+str(i)+" "+g.nodes[str(i)].get("att")+"\n") 104 | 105 | if removeEdgeLabel: 106 | for e in g.edges(): 107 | f2.write("e "+e[0]+" "+e[1]+" 0\n") 108 | else: 109 | for e in g.edges(): 110 | f2.write("e "+e[0]+" "+e[1]+" "+g[e[0]][e[1]].get("att")+"\n") 111 | f2.close() 112 | if fileformat == "gexf": 113 | nx.write_gexf(g, writeTo+g.graph.get("id")+".gexf") 114 | 115 | 116 | glist.append(g) 117 | if len(glist) > graph_num: 118 | return glist 119 | 120 | return glist 121 | 122 | 123 | 124 | 125 | class GINDataset(DGLDataset): 126 | 127 | def __init__(self, name, trainFileName, gid2dgmap, gID2InitEmbMap, gID2InitTensorIndexMap, neighNum, margin, isTrain, self_loop=False, degree_as_nlabel=False, 128 | raw_dir=None, force_reload=False, verbose=False): 129 | 130 | self._name = name 131 | gin_url = "" 132 | 133 | self.gid2dgmap = gid2dgmap 134 | self.gID2InitEmbMap = gID2InitEmbMap 135 | self.gID2InitTensorIndexMap = gID2InitTensorIndexMap 136 | self.isTrain = isTrain 137 | self.neighNum = neighNum 138 | self.margin = margin 139 | self.trainFileName = trainFileName 140 | 141 | self.qList = [] 142 | self.pgNodeEmbList = [] 143 | self.ground_truth = [] 144 | self.neighInitEmbIndexList = [] 145 | self.mask_of_1_list = [] 146 | self.mask_of_0_list = [] 147 | self.class_weight_list = [] 148 | 149 | super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel), 150 | raw_dir=raw_dir, force_reload=force_reload, verbose=verbose) 151 | 152 | @property 153 | def raw_path(self): 154 | return os.path.join(".", self.raw_dir) 155 | 156 | 157 | def download(self): 158 | pass 159 | 160 | 161 | def __len__(self): 162 | return len(self.qList) 163 | 164 | 165 | def __getitem__(self, idx): 166 | return self.qList[idx], self.pgNodeEmbList[idx], self.neighInitEmbIndexList[idx], self.ground_truth[idx], self.mask_of_1_list[idx], self.mask_of_0_list[idx], self.class_weight_list[idx] 167 | 168 | 169 | def _file_path(self): 170 | return self.file 171 | 172 | 173 | def process(self): 174 | if self.isTrain: 175 | f = open(self.trainFileName) # do not forget to set this address 176 | else: 177 | f = open('test.data') # do not forget to set this address 178 | 179 | lines = f.read() 180 | f.close() 181 | lines = lines.strip().split('\n') 182 | 183 | for line in lines: 184 | tmp = line.strip().split(' ') 185 | 186 | self.qList.append(self.gid2dgmap[tmp[0]]) 187 | self.pgNodeEmbList.append(self.gID2InitEmbMap[tmp[1]]) 188 | q_g_ged = float(tmp[2]) 189 | gt = [int(ele) for ele in tmp[3:3+self.neighNum]] 190 | print(sum(gt)/self.neighNum) 191 | gt = gt[st:ed] 192 | print('st', st, 'ed', ed) 193 | self.ground_truth.append(gt) 194 | mask_of_1 = [] 195 | mask_of_0 = [] 196 | for i in range(0, len(gt)): 197 | if gt[i] == 1: 198 | mask_of_1.append(1) 199 | mask_of_0.append(0) 200 | else: 201 | mask_of_0.append(1) 202 | mask_of_1.append(0) 203 | 204 | mask_of_1 = torch.tensor(mask_of_1).bool() 205 | mask_of_0 = torch.tensor(mask_of_0).bool() 206 | 207 | self.mask_of_1_list.append(mask_of_1) 208 | self.mask_of_0_list.append(mask_of_0) 209 | 210 | neighIDs_and_neighDists = tmp[3+self.neighNum:] 211 | neighIDs = neighIDs_and_neighDists[0 : int(len(neighIDs_and_neighDists)/2)] 212 | neighDists = neighIDs_and_neighDists[int(len(neighIDs_and_neighDists)/2) : ] 213 | if len(neighIDs) != len(neighDists): 214 | print('ERROR! len(neighIDs) != len(neighDists)') 215 | print("neighIDs", len(neighIDs)) 216 | print("neighDists", len(neighDists)) 217 | exit(-1) 218 | 219 | 220 | neighIDs = neighIDs[st:ed] 221 | posInInitTensor = [] 222 | for neighID in neighIDs: 223 | posInInitTensor.append(self.gID2InitTensorIndexMap[neighID]) 224 | while len(posInInitTensor) < (ed-st): 225 | # pad zero 226 | posInInitTensor.append(40000) # aids dataset size is 40000, do not forget to re-set it for different dataset 227 | 228 | class_weight = [] 229 | neighDists = neighDists[st:ed] 230 | # print(neighDists) 231 | # print(q_g_ged) 232 | for idx in range(0, len(neighDists)): 233 | neighDist = float(neighDists[idx]) 234 | dist_diff = neighDist - q_g_ged - self.margin # the self.margin here needs to be consistent with the margin in train.data generation 235 | class_weight.append(abs(dist_diff)) 236 | while len(class_weight) < (ed-st): 237 | class_weight.append(0.0) 238 | class_weight = np.array(class_weight) 239 | # print(class_weight) 240 | # class_weight = class_weight/(np.max(class_weight)+0.00000001) 241 | class_weight = class_weight/(100.0) 242 | # print(class_weight) 243 | class_weight = np.exp(class_weight) 244 | 245 | 246 | self.class_weight_list.append(class_weight) 247 | self.neighInitEmbIndexList.append(posInInitTensor) 248 | print(len(self.ground_truth)) 249 | print('-------') 250 | 251 | 252 | 253 | def save(self): 254 | pass 255 | 256 | def load(self): 257 | pass 258 | 259 | def has_cache(self): 260 | pass 261 | 262 | 263 | 264 | 265 | import torch.autograd.profiler as profiler 266 | 267 | class Model(nn.Module): 268 | def __init__(self, hdim, outputNum): 269 | super(Model, self).__init__() 270 | 271 | self.hdim = hdim 272 | self.outputNum = outputNum 273 | 274 | self.RELU = torch.nn.ReLU(inplace=True) 275 | 276 | self.fc_init = nn.Linear(20, self.hdim, bias=True) 277 | self.conv1_for_g = GINConv(None, 'mean') 278 | self.conv2_for_g = GINConv(None, 'mean') 279 | # self.conv1_for_g = GINConv(nn.Linear(hdim, hdim, bias=True), 'mean') 280 | # self.conv2_for_g = GINConv(nn.Linear(hdim, hdim, bias=True), 'mean') 281 | self.gnn_bn = torch.nn.BatchNorm1d(hdim) 282 | self.gnn_bn2 = torch.nn.BatchNorm1d(hdim) 283 | 284 | self.fc = nn.Linear(self.hdim*3, 256, bias=True) 285 | self.fc2 = nn.Linear(256, 256, bias=True) 286 | self.fc3 = nn.Linear(256, 256, bias=True) 287 | self.fc4 = nn.Linear(256, 1, bias=True) 288 | self.bn = torch.nn.BatchNorm1d(256) 289 | self.bn2 = torch.nn.BatchNorm1d(256) 290 | self.bn3 = torch.nn.BatchNorm1d(256) 291 | 292 | self.dp = torch.nn.Dropout(0.5) 293 | 294 | 295 | 296 | def forward(self, qList, pgNodeEmbList, neighEmbList, classWeightList): 297 | 298 | batch_size = len(pgNodeEmbList) 299 | number_of_outputs = self.outputNum # do not forget to set it for different dataset 300 | 301 | qList.ndata['h'] = self.fc_init(qList.ndata['h']) 302 | qList.ndata['h'] = self.RELU(self.gnn_bn(self.conv1_for_g(qList, qList.ndata['h']))) 303 | qList.ndata['h'] = self.RELU(self.gnn_bn2(self.conv2_for_g(qList, qList.ndata['h']))) 304 | qemb = dgl.mean_nodes(qList, 'h') 305 | 306 | a = torch.cat([qemb, pgNodeEmbList], 1) 307 | a = a.repeat(1, number_of_outputs).view(-1, self.hdim*2) 308 | 309 | b = torch.cat([a, neighEmbList], 1) 310 | 311 | H = self.RELU(self.bn(self.fc(b))) 312 | H2 = self.RELU(self.bn2(self.fc2(H))) 313 | H3 = self.RELU(self.bn3(self.fc3(H2))) 314 | preds = torch.sigmoid(self.fc4(H3)) 315 | 316 | preds = preds.view(batch_size, number_of_outputs) 317 | 318 | return preds, H3 319 | 320 | 321 | bceLoss = nn.BCELoss() 322 | mseLoss = nn.MSELoss() 323 | 324 | 325 | def weighted_binary_cross_entropy(output, target, weights=None): 326 | output = torch.clamp(output, min=1e-6, max=1-1e-6) 327 | 328 | if weights is not None: 329 | assert len(weights) == 2 330 | loss = weights[1] * (target * torch.log(output)) + \ 331 | weights[0] * ((1 - target) * torch.log(1 - output)) 332 | else: 333 | loss = target * torch.log(output) + (1 - target) * torch.log(1 - output) 334 | 335 | return torch.neg(torch.mean(loss)) 336 | 337 | 338 | def myloss(epoch, preds, gts, mask_of_1_list, mask_of_0_list, classWeightList): 339 | bce = weighted_binary_cross_entropy(preds, gts, [10.0, 1.0]) 340 | return bce 341 | 342 | 343 | def perf_measure(y_actual, y_hat): 344 | TP = 0 345 | FP = 0 346 | TN = 0 347 | FN = 0 348 | 349 | for i in range(len(y_hat)): 350 | if y_actual[i]==y_hat[i]==1: 351 | TP += 1 352 | if y_hat[i]==1 and y_actual[i]!=y_hat[i]: 353 | FP += 1 354 | if y_actual[i]==y_hat[i]==0: 355 | TN += 1 356 | if y_hat[i]==0 and y_actual[i]!=y_hat[i]: 357 | FN += 1 358 | 359 | return TP, FP, TN, FN 360 | 361 | 362 | from sklearn.metrics import confusion_matrix 363 | 364 | def myloss_for_test(preds, gts, thr): 365 | fpList = [] 366 | fnList = [] 367 | tpList = [] 368 | tnList = [] 369 | fprList = [] 370 | fnrList = [] 371 | tprList = [] 372 | for i in range(0, preds.shape[0]): 373 | pred = preds[i] 374 | gt = gts[i] 375 | pred = pred.view(-1,1).cpu().detach().numpy() 376 | gt = gt.view(-1,1).cpu().detach().numpy() 377 | y = (pred > thr) 378 | y = y.astype(int) 379 | 380 | TP, FP, TN, FN = perf_measure(gt, y) 381 | fpList.append(FP) 382 | fnList.append(FN) 383 | tpList.append(TP) 384 | tnList.append(TN) 385 | 386 | # Sensitivity, hit rate, recall, or true positive rate 387 | TPR = TP/(TP+FN+0.000001) 388 | # Specificity or true negative rate 389 | TNR = TN/(TN+FP+0.000001) 390 | # Precision or positive predictive value 391 | PPV = TP/(TP+FP+0.000001) 392 | # Negative predictive value 393 | NPV = TN/(TN+FN+0.000001) 394 | # Fall out or false positive rate 395 | FPR = FP/(FP+TN+0.000001) 396 | # False negative rate 397 | FNR = FN/(TP+FN+0.000001) 398 | # False discovery rate 399 | FDR = FP/(TP+FP+0.000001) 400 | 401 | # Overall accuracy 402 | ACC = (TP+TN)/(TP+FP+FN+TN) 403 | 404 | fprList.append(FPR) 405 | fnrList.append(FNR) 406 | tprList.append(TPR) 407 | 408 | fpList = np.array(fpList) 409 | fnList = np.array(fnList) 410 | tpList = np.array(tpList) 411 | tnList = np.array(tnList) 412 | fprList = np.array(fprList) 413 | fnrList = np.array(fnrList) 414 | tprList = np.array(tprList) 415 | 416 | return np.mean(fprList), np.mean(fnrList), np.mean(tprList), np.mean(fpList), np.mean(fnList), np.mean(tpList), np.mean(tnList) 417 | 418 | 419 | 420 | def read_initial_gemb(addr): 421 | gEmbMap = {} 422 | gfileList = os.listdir(addr) 423 | for gfile in gfileList: 424 | gID = gfile[1:-4] 425 | f = open(addr+"/"+gfile) 426 | lines = f.read() 427 | f.close() 428 | lines = lines.strip().split('\n') 429 | lines = lines[1:] 430 | nodeEmbList = [] 431 | for line in lines: 432 | tmp = line.strip().split(' ') 433 | tmp2 = [float(ele) for ele in tmp[1:]] 434 | nodeEmbList.append(tmp2) 435 | nodeEmbList = torch.tensor(nodeEmbList) 436 | gEmb = torch.mean(nodeEmbList, 0) 437 | gEmbMap[gID] = gEmb#.cuda() 438 | return gEmbMap 439 | 440 | 441 | 442 | def make_a_dglgraph(g): 443 | 444 | max_deg = 20 # largest node degree of q 445 | ones = torch.eye(max_deg) 446 | 447 | edges = [[],[]] 448 | for edge in g.edges(): 449 | end1 = edge[0] 450 | end2 = edge[1] 451 | edges[0].append(int(end1)) 452 | edges[1].append(int(end2)) 453 | edges[0].append(int(end2)) 454 | edges[1].append(int(end1)) 455 | dg = dgl.graph((torch.tensor(edges[0]), torch.tensor(edges[1]))) 456 | 457 | h0 = dg.in_degrees().view(1,-1).squeeze() 458 | h0 = ones.index_select(0, h0).float() 459 | dg.ndata['h'] = h0 460 | 461 | return dg#.to(torch.device('cuda:'+str(GPUID))) 462 | 463 | 464 | 465 | def readQ2GDistBook(fname, validNodeIDSet=None): 466 | ''' 467 | store distance from the query to a data graph 468 | ''' 469 | f = open(fname) 470 | lines = f.read() 471 | f.close() 472 | lines = lines.strip() 473 | lines = lines.split('\n') 474 | distBook = {} 475 | for line in lines: 476 | tmp = line.split(" ") 477 | if validNodeIDSet != None and tmp[1] in validNodeIDSet: 478 | if tmp[0] in distBook: 479 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 480 | else: 481 | distBook[tmp[0]] = {} 482 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 483 | return distBook 484 | 485 | 486 | def make_big_init_emb_tensor(gID2InitEmbMap, hdim): 487 | gid2posMap = {} 488 | embList = [] 489 | for k,v in gID2InitEmbMap.items(): 490 | embList.append(v) 491 | gid2posMap[k] = len(embList)-1 492 | embList.append(torch.zeros(hdim)) 493 | 494 | return torch.stack(embList), gid2posMap 495 | 496 | 497 | def get_exact_answer(topk, Q2GDistBook): 498 | answer = {} 499 | for query in Q2GDistBook.keys(): 500 | distToGList = list(Q2GDistBook[query].items()) 501 | distToGList.sort(key=lambda x: x[1]) 502 | dist_thr = -1 503 | if topk-1 < len(distToGList): 504 | dist_thr = distToGList[topk-1][1] 505 | else: 506 | dist_thr = 1000000.0 507 | a = [] 508 | for ele in distToGList: 509 | if ele[1] <= dist_thr: 510 | a.append(ele) 511 | else: 512 | break 513 | answer[query] = a 514 | return answer 515 | 516 | 517 | 518 | from functools import partial 519 | 520 | 521 | def my_collate_fn(samples, gInitEmbTensor): 522 | qEmbs, pgNodes, neighIndexLists, gts, mask_of_1_list, mask_of_0_list, classWeightList = map(list, zip(*samples)) 523 | 524 | neighIndexLists = torch.tensor(neighIndexLists).view(1,-1).squeeze() 525 | neighInitEmbs = torch.index_select(gInitEmbTensor, 0, neighIndexLists) 526 | 527 | return dgl.batch(qEmbs), torch.stack(pgNodes), neighInitEmbs, torch.tensor(gts), torch.stack(mask_of_1_list), torch.stack(mask_of_0_list), torch.tensor(classWeightList) 528 | 529 | 530 | 531 | 532 | if __name__ == '__main__': 533 | 534 | D_of_pg = 80 # this is the max degree of pg 535 | model_prediction_num = 10 # 80/8=20 536 | prune_margin = 1 537 | 538 | GPUID=7 539 | st = GPUID*10 # 8 GPU cards 540 | ed = (GPUID+1)*10 541 | 542 | GPUID = GPUID % 4 543 | 544 | 545 | print("st ", st) 546 | print("ed ", ed) 547 | print('GPUID ', GPUID) 548 | 549 | if ed > 80: 550 | print('ERROR! ed > 80') 551 | exit(-1) 552 | 553 | torch.cuda.set_device(GPUID) 554 | 555 | 556 | ##################################################################################################### 557 | ### do the job as follows 558 | ##################################################################################################### 559 | 560 | 561 | emb_dim = 512 # dim of embedding 562 | gID2InitEmbMap = read_initial_gemb('data/AIDS/emb/aids'+str(emb_dim)) # it is pre-computed by node2vec on csr 563 | gInitEmbBigTensor, gID2InitTensorIndexMap = make_big_init_emb_tensor(gID2InitEmbMap, emb_dim) 564 | print('read g init emb done.') 565 | print("gInitEmbBigTensor.shape", gInitEmbBigTensor.shape) 566 | 567 | 568 | 569 | entire_dataset = read_and_split_to_individual_graph("data/AIDS/aids.txt", 0, 10000000) 570 | print('entire_dataset len: ', len(entire_dataset)) 571 | gid2dgmap = {} 572 | for g in entire_dataset: 573 | dg = make_a_dglgraph(g) 574 | dg = dgl.add_self_loop(dg) 575 | gid2dgmap[g.graph.get('id')] = dg 576 | 577 | 578 | train_data = GINDataset("aids", 'aids_train.perc20.data', gid2dgmap, gID2InitEmbMap, gID2InitTensorIndexMap, D_of_pg, prune_margin, isTrain=True) 579 | test_data = GINDataset("aids", 'aids_train.perc20.data', gid2dgmap, gID2InitEmbMap, gID2InitTensorIndexMap, D_of_pg, prune_margin, isTrain=False) 580 | 581 | 582 | dataloader = DataLoader( 583 | train_data, 584 | batch_size=1000, 585 | collate_fn=partial(my_collate_fn, gInitEmbTensor=gInitEmbBigTensor), 586 | num_workers=6, 587 | drop_last=False, 588 | shuffle=True) 589 | 590 | testdataloader = DataLoader( 591 | test_data, 592 | batch_size=1, 593 | collate_fn=partial(my_collate_fn, gInitEmbTensor=gInitEmbBigTensor), 594 | drop_last=False, 595 | shuffle=False) 596 | 597 | 598 | n_epochs = 1000 # epochs to train 599 | lr = 0.05 # learning rate 600 | l2norm = 0 # L2 norm coefficient 601 | 602 | 603 | # create model 604 | model = Model(hdim=emb_dim, outputNum=model_prediction_num) 605 | model.cuda() 606 | 607 | 608 | optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = l2norm) 609 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 2, gamma = 0.95) 610 | # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 611 | 612 | torch.backends.cudnn.enabled = True 613 | torch.backends.cudnn.benchmark = True 614 | 615 | start_train_time = time.time() 616 | with torch.autograd.set_detect_anomaly(True): 617 | for epoch in range(n_epochs): 618 | old_lr = optimizer.param_groups[0]['lr'] 619 | print("cur lr: ", old_lr) 620 | if (epoch+1) % 10 == 0: 621 | torch.save(model.state_dict(), "aids.D"+str(D_of_pg)+".perc20_model_save/prune_ged"+str(st)+"_"+str(ed)+'.e'+str(epoch)+".pkl") 622 | model.train() 623 | batch_count = 0 624 | for qids, pgNodes, neighEmbList, gts, mask_of_1_list, mask_of_0_list, classWeightList in dataloader: 625 | preds, H3 = model(qids.to(torch.device('cuda:'+str(GPUID))), pgNodes.cuda(), neighEmbList.cuda(), classWeightList.cuda()) 626 | loss = myloss(epoch, preds, gts.float().cuda(), mask_of_1_list.cuda(), mask_of_0_list.cuda(), classWeightList.cuda()) 627 | if batch_count % 10 == 0: 628 | print('='*40+" epoch ", epoch, "batch ", batch_count) 629 | print('loss ', loss.item()) 630 | print('st ', st, ' ed ', ed, ' GPUID ', GPUID) 631 | optimizer.zero_grad() 632 | loss.backward() 633 | optimizer.step() 634 | torch.cuda.empty_cache() 635 | batch_count += 1 636 | 637 | 638 | # do a test after an epoch 639 | if epoch % 10 == 0: 640 | print("do test ....") 641 | model.eval() 642 | with torch.no_grad(): 643 | # do test 644 | for m in model.modules(): 645 | if isinstance(m, nn.BatchNorm1d): 646 | m.track_running_stats=False 647 | print("start test ................") 648 | print("epoch", epoch) 649 | # for qids, pgNodes, gts in testdataloader: 650 | for qids, pgNodes, neighEmbList, gts, index_of_1_list, index_of_0_list in testdataloader: 651 | preds = model(qids.to(torch.device('cuda:'+str(GPUID))), pgNodes.cuda(), neighEmbList) 652 | fpr, fnr, tpr, FP, FN, TP, TN = myloss_for_test(preds, gts, 0.5) 653 | print('fpr ', fpr, 'fnr ', fnr, 'tpr ', tpr) 654 | print('preds', preds) 655 | print('gts', gts) 656 | print('index_of_1_list', index_of_1_list) 657 | print('index_of_0_list', index_of_0_list) 658 | pred_1_probs = torch.masked_select(preds, index_of_1_list) 659 | pred_0_probs = torch.masked_select(preds, index_of_0_list) 660 | print('pred_1_probs', pred_1_probs) 661 | print('pred_0_probs', pred_0_probs) 662 | fnList = (pred_1_probs < 0.5).int().view(1,-1).squeeze(dim=0).cpu().detach().numpy().tolist() 663 | fpList = (pred_0_probs > 0.5).int().view(1,-1).squeeze(dim=0).cpu().detach().numpy().tolist() 664 | print('fnList', fnList) 665 | print('fpList', fpList) 666 | fnRatio = sum(fnList)/(len(fnList)+0.0000001) 667 | fpRatio = sum(fpList)/(len(fpList)+0.0000001) 668 | print(' fpR ', fpRatio, 'fnR ', fnRatio) 669 | print("+++++++"*5, 'test finish') 670 | 671 | 672 | # do a test after an epoch 673 | print("do test ....") 674 | model.eval() 675 | with torch.no_grad(): 676 | # do test 677 | for m in model.modules(): 678 | if isinstance(m, nn.BatchNorm1d): 679 | m.track_running_stats=False 680 | 681 | 682 | print("start test ................") 683 | print("epoch", epoch) 684 | for qids, pgNodes, neighEmbList, gts, index_of_1_list, index_of_0_list in testdataloader: 685 | preds = model(qids.to(torch.device('cuda:'+str(GPUID))), pgNodes.cuda(), neighEmbList.cuda()) 686 | fpr, fnr, tpr, FP, FN, TP, TN = myloss_for_test(preds, gts, 0.5) 687 | print('fpr ', fpr, 'fnr ', fnr, 'tpr ', tpr) 688 | print('preds', preds) 689 | print('gts', gts) 690 | print('index_of_1_list', index_of_1_list) 691 | print('index_of_0_list', index_of_0_list) 692 | pred_1_probs = torch.masked_select(preds, index_of_1_list) 693 | pred_0_probs = torch.masked_select(preds, index_of_0_list) 694 | print('pred_1_probs', pred_1_probs) 695 | print('pred_0_probs', pred_0_probs) 696 | fnList = (pred_1_probs < 0.5).int().view(1,-1).squeeze(dim=0).cpu().detach().numpy().tolist() 697 | fpList = (pred_0_probs > 0.5).int().view(1,-1).squeeze(dim=0).cpu().detach().numpy().tolist() 698 | print('fnList', fnList) 699 | print('fpList', fpList) 700 | fnRatio = sum(fnList)/(len(fnList)+0.0000001) 701 | fpRatio = sum(fpList)/(len(fpList)+0.0000001) 702 | print(' fpR ', fpRatio, 'fnR ', fnRatio) 703 | print("+++++++"*5, 'test finish') 704 | 705 | end_train_time = time.time() 706 | print('train time (s) ', (end_train_time - start_train_time)) 707 | 708 | 709 | print("st ", st) 710 | print("ed ", ed) 711 | print('GPUID ', GPUID) 712 | 713 | 714 | torch.save(model.state_dict(), "aids.D"+str(D_of_pg)+".perc10_model_save/prune_ged"+str(st)+"_"+str(ed)+".pkl") 715 | -------------------------------------------------------------------------------- /routing_with_neigh_pruning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import networkx as nx 3 | import random 4 | import time 5 | import subprocess 6 | import heapq 7 | import logging 8 | import os 9 | import numpy as np 10 | import jpype 11 | # from Properties import * 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from dgl import DGLGraph 17 | import dgl.function as fn 18 | from functools import partial 19 | import dgl 20 | from torch.utils.data import DataLoader 21 | import numpy as np 22 | import time 23 | import random 24 | from dgl.nn.pytorch.conv import GINConv, SGConv, SAGEConv 25 | import os 26 | import networkx as nx 27 | from dgl.data import DGLDataset 28 | from sklearn.cluster import KMeans 29 | import heapq 30 | 31 | 32 | logging.basicConfig(level=logging.ERROR) 33 | DEBUG = False 34 | 35 | 36 | jarpath = os.path.join(os.path.abspath('.'), 'graph-matching-toolkit/graph-matching-toolkit.jar') 37 | jpype.startJVM(jpype.getDefaultJVMPath(), "-ea", "-Djava.class.path=%s" % jarpath) 38 | javaClass = jpype.JClass('algorithms.GraphMatching') 39 | 40 | 41 | GPUID = 2 42 | 43 | modelMap = {} 44 | 45 | 46 | def read_and_split_to_individual_graph(fname, gsizeNoLessThan=15, gsizeLessThan=30, writeTo=None, prefix=None, fileformat=None, removeEdgeLabel=True, removeNodeLabel=True, graph_num=100000000): 47 | ''' 48 | aids graphs are in one file. write each graph into a single file 49 | :parm fname: the file storing all graphs in the format as follows: 50 | t # 1 51 | v 1 52 | v 2 53 | e 1 2 54 | :parm gsize: some graphs are too large to compute GED, just skip them at the current stage 55 | :parm writeTo: path to write 56 | :parm prefix: give a prefix to each file, e.g., "g" or "q4", etc 57 | :parm fileformat: None means following the original format of aids; 58 | :parm removeEdgeLabel: all edges are given label "0" 59 | :parm removeNodeLabel: all nodes are given label "0" 60 | ''' 61 | if writeTo is not None: 62 | if prefix is None: 63 | print("You want to write each graph into a single file.") 64 | print("You need to give the prefix of the filename to store each graph. For example, g, q4, q8") 65 | exit(-1) 66 | else: 67 | if writeTo[-1] == '/': 68 | writeTo = writeTo+prefix 69 | else: 70 | writeTo = writeTo+"/"+prefix 71 | if fileformat is None: 72 | print("please specify fileformat: aids, gexf") 73 | exit(-1) 74 | 75 | 76 | f = open(fname) 77 | lines = f.read() 78 | f.close() 79 | 80 | lines2 = lines.split("t # ") 81 | 82 | lines3 = [g.strip().split("\n") for g in lines2] 83 | 84 | glist = [] 85 | for idx in range(1, len(lines3)): 86 | cur_g = lines3[idx] 87 | 88 | gid_line = cur_g[0].strip().split(' ') 89 | gid = gid_line[0] 90 | if len(gid_line) == 4: 91 | glabel = gid_line[3] 92 | g = nx.Graph(id = gid, label = glabel) 93 | elif len(gid_line) == 6: 94 | glabel = gid_line[3] 95 | g = nx.Graph(id = gid, label = glabel) 96 | else: 97 | g = nx.Graph(id = gid) 98 | 99 | 100 | for idy in range(1, len(cur_g)): 101 | tmp = cur_g[idy].split(' ') 102 | if tmp[0] == 'v': 103 | if removeNodeLabel == False: 104 | g.add_node(tmp[1], att=tmp[2]) 105 | else: 106 | g.add_node(tmp[1], att="0") 107 | if tmp[0] == 'e': 108 | if removeEdgeLabel == False: 109 | g.add_edge(tmp[1], tmp[2], att=tmp[3]) 110 | else: 111 | g.add_edge(tmp[1], tmp[2], att="0") 112 | 113 | 114 | if g.number_of_nodes() >= gsizeNoLessThan and g.number_of_nodes() < gsizeLessThan: 115 | if writeTo is not None: 116 | if fileformat == "aids": 117 | f2 = open(writeTo+g.graph.get('id')+".txt", "w") 118 | f2.write("t # "+g.graph.get('id')+"\n") 119 | 120 | if removeNodeLabel: 121 | for i in range(0, len(g.nodes())): 122 | f2.write("v "+str(i)+" 0\n") 123 | else: 124 | for i in range(0, len(g.nodes())): 125 | f2.write("v "+str(i)+" "+g.nodes[str(i)].get("att")+"\n") 126 | 127 | if removeEdgeLabel: 128 | for e in g.edges(): 129 | f2.write("e "+e[0]+" "+e[1]+" 0\n") 130 | else: 131 | for e in g.edges(): 132 | f2.write("e "+e[0]+" "+e[1]+" "+g[e[0]][e[1]].get("att")+"\n") 133 | f2.close() 134 | if fileformat == "gexf": 135 | nx.write_gexf(g, writeTo+g.graph.get("id")+".gexf") 136 | 137 | 138 | glist.append(g) 139 | if len(glist) > graph_num: 140 | return glist 141 | 142 | return glist 143 | 144 | 145 | 146 | 147 | def read_PG(fname): 148 | f = open(fname) 149 | lines = f.read() 150 | f.close() 151 | 152 | lines2 = lines.split("t # ") 153 | 154 | lines3 = [g.strip().split("\n") for g in lines2] 155 | 156 | glist = [] 157 | for idx in range(1, len(lines3)): 158 | cur_g = lines3[idx] 159 | 160 | gid_line = cur_g[0].strip().split(' ') 161 | gid = gid_line[0] 162 | g = nx.Graph(id = gid) 163 | 164 | for idy in range(1, len(cur_g)): 165 | tmp = cur_g[idy].split(' ') 166 | if tmp[0] == 'v': 167 | g.add_node(tmp[1], att="0") 168 | if tmp[0] == 'e': 169 | g.add_edge(tmp[1], tmp[2], ged=float(tmp[3])) 170 | 171 | glist.append(g) 172 | 173 | return glist 174 | 175 | 176 | 177 | 178 | def readG2GDistBook(fname): 179 | ''' 180 | store distance between two graphs, from G1 to G2 and from G2 to G1 181 | ''' 182 | f = open(fname) 183 | lines = f.read() 184 | f.close() 185 | lines = lines.strip() 186 | lines = lines.split('\n') 187 | distBook = {} 188 | for line in lines: 189 | tmp = line.split(" ") 190 | 191 | if tmp[0] in distBook: 192 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 193 | else: 194 | distBook[tmp[0]] = {} 195 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 196 | if tmp[1] in distBook: 197 | distBook[tmp[1]][tmp[0]] = float(tmp[2]) 198 | else: 199 | distBook[tmp[1]] = {} 200 | distBook[tmp[1]][tmp[0]] = float(tmp[2]) 201 | 202 | # dist from self to self is zero 203 | distBook[tmp[0]][tmp[0]] = 0 204 | distBook[tmp[1]][tmp[1]] = 0 205 | 206 | return distBook 207 | 208 | 209 | 210 | def readQ2GDistBook(fname, validNodeIDSet=None): 211 | ''' 212 | store distance from the query to a data graph 213 | ''' 214 | f = open(fname) 215 | lines = f.read() 216 | f.close() 217 | lines = lines.strip() 218 | lines = lines.split('\n') 219 | distBook = {} 220 | for line in lines: 221 | tmp = line.split(" ") 222 | if validNodeIDSet != None and tmp[1] in validNodeIDSet: 223 | if tmp[0] in distBook: 224 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 225 | else: 226 | distBook[tmp[0]] = {} 227 | distBook[tmp[0]][tmp[1]] = float(tmp[2]) 228 | return distBook 229 | 230 | 231 | def getSubG2GDistBookGivenQueries(G2GDistBookFileName, qGList): 232 | # we only need the distance from qG in qGList to others 233 | f = open(G2GDistBookFileName+".sub4q.txt", 'w') 234 | allDistBook = readG2GDistBook(G2GDistBookFileName) 235 | for q in qGList: 236 | distList_of_q = allDistBook[q.graph.get('id')] 237 | for key,value in distList_of_q.items(): 238 | f.write(q.graph.get('id')+' '+key+' '+str(value)+"\n") 239 | f.close() 240 | 241 | 242 | 243 | 244 | 245 | def make_a_dglgraph(g): 246 | 247 | max_deg = 20 # largest node degree of q 248 | ones = torch.eye(max_deg) 249 | 250 | edges = [[],[]] 251 | for edge in g.edges(): 252 | end1 = edge[0] 253 | end2 = edge[1] 254 | edges[0].append(int(end1)) 255 | edges[1].append(int(end2)) 256 | edges[0].append(int(end2)) 257 | edges[1].append(int(end1)) 258 | dg = dgl.graph((torch.tensor(edges[0]), torch.tensor(edges[1]))) 259 | 260 | h0 = dg.in_degrees().view(1,-1).squeeze() 261 | h0 = ones.index_select(0, h0).float() 262 | dg.ndata['h'] = h0 263 | 264 | return dg.to(torch.device('cuda:'+str(GPUID))) 265 | 266 | 267 | 268 | 269 | def getProdGraphV4(g1, g2): 270 | small_g = None 271 | large_g = None 272 | if g1.number_of_nodes() <= g2.number_of_nodes(): 273 | small_g = g1 274 | large_g = g2 275 | else: 276 | small_g = g2 277 | large_g = g1 278 | 279 | smallG_nodeList = list(small_g.nodes()) 280 | largeG_nodeList = list(large_g.nodes()) 281 | smallG_nodeList.extend([str(-ele) for ele in range(1, len(largeG_nodeList)-len(smallG_nodeList)+1)]) 282 | 283 | # sampled_smallG_nodes = random.sample(smallG_nodeList, max([5, int(len(smallG_nodeList)/10)])) 284 | # sampled_largeG_nodes = random.sample(largeG_nodeList, max([5, int(len(largeG_nodeList)/10)])) 285 | sampled_smallG_nodes = random.sample(smallG_nodeList, max([min([len(smallG_nodeList), 10]), int(len(smallG_nodeList)/10)])) 286 | sampled_largeG_nodes = random.sample(largeG_nodeList, max([min([len(largeG_nodeList), 10]), int(len(largeG_nodeList)/10)])) 287 | 288 | prod_g = nx.Graph() 289 | 290 | for n1 in sampled_smallG_nodes: 291 | if int(n1) >= 0: 292 | n1_neighs = list(small_g[n1]) 293 | else: 294 | n1_neighs = [] 295 | for n2 in sampled_largeG_nodes: 296 | n2_nonNeighs = random.sample(largeG_nodeList, max([3, int(len(largeG_nodeList)/10)])) 297 | for n1_neigh in n1_neighs: 298 | for n2_nonNeigh in n2_nonNeighs: 299 | if large_g.has_edge(n2,n2_nonNeigh) == False: 300 | prod_g.add_edge(n1+"|"+n2, n1_neigh+"|"+n2_nonNeigh) 301 | 302 | 303 | for n2 in sampled_largeG_nodes: 304 | n2_neighs = list(large_g[n2]) 305 | for n1 in sampled_smallG_nodes: 306 | n1_nonNeighs = random.sample(smallG_nodeList, max([3, int(len(smallG_nodeList)/10)])) 307 | for n2_neigh in n2_neighs: 308 | for n1_nonNeigh in n1_nonNeighs: 309 | if small_g.has_edge(n1,n1_nonNeigh) == False: 310 | prod_g.add_edge(n1+"|"+n2, n1_nonNeigh+"|"+n2_neigh) 311 | 312 | 313 | if prod_g.number_of_edges() == 0: 314 | print(smallG_nodeList) 315 | print(largeG_nodeList) 316 | exit(-1) 317 | 318 | # print("prod_g nodes: ", prod_g.number_of_nodes()) 319 | # print("prod_g edges: ", prod_g.number_of_edges()) 320 | largest_deg = -1 321 | for node in prod_g: 322 | if len(prod_g[node]) > largest_deg: 323 | largest_deg = len(prod_g[node]) 324 | 325 | return prod_g, largest_deg 326 | 327 | 328 | 329 | estGEDBuffer = {} 330 | 331 | def getDist(q, g, distBook, isQuery=False): 332 | ''' 333 | q: the query graph 334 | g: the data graph 335 | distBook: pre-computed exact GED 336 | GEDEstimator: GNN model 337 | estScoreBuffer_for_q: used in query processing, to store the predictions of GED just for q, key is g, values is estimated GED 338 | global_estScoreBuffer: used in pg construction to store the predictions of GED, key is a graph, value is {another graph : estimated GED} 339 | ''' 340 | qid = q.graph.get("id") 341 | gid = g.graph.get("id") 342 | # print("qid: ", qid) 343 | # print("gid: ", gid) 344 | 345 | if qid == gid: 346 | return 0 347 | 348 | 349 | if isQuery is False: 350 | if qid not in distBook or gid not in distBook[qid]: 351 | 352 | if qid in estGEDBuffer: 353 | if gid in estGEDBuffer[qid]: 354 | return estGEDBuffer[qid][gid] 355 | 356 | # estimate GED 357 | distance = javaClass.runApp("data/AIDS/g"+qid+".txt", "data/AIDS/g"+gid+".txt") 358 | distance = distance * 2.0 359 | 360 | if qid in estGEDBuffer: 361 | if gid not in estGEDBuffer[qid]: 362 | estGEDBuffer[qid][gid] = distance 363 | else: 364 | estGEDBuffer[qid] = {} 365 | estGEDBuffer[qid][gid] = distance 366 | 367 | if gid in estGEDBuffer: 368 | if qid not in estGEDBuffer[gid]: 369 | estGEDBuffer[gid][qid] = distance 370 | else: 371 | estGEDBuffer[gid] = {} 372 | estGEDBuffer[gid][qid] = distance 373 | 374 | print(qid, gid, "no exact GED, estimate it", distance) 375 | 376 | 377 | return distance 378 | else: 379 | # have pre-computed the exact GED 380 | return distBook[qid][gid] 381 | else: 382 | # compute distance on-the-fly for runtime testing 383 | 384 | if qid in estGEDBuffer: 385 | if gid in estGEDBuffer[qid]: 386 | return estGEDBuffer[qid][gid] 387 | 388 | distance = getExactDist("data/AIDS/g"+qid+".txt", "data/AIDS/g"+gid+".txt", 10000000, 10) # 10 seconds limitation for exact GED computation 389 | if distance < 0: 390 | distance = javaClass.runApp("data/AIDS/g"+qid+".txt", "data/AIDS/g"+gid+".txt") 391 | distance = distance * 2.0 392 | 393 | if qid in estGEDBuffer: 394 | if gid not in estGEDBuffer[qid]: 395 | estGEDBuffer[qid][gid] = distance 396 | else: 397 | estGEDBuffer[qid] = {} 398 | estGEDBuffer[qid][gid] = distance 399 | 400 | if gid in estGEDBuffer: 401 | if qid not in estGEDBuffer[gid]: 402 | estGEDBuffer[gid][qid] = distance 403 | else: 404 | estGEDBuffer[gid] = {} 405 | estGEDBuffer[gid][qid] = distance 406 | 407 | return distance 408 | 409 | 410 | def perf_measure(y_actual, y_hat): 411 | TP = 0 412 | FP = 0 413 | TN = 0 414 | FN = 0 415 | 416 | for i in range(len(y_hat)): 417 | if y_actual[i]==y_hat[i]==1: 418 | TP += 1 419 | if y_hat[i]==1 and y_actual[i]!=y_hat[i]: 420 | FP += 1 421 | if y_actual[i]==y_hat[i]==0: 422 | TN += 1 423 | if y_hat[i]==0 and y_actual[i]!=y_hat[i]: 424 | FN += 1 425 | 426 | return TP, FP, TN, FN 427 | 428 | 429 | def myloss_for_test(preds, gts): 430 | 431 | 432 | TP, FP, TN, FN = perf_measure(gts, preds) 433 | 434 | # Sensitivity, hit rate, recall, or true positive rate 435 | TPR = TP/(TP+FN+0.000001) 436 | # Specificity or true negative rate 437 | TNR = TN/(TN+FP+0.000001) 438 | # Precision or positive predictive value 439 | PPV = TP/(TP+FP+0.000001) 440 | # Negative predictive value 441 | NPV = TN/(TN+FN+0.000001) 442 | # Fall out or false positive rate 443 | FPR = FP/(FP+TN+0.000001) 444 | # False negative rate 445 | FNR = FN/(TP+FN+0.000001) 446 | # False discovery rate 447 | FDR = FP/(TP+FP+0.000001) 448 | 449 | # Overall accuracy 450 | ACC = (TP+TN)/(TP+FP+FN+TN) 451 | 452 | 453 | return FN, FP, FNR, FPR 454 | 455 | 456 | def greedy_search(proximityGraph, queryGraph, k, ep, ef, distBook, gid2gmap): 457 | ''' 458 | k: return top-k from ef candidates 459 | ep: list of start nodes 460 | ef: number of candidates 461 | ''' 462 | cand, stat = search_layer(proximityGraph, queryGraph, ep, ef, distBook, gid2gmap) 463 | while len(cand) > k: 464 | heapq.heappop(cand) 465 | 466 | return cand, stat 467 | 468 | 469 | exact_ans = None 470 | 471 | def search_layer(proxG, q, ep, ef, G2GDistBook, gid2gmap): 472 | ''' 473 | At each hop, use GED ranking neural network to rank the nodes 474 | parm: proxG: current proximity graph. each node in proxG is a networkx graph 475 | parm: q: query, which is a networkx graph 476 | parm: ep: enter points for the search in proxG. should be a list 477 | parm: ef: number of results 478 | ''' 479 | bar = 1000 480 | 481 | aaa = get_topkAll_in_a_list(200, exact_ans[qid]) 482 | aaa = set([ele[0] for ele in aaa]) 483 | 484 | if len(ep) == 0: 485 | logging.error("Error: no enter point") 486 | exit(-1) 487 | 488 | 489 | DCS = 0.0 # count of distance computation 490 | hop_count = 0.0 # the number of hops 491 | visited = set() 492 | C = [] # search frontier, stored by a min-heap 493 | idx_C = 0 # required by min-heap 494 | W = [] # result to return. dynamic list of found nearest neighbors. a max-heap, which can be realized by storing the negative value in a min-heap. 495 | idx_W = 0 # required by max-heap 496 | 497 | hop_enter_neighorhood = None 498 | time_used_before_enter_neighorhood = 0 499 | time_used_after_enter_neighorhood = 0 500 | DCS_before_enter_neighorhood = 0 501 | 502 | model_predicted_nodes = set() 503 | 504 | for ele in ep: 505 | dist = getDist(q, ele, G2GDistBook) 506 | DCS = DCS + 1 507 | 508 | if hop_enter_neighorhood is None: 509 | DCS_before_enter_neighorhood += 1 510 | 511 | heapq.heappush(C, (dist, idx_C, ele)) 512 | idx_C = idx_C + 1 513 | heapq.heappush(W, (-dist, idx_W, ele)) 514 | idx_W = idx_W + 1 515 | visited.add(ele.graph.get("id")) 516 | 517 | while len(C) > 0: 518 | if DEBUG: 519 | print("-"*60) 520 | print("C:") 521 | for ele in C: 522 | print(str(ele[0])+" g"+ele[2].graph.get("id")) 523 | print("+++++++"*3) 524 | print("W:") 525 | for ele in W: 526 | print(str(ele[0])+" g"+ele[2].graph.get("id")) 527 | print("+++++++"*3) 528 | print("visited:") 529 | print(visited) 530 | print("+++++++"*3) 531 | 532 | c = heapq.heappop(C) # extract the nearest in C to q 533 | f = min(W) # get the furthest from W to q 534 | c_dist = getDist(q, c[2], G2GDistBook) 535 | f_dist = getDist(q, f[2], G2GDistBook) 536 | # print([c[0], c[2].graph.get('id')]) 537 | 538 | logging.debug("c is g"+c[2].graph.get("id")+ " dist=" + str(c_dist)) 539 | logging.debug("f is g"+f[2].graph.get("id")+ " dist=" + str(f_dist)) 540 | 541 | 542 | if c_dist > f_dist: 543 | break 544 | 545 | 546 | # list_of_W = list(W) 547 | # list_of_W.sort(key = lambda x: -x[0]) 548 | 549 | # if len(list_of_W) < bar: 550 | # if c_dist > f_dist: 551 | # break 552 | # else: 553 | # if c_dist > -list_of_W[bar-1][0]: 554 | # break 555 | 556 | 557 | 558 | 559 | logging.debug("ok to go on ...") 560 | hop_count = hop_count + 1 561 | 562 | 563 | if c[2].graph.get('id') in aaa: 564 | if hop_enter_neighorhood is None: 565 | hop_enter_neighorhood = hop_count 566 | 567 | 568 | neighbors_of_c = list(proxG[c[2]]) 569 | logging.debug("neighbors of c: " + str(["g"+ele.graph.get("id") for ele in neighbors_of_c])) 570 | 571 | 572 | 573 | 574 | if hop_count >= 0: 575 | # print('c_dist', c_dist) 576 | neighIDs = [neigh.graph.get('id') for neigh in neighbors_of_c] 577 | neighIDs.sort() 578 | # print(neighIDs) 579 | # neighDists = [getDist(q, gid2gmap[neighID], G2GDistBook) for neighID in neighIDs] 580 | # print(neighDists) 581 | 582 | for neighID in neighIDs: 583 | model_predicted_nodes.add(neighID) 584 | 585 | preds = [] 586 | for i in range(0, 8): 587 | curGEDPruneModel = modelMap[i] 588 | curGEDPruneModel.eval() 589 | with torch.no_grad(): 590 | for m in curGEDPruneModel.modules(): 591 | if isinstance(m, nn.BatchNorm1d): 592 | m.track_running_stats=False 593 | 594 | subNeighIDs = neighIDs[i*10 : (i+1)*10] 595 | preds.append(curGEDPruneModel([q.graph.get('id')], [c[2].graph.get('id')], subNeighIDs).view(1,-1).squeeze()) 596 | 597 | preds = torch.stack(preds) 598 | # print(preds) 599 | prune_decision = (preds > 0.5).int().view(1,-1).squeeze().cpu().detach().numpy() 600 | 601 | preds = preds.view(1,-1).squeeze().cpu().detach().numpy().tolist() 602 | 603 | neighID_and_preds = [] 604 | for idx in range(0, len(neighIDs)): 605 | neighID_and_preds.append( (neighIDs[idx], preds[idx]) ) 606 | 607 | neighID_and_preds.sort(key=lambda x: x[1]) 608 | 609 | topPercNeighIDs = set() 610 | for ele in neighID_and_preds[ 0 : int(len(neighID_and_preds)*0.2) ]: 611 | topPercNeighIDs.add(ele[0]) 612 | 613 | 614 | else: 615 | curGEDPruneModel = None 616 | prune_decision = None 617 | 618 | 619 | 620 | neighbors_of_c.sort(key=lambda x: x.graph.get('id')) 621 | neighIDs = [neigh.graph.get('id') for neigh in neighbors_of_c] 622 | 623 | 624 | 625 | 626 | for iii in range(0, len(neighbors_of_c)): 627 | neigh = neighbors_of_c[iii] 628 | 629 | if prune_decision is not None: 630 | if neigh.graph.get('id') not in topPercNeighIDs: 631 | continue 632 | 633 | logging.debug("cur_neigh: g"+neigh.graph.get("id")) 634 | neigh_dist = getDist(q, neigh, G2GDistBook) 635 | 636 | 637 | if neigh.graph.get("id") not in visited: 638 | logging.debug("= not visited") 639 | 640 | visited.add(neigh.graph.get("id")) 641 | f = min(W) 642 | f_dist = getDist(q, f[2], G2GDistBook) 643 | logging.debug("= f is g"+f[2].graph.get("id") + " dist=" + str(f_dist)) 644 | neigh_dist = getDist(q, neigh, G2GDistBook) 645 | DCS = DCS + 1 646 | logging.debug("= cur_neigh g"+neigh.graph.get("id") + " dist=" + str(neigh_dist)) 647 | 648 | 649 | # for PG construction 650 | if neigh_dist < f_dist or len(W) < ef: 651 | if neigh_dist < f_dist: 652 | logging.debug("= neigh_dist < f_dist") 653 | if len(W) < ef: 654 | logging.debug("= len(W) < ef") 655 | heapq.heappush(C, (neigh_dist, idx_C, neigh)) 656 | idx_C = idx_C + 1 657 | heapq.heappush(W, (-neigh_dist, idx_W, neigh)) 658 | idx_W = idx_W + 1 659 | 660 | logging.debug("= push " + "g"+neigh.graph.get("id") + " to C and W") 661 | if len(W) > ef: 662 | deleted = heapq.heappop(W) 663 | logging.debug("= W's size "+ str(len(W)) + " is too large, delete " + "g"+deleted[2].graph.get("id")) 664 | 665 | logging.debug("********") 666 | 667 | 668 | stat = {} 669 | stat["hop_count"] = hop_count 670 | stat["DCS"] = DCS 671 | stat['visited'] = visited 672 | stat['model_pred_count'] = len(model_predicted_nodes) 673 | print("DCS", DCS) 674 | print("model_pred_count", len(model_predicted_nodes)) 675 | print("hop_enter_neighorhood", hop_enter_neighorhood) 676 | print("time_used_before_enter_neighorhood", time_used_before_enter_neighorhood) 677 | print("time_used_after_enter_neighorhood", time_used_after_enter_neighorhood) 678 | print("DCS_before_enter_neighorhood", DCS_before_enter_neighorhood) 679 | 680 | return W, stat 681 | 682 | 683 | 684 | 685 | def select_neighbors_simple(q, cand, M): 686 | deleted = [] 687 | while len(cand) > M: 688 | x = heapq.heappop(cand) 689 | deleted.append(x[2]) 690 | return [ele[2] for ele in cand], deleted 691 | 692 | 693 | 694 | def insert(proxG, q, ep, M, maxDeg0, efConst, G2GDistBook): 695 | cand, _ = search_layer(proxG, q, ep, efConst, G2GDistBook) 696 | neighs, _ = select_neighbors_simple(q, cand, M) 697 | 698 | logging.debug("insert g"+q.graph.get("id")) 699 | logging.debug("insert edges: "+str([ele.graph.get("id") for ele in neighs])) 700 | # insert q to proxG 701 | proxG.add_node(q) 702 | # insert edges for q 703 | for neigh in neighs: 704 | proxG.add_edge(q, neigh) 705 | 706 | # shrink connections if needed 707 | for neigh in neighs: 708 | eConn = list(proxG[neigh]) 709 | if len(eConn) > maxDeg0: 710 | # print("shink") 711 | logging.debug("shink g"+neigh.graph.get("id")+ " degree " + str(len(eConn)) + " exceeds maxDeg0 " + str(maxDeg0)) 712 | eConn_heap = [] 713 | idx_eConn = 0 714 | for ele in eConn: 715 | ele_dist = getDist(neigh, ele, G2GDistBook) 716 | heapq.heappush(eConn_heap, (-ele_dist, idx_eConn, ele)) 717 | idx_eConn = idx_eConn + 1 718 | _, deleted = select_neighbors_simple(neigh, eConn_heap, maxDeg0) 719 | for ele in deleted: 720 | logging.debug("delete edge "+ neigh.graph.get("id") + " " + ele.graph.get("id")) 721 | if len(proxG[neigh]) == 1 or len(proxG[ele]) == 1: 722 | pass 723 | else: 724 | proxG.remove_edge(neigh, ele) 725 | 726 | 727 | 728 | 729 | 730 | 731 | def build_proximity_graph(graphList, M, maxDeg0, efConst, G2GDistBook): 732 | ''' 733 | The construction algorithm of the paper Hierarchical navigable small world graph HNSW PAMI2018 734 | But, we just construct the buttom layer of HNSW. 735 | - graphList: the graphs to insert to the proximity graph. Each node in the proximity graph is a graph in graphList 736 | - maxDeg0: the max degree of node in proximity graph. '0' means the buttom layer of HNSW. 737 | - M, efConst: suppose you are inserting g to the proximity graph, you first find efConst nodes in the proximity graph as the 738 | candidate neighbors of g. Then, you pick M candidates to connect with g. Note that the M nodes maybe not in the efConst candidates, if 739 | you use the select_neighbors_heuristic function in the HNSW paper to pick the M nodes, as select_neighbors_heuristic may check the neighbors 740 | of the nodes in the efConst candidates. 741 | - G2GDistBook: all pair distance between graphs in graphList. Just for fast construction. 742 | ''' 743 | proxG = nx.Graph() 744 | proxG.add_node(graphList[0]) 745 | 746 | for i in range(1, len(graphList)): 747 | if i % 1 == 0: 748 | print(i) 749 | if DEBUG: 750 | print("====================================" + str(i)) 751 | print("cur proxG is: ") 752 | print(str(["g"+ele.graph.get("id") for ele in proxG.nodes()])) 753 | for edge in proxG.edges(): 754 | print(edge[0].graph.get("id"), ' ', edge[1].graph.get("id")) 755 | print("inserting ", "g"+graphList[i].graph.get("id")) 756 | # insert(proxG, graphList[i], [graphList[0]], M, maxDeg0, efConst, G2GDistBook) 757 | rand = np.random.randint(i) 758 | insert(proxG, graphList[i], [graphList[rand]], M, maxDeg0, efConst, G2GDistBook) 759 | 760 | return proxG 761 | 762 | 763 | 764 | 765 | 766 | def hnsw_const(G2GDistBook, data_graphs, M, maxDeg0, efConst): 767 | distBook = G2GDistBook 768 | proxG = build_proximity_graph(data_graphs, M, maxDeg0, efConst, distBook) 769 | print("node has: ", proxG.number_of_nodes()) 770 | print("edge has: ", proxG.number_of_edges()) 771 | print("cc has: ", nx.number_connected_components(proxG)) 772 | return proxG 773 | 774 | 775 | 776 | def save_proxG(fname, proxG): 777 | # write proxG into file 778 | f = open(fname, "w") 779 | f.write("t # 0\n") 780 | for n in proxG.nodes(): 781 | f.write("v "+n.graph.get("id")+"\n") 782 | for e in proxG.edges(): 783 | f.write("e "+e[0].graph.get("id")+" "+e[1].graph.get("id")+"\n") 784 | f.close() 785 | 786 | 787 | def scan_db_and_comp_ged(q, database, distBook): 788 | res = [] 789 | for g in database: 790 | dist = getDist(q, g, distBook) 791 | res.append([ g.graph.get('id'), dist]) 792 | #print(len(res)) 793 | res.sort(key = lambda x: x[1]) 794 | #print(res) 795 | return res 796 | 797 | 798 | def get_topkAll_in_a_list(topk, x): 799 | kth = x[topk-1] 800 | res = x[0:topk] 801 | for i in range(topk, len(x)): 802 | if x[i][1] == kth[1]: 803 | res.append(x[i]) 804 | return res 805 | 806 | 807 | 808 | def get_exact_answer(topk, Q2GDistBook): 809 | answer = {} 810 | for query in Q2GDistBook.keys(): 811 | distToGList = list(Q2GDistBook[query].items()) 812 | distToGList.sort(key=lambda x: x[1]) 813 | dist_thr = -1 814 | if topk-1 < len(distToGList): 815 | dist_thr = distToGList[topk-1][1] 816 | else: 817 | dist_thr = 1000000.0 818 | a = [] 819 | for ele in distToGList: 820 | if ele[1] <= dist_thr: 821 | a.append(ele) 822 | else: 823 | break 824 | answer[query] = a 825 | return answer 826 | 827 | 828 | def pgBuild(database, G2GDistBookFileName, M, maxDeg0, efConst): 829 | G2GDistBook = readG2GDistBook(G2GDistBookFileName) 830 | pg = hnsw_const(G2GDistBook, database, M, maxDeg0, efConst) 831 | save_proxG("hnsw.aids.M"+str(M)+".D"+str(maxDeg0)+".ef"+str(efConst)+".nx", pg) 832 | 833 | 834 | 835 | def reassignNodeID(graph, fname): 836 | oldID2newIDMap = {} 837 | newG = nx.Graph() 838 | newNodeID = 0 839 | for node in graph.nodes(): 840 | oldID2newIDMap[node] = newNodeID 841 | newG.add_node(newNodeID) 842 | newNodeID += 1 843 | for edge in graph.edges(): 844 | end1 = edge[0] 845 | end2 = edge[1] 846 | new_end1 = oldID2newIDMap[end1] 847 | new_end2 = oldID2newIDMap[end2] 848 | newG.add_edge(new_end1, new_end2) 849 | nx.write_edgelist(newG, fname, data=False) 850 | 851 | 852 | 853 | def getExactDist(gfile, qfile, thr, timelimit): 854 | """ 855 | invoke the code of Lijun Chang (ICDE’20 paper) 856 | :parm gfile: data graph file name 857 | :parm qfile: query graph file name 858 | :parm thr: check if GED <= thr 859 | :parm timelimit: stop if reach the time limit (in seconds) 860 | :return: if timeout, return -2.0; if GED > thr, return -1.0; if GED <= thr, return GED 861 | """ 862 | dist = -2.0 863 | st = time.time() 864 | try: 865 | abc = subprocess.check_output([".~/Graph_Edit_Distance/ged_debian", gfile, qfile, "astar", "LSa", str(thr)], timeout=timelimit) # timeout is in seconds 866 | abc = abc.decode('utf-8') 867 | abc2 = abc.split('\n') 868 | abc3 = abc2[1] 869 | abc3 = abc3.strip() 870 | abc4 = abc3.split(',') 871 | abc5 = abc4[0] 872 | abc6 = abc5.split(' ') 873 | dist = float(abc6[2]) 874 | # print(abc) 875 | # print(dist) 876 | # print('gfile ', gfile) 877 | # print('qfile ', qfile) 878 | # print(dist) 879 | except: 880 | # print('gfile ', gfile) 881 | # print('qfile ', qfile) 882 | # print("time out!") 883 | # print(dist) 884 | pass 885 | et = time.time() 886 | # print("clock time (sec.) ", (et-st)) 887 | return dist 888 | 889 | 890 | 891 | def read_initial_gemb(addr): 892 | gEmbMap = {} 893 | gfileList = os.listdir(addr) 894 | for gfile in gfileList: 895 | gID = gfile[1:-4] 896 | f = open(addr+"/"+gfile) 897 | lines = f.read() 898 | f.close() 899 | lines = lines.strip().split('\n') 900 | lines = lines[1:] 901 | nodeEmbList = [] 902 | for line in lines: 903 | tmp = line.strip().split(' ') 904 | tmp2 = [float(ele) for ele in tmp[1:]] 905 | nodeEmbList.append(tmp2) 906 | nodeEmbList = torch.tensor(nodeEmbList) 907 | gEmb = torch.mean(nodeEmbList, 0) 908 | gEmbMap[gID] = gEmb.cuda() 909 | return gEmbMap 910 | 911 | 912 | 913 | ######################################################################################################## 914 | #### GNN model 915 | ######################################################################################################## 916 | 917 | 918 | 919 | class Model(nn.Module): 920 | def __init__(self, gID2InitEmbMap, gid2dgmap): 921 | super(Model, self).__init__() 922 | 923 | self.gid2dgMap = gid2dgmap 924 | self.gID2InitEmbMap = gID2InitEmbMap 925 | 926 | 927 | 928 | self.RELU = torch.nn.ReLU(inplace=True) 929 | 930 | self.fc_init = nn.Linear(20, 512, bias=True) 931 | self.conv1_for_g = GINConv(None, 'mean') 932 | self.conv2_for_g = GINConv(None, 'mean') 933 | # self.conv1_for_g = GINConv(nn.Linear(hdim, hdim, bias=True), 'mean') 934 | # self.conv2_for_g = GINConv(nn.Linear(hdim, hdim, bias=True), 'mean') 935 | self.gnn_bn = torch.nn.BatchNorm1d(512) 936 | self.gnn_bn2 = torch.nn.BatchNorm1d(512) 937 | 938 | self.fc = nn.Linear(512*3, 256, bias=True) 939 | self.fc2 = nn.Linear(256, 256, bias=True) 940 | self.fc3 = nn.Linear(256, 256, bias=True) 941 | self.fc4 = nn.Linear(256, 1, bias=True) 942 | self.bn = torch.nn.BatchNorm1d(256) 943 | self.bn2 = torch.nn.BatchNorm1d(256) 944 | self.bn3 = torch.nn.BatchNorm1d(256) 945 | 946 | self.dp = torch.nn.Dropout(0.5) 947 | 948 | 949 | 950 | def forward(self, qIDs, pgNodeIDs, neighIDs): 951 | 952 | outputNum = 10 953 | 954 | batch_dg = dgl.batch([self.gid2dgMap[qid] for qid in qIDs]) 955 | batch_dg.ndata['h'] = self.fc_init(batch_dg.ndata['h']) 956 | batch_dg.ndata['h'] = self.RELU(self.gnn_bn(self.conv1_for_g(batch_dg, batch_dg.ndata['h']))) 957 | batch_dg.ndata['h'] = self.RELU(self.gnn_bn2(self.conv2_for_g(batch_dg, batch_dg.ndata['h']))) 958 | qemb = dgl.mean_nodes(batch_dg, 'h') 959 | qemb = qemb.view(1,-1) 960 | 961 | 962 | neighEmbList = torch.zeros(outputNum, 512).cuda() 963 | for idx in range(0, len(neighIDs)): 964 | neighID = neighIDs[idx] 965 | neighEmbList[idx] = self.gID2InitEmbMap[neighID] 966 | 967 | pgNode_embList = [self.gID2InitEmbMap[pgNodeID] for pgNodeID in pgNodeIDs] 968 | pgNode_embList = torch.stack(pgNode_embList).cuda() 969 | pgNode_embList = pgNode_embList.view(1,-1) 970 | 971 | 972 | a = torch.cat([qemb, pgNode_embList], 1) 973 | a = a.repeat(1, outputNum).view(-1, 512*2) 974 | 975 | b = torch.cat([a, neighEmbList], 1) 976 | 977 | 978 | H = self.RELU(self.bn(self.fc(b))) 979 | H2 = self.RELU(self.bn2(self.fc2(H))) 980 | H3 = self.RELU(self.bn3(self.fc3(H2))) 981 | pred = torch.sigmoid(self.fc4(H3)) 982 | pred = pred.view(1, outputNum) 983 | 984 | return pred 985 | 986 | 987 | 988 | ######################################################################################################## 989 | 990 | entire_dataset = read_and_split_to_individual_graph("data/AIDS/aids.txt", 0, 10000000) 991 | print(len(entire_dataset)) 992 | 993 | gid2gmap = {} 994 | for g in entire_dataset: 995 | gid2gmap[g.graph.get("id")] = g 996 | gid2dgmap = {} 997 | for g in entire_dataset: 998 | dg = make_a_dglgraph(g) 999 | dg = dgl.add_self_loop(dg) 1000 | gid2dgmap[g.graph.get('id')] = dg#.to(torch.device('cuda:'+str(GPUID))) 1001 | 1002 | 1003 | gID2InitEmbMap = read_initial_gemb('data/AIDS/emb/aids.emb') 1004 | print('read g init emb done.') 1005 | 1006 | 1007 | database = entire_dataset[0:40000] # aids db size = 4000 1008 | 1009 | 1010 | pgTmp = read_and_split_to_individual_graph("PG.aids.nx", 0, 10000000000) 1011 | pgTmp = pgTmp[0] 1012 | 1013 | 1014 | pgNodeIDSet = pgTmp.nodes() 1015 | pg = nx.Graph() 1016 | for nID in pgTmp.nodes(): 1017 | pg.add_node(gid2gmap[nID], deg=len(pgTmp[nID])) 1018 | for edge in pgTmp.edges(): 1019 | edge_weight = pgTmp.get_edge_data(*edge) 1020 | pg.add_edge(gid2gmap[edge[0]], gid2gmap[edge[1]]) 1021 | 1022 | 1023 | queries = [] 1024 | f = open('data/AIDS/query_test.txt') 1025 | lines = f.read() 1026 | f.close() 1027 | lines = lines.strip().split('\n') 1028 | for line in lines: 1029 | qid = line.strip() 1030 | queries.append(gid2gmap[qid]) 1031 | queryIDs = set() 1032 | for q in queries: 1033 | queryIDs.add(q.graph.get('id')) 1034 | 1035 | 1036 | 1037 | q2GDistBook = readQ2GDistBook("data/AIDS/aids.txt", pgNodeIDSet) 1038 | exact_ans = get_exact_answer(100000000, q2GDistBook) 1039 | 1040 | 1041 | 1042 | 1043 | ep = 319 # model of which epoch you want to use 1044 | 1045 | model0 = Model(gID2InitEmbMap, gid2dgmap) 1046 | model0.load_state_dict(torch.load("aids.perc20_model_save/prune_ged0_10.e"+str(ep)+".pkl")) 1047 | modelMap[0] = model0.cuda() 1048 | 1049 | model10 = Model(gID2InitEmbMap, gid2dgmap) 1050 | model10.load_state_dict(torch.load("aids.perc20_model_save/prune_ged10_20.e"+str(ep)+".pkl")) 1051 | modelMap[1] = model10.cuda() 1052 | 1053 | model20 = Model(gID2InitEmbMap, gid2dgmap) 1054 | model20.load_state_dict(torch.load("aids.perc20_model_save/prune_ged20_30.e"+str(ep)+".pkl")) 1055 | modelMap[2] = model20.cuda() 1056 | 1057 | model30 = Model(gID2InitEmbMap, gid2dgmap) 1058 | model30.load_state_dict(torch.load("aids.perc20_model_save/prune_ged30_40.e"+str(ep)+".pkl")) 1059 | modelMap[3] = model30.cuda() 1060 | 1061 | model40 = Model(gID2InitEmbMap, gid2dgmap) 1062 | model40.load_state_dict(torch.load("aids.perc20_model_save/prune_ged40_50.e"+str(ep)+".pkl")) 1063 | modelMap[4] = model40.cuda() 1064 | 1065 | model50 = Model(gID2InitEmbMap, gid2dgmap) 1066 | model50.load_state_dict(torch.load("aids.perc20_model_save/prune_ged50_60.e"+str(ep)+".pkl")) 1067 | modelMap[5] = model50.cuda() 1068 | 1069 | model60 = Model(gID2InitEmbMap, gid2dgmap) 1070 | model60.load_state_dict(torch.load("aids.perc20_model_save/prune_ged60_70.e"+str(ep)+".pkl")) 1071 | modelMap[6] = model60.cuda() 1072 | 1073 | model70 = Model(gID2InitEmbMap, gid2dgmap) 1074 | model70.load_state_dict(torch.load("aids.perc20_model_save/prune_ged70_80.e"+str(ep)+".pkl")) 1075 | modelMap[7] = model70.cuda() 1076 | 1077 | 1078 | def set_bn_eval(m): 1079 | classname = m.__class__.__name__ 1080 | print(classname) 1081 | if classname.find('BatchNorm') != -1: 1082 | print('make it eval()') 1083 | m.eval() 1084 | 1085 | for i in range(0, 8): 1086 | modelMap[i].apply(set_bn_eval) 1087 | 1088 | 1089 | 1090 | 1091 | topk = 50 1092 | 1093 | avg_recall = 0.0 1094 | avg_precision = 0.0 1095 | avg_DCS = 0.0 1096 | avg_hops = 0.0 1097 | counter = 0 1098 | start_time = time.time() 1099 | for q in queries[0:10]: 1100 | qid = q.graph.get("id") 1101 | print("qid: ", qid) 1102 | 1103 | exact_ans_of_q = get_topkAll_in_a_list(topk, exact_ans[qid]) 1104 | print("exact_ans_of_q: ", exact_ans_of_q) 1105 | 1106 | 1107 | message = 'random_initial_node' 1108 | rand = np.random.randint(len(database)) 1109 | start_nodes = [database[rand]] 1110 | cand, stat = greedy_search(pg, q, 50, start_nodes, 50, q2GDistBook, gid2gmap) 1111 | 1112 | 1113 | pred_ans_of_q = [(ele[2].graph.get('id'), ele[0]) for ele in cand] 1114 | pred_ans_of_q.sort(key = lambda x: -x[1]) 1115 | print("pred_ans_of_q", pred_ans_of_q) 1116 | print("DCS: ", stat['DCS']) 1117 | 1118 | avg_DCS += stat['DCS'] 1119 | avg_hops += stat['hop_count'] 1120 | 1121 | recall = set([ele[0] for ele in pred_ans_of_q]) & set([ele[0] for ele in exact_ans_of_q]) 1122 | recall_perc = min(1.0, len(recall)/topk) 1123 | print("recall perc: ", recall_perc) 1124 | avg_recall += recall_perc 1125 | 1126 | precision = len(recall)/len(exact_ans_of_q) 1127 | print('precision', precision) 1128 | avg_precision += precision 1129 | 1130 | f.write(qid+" "+str(recall_perc)+" "+str(precision)+" "+str(stat['DCS'])+" "+str(stat['hop_count'])+"\n") 1131 | 1132 | counter += 1 1133 | print('---------------------------------------------') 1134 | end_time = time.time() 1135 | 1136 | print("avg_recall: ", (avg_recall/counter)) 1137 | print("avg_precision: ", (avg_precision/counter)) 1138 | print("avg_DCS: ", (avg_DCS/counter)) 1139 | print('avg_hops: ', (avg_hops/counter)) 1140 | print('avg_time: (s)', (end_time - start_time)/counter) 1141 | print("counter: ", counter) 1142 | print('msg: ', message) 1143 | print("ep ", ep) 1144 | 1145 | 1146 | 1147 | 1148 | 1149 | jpype.shutdownJVM() 1150 | --------------------------------------------------------------------------------