├── data ├── all_grids.pt ├── dataset.py ├── load_data.py ├── orderings.py ├── tokens.py ├── mol_utils.py └── data_utils.py ├── assets └── k2_tree_4.png ├── evaluation ├── orca │ ├── orca │ └── orca.h ├── orcamodule.cpp ├── ablation.py ├── evaluation_spectre.py └── evaluation.py ├── .gitignore ├── requirements.txt ├── test.py ├── script └── trans │ ├── enz.sh │ ├── zinc.sh │ ├── planar.sh │ ├── grid.sh │ ├── qm9.sh │ └── com_small.sh ├── generate_string.py ├── README.md ├── plot.py ├── eval_vun.py ├── draw_samples.py ├── trainer ├── train_trans_generator.py └── train_generator.py └── model └── trans_generator.py /data/all_grids.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunhuijang/HGGT/HEAD/data/all_grids.pt -------------------------------------------------------------------------------- /assets/k2_tree_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunhuijang/HGGT/HEAD/assets/k2_tree_4.png -------------------------------------------------------------------------------- /evaluation/orca/orca: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunhuijang/HGGT/HEAD/evaluation/orca/orca -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | resource/ 2 | samples/ 3 | *.pyc 4 | *.ckpt 5 | *.yaml 6 | *.log 7 | *.json 8 | wandb/ 9 | k2g/ 10 | *.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | molsets 3 | pandas==1.5.3 4 | matplotlib==3.7.0 5 | numpy==1.23.5 6 | scipy==1.10.1 7 | pyemd==1.0.0 8 | wandb -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from data.load_data import generate_string 5 | import os 6 | import networkx as nx 7 | from data.data_utils import get_max_len 8 | 9 | 10 | for data in ['point']: 11 | print(get_max_len(data, 'C-M', 2)) 12 | -------------------------------------------------------------------------------- /script/trans/enz.sh: -------------------------------------------------------------------------------- 1 | python trainer/train_trans_generator.py \ 2 | --order C-M \ 3 | --dataset_name GDSS_enz \ 4 | --max_epochs 500 \ 5 | --check_sample_every_n_epoch 501 \ 6 | --replicate 1 \ 7 | --max_len 238 \ 8 | --wandb_on online \ 9 | --string_type group-red \ 10 | --lr 0.0002 \ 11 | --batch_size 64 \ 12 | --sample_batch_size 64 \ 13 | --num_samples 64 \ 14 | --dropout 0.1 \ 15 | --input_dropout 0 \ 16 | --k 2 -------------------------------------------------------------------------------- /script/trans/zinc.sh: -------------------------------------------------------------------------------- 1 | python trainer/train_trans_generator.py \ 2 | --order C-M \ 3 | --dataset_name zinc \ 4 | --max_epochs 500 \ 5 | --check_sample_every_n_epoch 20 \ 6 | --replicate 0 \ 7 | --max_len 87 \ 8 | --wandb_on online \ 9 | --string_type zinc-red \ 10 | --lr 0.0005 \ 11 | --batch_size 32 \ 12 | --num_samples 10000 \ 13 | --sample_batch_size 50 \ 14 | --dropout 0.1 \ 15 | --input_dropout 0 \ 16 | --k 2 17 | -------------------------------------------------------------------------------- /script/trans/planar.sh: -------------------------------------------------------------------------------- 1 | python trainer/train_trans_generator.py \ 2 | --order C-M \ 3 | --dataset_name planar \ 4 | --max_epochs 500 \ 5 | --check_sample_every_n_epoch 501 \ 6 | --replicate 0 \ 7 | --max_len 230 \ 8 | --wandb_on online \ 9 | --string_type group-red \ 10 | --lr 0.001 \ 11 | --batch_size 32 \ 12 | --sample_batch_size 50 \ 13 | --num_samples 100 \ 14 | --tree_pos \ 15 | --pos_type emb \ 16 | --dropout 0 \ 17 | --k 2 -------------------------------------------------------------------------------- /script/trans/grid.sh: -------------------------------------------------------------------------------- 1 | python trainer/train_trans_generator.py \ 2 | --order C-M \ 3 | --dataset_name GDSS_grid \ 4 | --max_epochs 500 \ 5 | --check_sample_every_n_epoch 501 \ 6 | --replicate 0 \ 7 | --max_len 706 \ 8 | --wandb_on online \ 9 | --string_type group-red \ 10 | --lr 0.0005 \ 11 | --batch_size 8 \ 12 | --num_samples 200 \ 13 | --tree_pos \ 14 | --pos_type emb \ 15 | --dropout 0.1 \ 16 | --input_dropout 0 \ 17 | --sample_batch_size 50 \ 18 | --k 2 -------------------------------------------------------------------------------- /script/trans/qm9.sh: -------------------------------------------------------------------------------- 1 | python trainer/train_trans_generator.py \ 2 | --order C-M \ 3 | --dataset_name qm9 \ 4 | --max_epochs 500 \ 5 | --check_sample_every_n_epoch 20 \ 6 | --replicate 0 \ 7 | --max_len 23 \ 8 | --wandb_on online \ 9 | --string_type qm9-red \ 10 | --lr 0.0005 \ 11 | --batch_size 1024 \ 12 | --num_samples 10000 \ 13 | --sample_batch_size 400 \ 14 | --tree_pos \ 15 | --pos_type emb \ 16 | --dropout 0.1 \ 17 | --input_dropout 0.5 \ 18 | --k 2 -------------------------------------------------------------------------------- /script/trans/com_small.sh: -------------------------------------------------------------------------------- 1 | python trainer/train_trans_generator.py \ 2 | --order C-M \ 3 | --dataset_name GDSS_com \ 4 | --max_epochs 500 \ 5 | --check_sample_every_n_epoch 501 \ 6 | --replicate 0 \ 7 | --max_len 48 \ 8 | --wandb_on online \ 9 | --string_type group-red \ 10 | --lr 0.001 \ 11 | --batch_size 128 \ 12 | --num_samples 128 \ 13 | --sample_batch_size 128 \ 14 | --tree_pos \ 15 | --pos_type emb \ 16 | --dropout 0.1 \ 17 | --input_dropout 0 \ 18 | --k 2 -------------------------------------------------------------------------------- /generate_string.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from data.load_data import generate_string 4 | from data.load_data import generate_mol_string 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | # dataset_name: GDSS_com, GDSS_enz, GDSS_grid, planar, qm9, zinc, planar, sbm 10 | parser.add_argument("dataset_name", type=str, default='traffic') 11 | # order: C-M, BFS, DFS 12 | parser.add_argument("order", type=str, default='C-M') 13 | # k: 2, 3 14 | parser.add_argument("k", type=int, default=2) 15 | 16 | args, _ = parser.parse_known_args() 17 | 18 | if args.dataset_name in ['GDSS_com', 'GDSS_enz', 'GDSS_grid', 'planar', 'sbm', 'planar', 'traffic', 'ego', 'lobster', 'point']: 19 | generate_string(dataset_name=args.dataset_name, order=args.order, k=args.k) 20 | else: 21 | generate_mol_string(dataset_name=args.dataset_name, order=args.order) -------------------------------------------------------------------------------- /evaluation/orcamodule.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "orca/orca.h" 7 | 8 | static PyObject * 9 | orca_motifs(PyObject *self, PyObject *args) 10 | { 11 | const char *orbit_type; 12 | int graphlet_size; 13 | const char *input_filename; 14 | const char *output_filename; 15 | int sts; 16 | 17 | if (!PyArg_ParseTuple(args, "siss", &orbit_type, &graphlet_size, &input_filename, &output_filename)) 18 | return NULL; 19 | sts = system(orbit_type); 20 | motif_counts(orbit_type, graphlet_size, input_filename, output_filename); 21 | return PyLong_FromLong(sts); 22 | } 23 | 24 | static PyMethodDef OrcaMethods[] = { 25 | {"motifs", orca_motifs, METH_VARARGS, 26 | "Compute motif counts."}, 27 | }; 28 | 29 | static struct PyModuleDef orcamodule = { 30 | PyModuleDef_HEAD_INIT, 31 | "orca", /* name of module */ 32 | NULL, /* module documentation, may be NULL */ 33 | -1, /* size of per-interpreter state of the module, 34 | or -1 if the module keeps state in global variables. */ 35 | OrcaMethods 36 | }; 37 | 38 | PyMODINIT_FUNC 39 | PyInit_orca(void) 40 | { 41 | return PyModule_Create(&orcamodule); 42 | } 43 | 44 | int main(int argc, char *argv[]) { 45 | 46 | wchar_t *program = Py_DecodeLocale(argv[0], NULL); 47 | if (program == NULL) { 48 | fprintf(stderr, "Fatal error: cannot decode argv[0]\n"); 49 | exit(1); 50 | } 51 | 52 | /* Add a built-in module, before Py_Initialize */ 53 | PyImport_AppendInittab("orca", PyInit_orca); 54 | 55 | /* Pass argv[0] to the Python interpreter */ 56 | Py_SetProgramName(program); 57 | 58 | /* Initialize the Python interpreter. Required. */ 59 | Py_Initialize(); 60 | 61 | /* Optionally import the module; alternatively, 62 | import can be deferred until the embedded script 63 | imports it. */ 64 | PyImport_ImportModule("orca"); 65 | 66 | PyMem_RawFree(program); 67 | 68 | } 69 | 70 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from pathlib import Path 4 | import os 5 | from tqdm import tqdm 6 | 7 | from data.data_utils import remove_redundant 8 | from data.tokens import tokenize 9 | 10 | 11 | DATA_DIR = "resource" 12 | 13 | class GenericDataset(Dataset): 14 | is_mol = False 15 | def __init__(self, split, string_type='group-red', order='C-M', is_tree=False, k=2): 16 | self.string_type = string_type 17 | self.is_tree = is_tree 18 | self.order = order 19 | self.k = k 20 | if k > 2: 21 | string_path = os.path.join(self.raw_dir, f"{self.order}/{self.data_name}_str_{split}_{self.k}.txt") 22 | else: 23 | string_path = os.path.join(self.raw_dir, f"{self.order}/{self.data_name}_str_{split}.txt") 24 | self.strings = Path(string_path).read_text(encoding="utf=8").splitlines() 25 | 26 | # remove redundant 27 | self.strings = [remove_redundant(string, self.is_mol, self.k) for string in tqdm(self.strings, 'Removing redundancy')] 28 | 29 | def __len__(self): 30 | return len(self.strings) 31 | 32 | def __getitem__(self, idx: int): 33 | return torch.LongTensor(tokenize(self.strings[idx], self.string_type, self.k)) 34 | 35 | class ComDataset(GenericDataset): 36 | data_name = 'GDSS_com' 37 | raw_dir = f'{DATA_DIR}/GDSS_com' 38 | 39 | class EnzDataset(GenericDataset): 40 | data_name = 'GDSS_enz' 41 | raw_dir = f'{DATA_DIR}/GDSS_enz' 42 | 43 | class GridDataset(GenericDataset): 44 | data_name = 'GDSS_grid' 45 | raw_dir = f'{DATA_DIR}/GDSS_grid' 46 | 47 | 48 | class QM9Dataset(GenericDataset): 49 | data_name = "qm9" 50 | raw_dir = f"{DATA_DIR}/qm9" 51 | is_mol = True 52 | 53 | class ZINCDataset(GenericDataset): 54 | data_name = 'zinc' 55 | raw_dir = f'{DATA_DIR}/zinc' 56 | is_mol = True 57 | 58 | class PlanarDataset(GenericDataset): 59 | data_name = 'planar' 60 | raw_dir = f'{DATA_DIR}/planar' 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Generation with $K^2$-trees (ICLR 2024) 2 | 3 | In this repository, we implement the paper: [Graph Generation with $K^{2}$-Tree (HGGT)](https://openreview.net/pdf?id=RIEW6M9YoV). 4 | 5 |

6 | 7 |

8 | 9 | ## Contribution 10 | 11 | + We propose a new graph generative model based on adopting the $K^2$-tree as a compact, hierarchical, and domain-agnostic representation of graphs. 12 | + We introduce a novel, compact sequential $K^2$-tree representation obtained from pruning, flattening, and tokenizing the $K^2$-tree. 13 | + We propose an autoregressive model to generate the sequential $K^2$-tree representation using Transformer architecture with a specialized positional encoding scheme. 14 | + We validate the efficacy of our framework by demonstrating state-of-the-art graph generation performance on five out of six graph generation benchmarks. 15 | 16 | ## Dependencies 17 | 18 | HGGT is built in Python 3.10.0, PyTorch 1.12.1, and PyTorch Geometric 2.2.0 . Use the following commands to install the required python packages. 19 | 20 | ```sh 21 | pip install Cython 22 | pip install -r requirements.txt 23 | pip install rdkit==2020.9.5 24 | pip install git+https://github.com/fabriziocosta/EDeN.git 25 | pip install pytorch-lightning==1.9.3 26 | pip install treelib 27 | pip install networkx==2.8.7 28 | ``` 29 | 30 | ## Running experiments 31 | 32 | ### 1. Data preparation 33 | 34 | We provide four generic datasets (community-small, enzymes, grid, and planar) and two molecular datasets (ZINC250k, and QM9). You can download the pickle files of five datasets (community-small, enzymes, grid, ZINC250k, and QM9) from https://github.com/harryjo97/GDSS/tree/master and the other from https://github.com/KarolisMart/SPECTRE. 35 | 36 | After downloading the pickle file into the `reource/${dataset_name}/` directory, make a new directory `reource/${dataset_name}/${order}/` to store sequence representations. Then you can generate sequence representations of $K^2$-tree by running: 37 | 38 | ```sh 39 | python generate_string.py --dataset_name ${dataset_name} --order ${order} --k ${k} 40 | ``` 41 | 42 | For example, 43 | ```sh 44 | python generate_string.py --dataset_name GDSS_com --order C-M --k 2 45 | ``` 46 | 47 | ### 2. Configurations 48 | 49 | The configurations are given in `config/trans/` directory. Note that max_len denotes the maximum length of the sequence representation in generation. We set max_len as the maximum length of sequence representations of training and test graphs. 50 | 51 | ### 3. Training and evaluation 52 | 53 | You can train HGGT model and generate samples by running: 54 | ```sh 55 | CUDA_VISIBLE_DEVICES=${gpu_id} bash script/trans/{script_name}.sh 56 | ``` 57 | 58 | For example, 59 | ```sh 60 | CUDA_VISIBLE_DEVICES=0 bash script/trans/com_small_2.sh 61 | ``` 62 | 63 | Then the generated samples are saved in `samples/` directory and the metrics are reported on WANDB. 64 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import math 2 | import networkx as nx 3 | import numpy as np 4 | import os 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import pickle 8 | import warnings 9 | warnings.filterwarnings("ignore", category=matplotlib.cbook.MatplotlibDeprecationWarning) 10 | 11 | 12 | options = { 13 | 'node_size': 5, 14 | 'edge_color' : 'black', 15 | 'linewidths': 0.8, 16 | 'width': 1.2 17 | } 18 | 19 | options_one = { 20 | 'node_size': 15, 21 | 'edge_color' : 'black', 22 | 'linewidths': 1, 23 | 'width': 1.2 24 | } 25 | 26 | def plot_graphs_list(graphs, title='title', max_num=16, save_dir=None, N=0): 27 | batch_size = len(graphs) 28 | max_num = min(batch_size, max_num) 29 | img_c = int(math.ceil(np.sqrt(max_num))) 30 | figure = plt.figure() 31 | 32 | for i in range(max_num): 33 | # idx = i * (batch_size // max_num) 34 | idx = i + max_num*N 35 | if not isinstance(graphs[idx], nx.Graph): 36 | G = graphs[idx].g.copy() 37 | else: 38 | G = graphs[idx].copy() 39 | assert isinstance(G, nx.Graph) 40 | G.remove_nodes_from(list(nx.isolates(G))) 41 | e = G.number_of_edges() 42 | v = G.number_of_nodes() 43 | l = nx.number_of_selfloops(G) 44 | 45 | ax = plt.subplot(img_c, img_c, i + 1) 46 | title_str = f'e={e - l}, n={v}' 47 | # if 'lobster' in save_dir.split('/')[0]: 48 | # if is_lobster_graph(graphs[idx]): 49 | # title_str += f' [L]' 50 | pos = nx.spring_layout(G) 51 | nx.draw(G, pos, with_labels=False, **options) 52 | ax.title.set_text(title_str) 53 | # figure.suptitle(title) 54 | 55 | save_fig(save_dir=save_dir, title=title) 56 | 57 | return 58 | 59 | def plot_one_graph(graph, title, save_dir): 60 | figure = plt.figure() 61 | 62 | G = graph 63 | G.remove_nodes_from(list(nx.isolates(G))) 64 | e = G.number_of_edges() 65 | v = G.number_of_nodes() 66 | l = nx.number_of_selfloops(G) 67 | 68 | fig, ax = plt.subplots(1,1, figsize=(3,3)) 69 | plt.figure(figsize=(5,5)) 70 | pos = nx.spring_layout(G) 71 | nx.draw(G, pos, with_labels=False, **options_one) 72 | # ax.title.set_text(title_str) 73 | # figure.suptitle(title) 74 | 75 | save_fig(save_dir=save_dir, title=title) 76 | 77 | def save_fig(save_dir=None, title='fig', dpi=400): 78 | plt.tight_layout() 79 | plt.subplots_adjust(top=0.85) 80 | if save_dir is None: 81 | plt.show() 82 | else: 83 | fig_dir = os.path.join(*['samples', 'fig', save_dir]) 84 | if not os.path.exists(fig_dir): 85 | os.makedirs(fig_dir) 86 | plt.savefig(os.path.join(fig_dir, title), 87 | bbox_inches='tight', 88 | dpi=dpi, 89 | transparent=False) 90 | plt.close() 91 | return 92 | 93 | 94 | def save_graph_list(log_folder_name, exp_name, gen_graph_list): 95 | 96 | if not(os.path.isdir('./samples/pkl/{}'.format(log_folder_name))): 97 | os.makedirs(os.path.join('./samples/pkl/{}'.format(log_folder_name))) 98 | with open('./samples/pkl/{}/{}.pkl'.format(log_folder_name, exp_name), 'wb') as f: 99 | pickle.dump(obj=gen_graph_list, file=f, protocol=pickle.HIGHEST_PROTOCOL) 100 | save_dir = './samples/pkl/{}/{}.pkl'.format(log_folder_name, exp_name) 101 | return save_dir 102 | -------------------------------------------------------------------------------- /evaluation/ablation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | import pickle 5 | import networkx as nx 6 | from statistics import mean 7 | import pandas as pd 8 | 9 | from data.data_utils import remove_redundant 10 | from data.load_data import load_proteins_data 11 | from data.orderings import bw_from_adj, ORDER_FUNCS, order_graphs 12 | 13 | 14 | DATA_DIR = "resource" 15 | 16 | def compute_compression_rate(data_name, order, model, is_red=False): 17 | raw_dir = f"{DATA_DIR}/{data_name}/{order}" 18 | total_strings = [] 19 | # load k2 tree string 20 | for split in ['train', 'val', 'test']: 21 | string_path = os.path.join(raw_dir, f"{data_name}_str_{split}.txt") 22 | strings = Path(string_path).read_text(encoding="utf=8").splitlines() 23 | # for string length 24 | if is_red: 25 | strings = [''.join(remove_redundant(string)) for string in strings] 26 | 27 | total_strings.extend(strings) 28 | 29 | 30 | # load data 31 | if data_name in ['planar', 'sbm']: 32 | adjs, _, _, _, _, _, _, _ = torch.load(f'{DATA_DIR}/{data_name}/{data_name}.pt') 33 | adjs = [adj.numpy() for adj in adjs] 34 | elif data_name == 'proteins': 35 | adjs = load_proteins_data("../resource") 36 | else: 37 | with open (f'{DATA_DIR}/{data_name}/{data_name}.pkl', 'rb') as f: 38 | graphs = pickle.load(f) 39 | 40 | order_func = ORDER_FUNCS['BFS'] 41 | ordered_graphs = order_graphs(graphs, num_repetitions=1, order_func=order_func, is_mol=False, seed=0) 42 | ordered_graphs = [nx.from_numpy_array(ordered_graph.to_adjacency().numpy()) for ordered_graph in ordered_graphs] 43 | adjs = [nx.adjacency_matrix(graph).toarray() for graph in ordered_graphs] 44 | 45 | n_nodes = [adj.shape[0]*adj.shape[1] for adj in adjs] 46 | if model == 'hggt': 47 | pass 48 | # len_strings = [len(string) for string in total_strings] 49 | 50 | else: 51 | # GraphRNN 52 | len_strings = [] 53 | n_squares = [] 54 | for adj in adjs: 55 | b = bw_from_adj(adj) 56 | n = adj.shape[0] 57 | len_strings.append(n*b-((b*b+b)/2)) 58 | n_squares.append(n*n) 59 | 60 | compression_rates = [length / n_square for length, n_square in zip(len_strings, n_squares)] 61 | 62 | return mean(compression_rates) 63 | 64 | 65 | datas = ['proteins'] 66 | orders = ['BFS', 'DFS', 'C-M'] 67 | result_df = pd.DataFrame(columns = datas, index=orders) 68 | 69 | def get_max_len(data_name, order='C-M', k=2): 70 | total_strings = [] 71 | k_square = k**2 72 | for split in ['train', 'test', 'val']: 73 | if k > 2: 74 | string_path = os.path.join(DATA_DIR, f"{data_name}/{order}/{data_name}_str_{split}_{k}.txt") 75 | else: 76 | string_path = os.path.join(DATA_DIR, f"{data_name}/{order}/{data_name}_str_{split}.txt") 77 | 78 | # string_path = os.path.join(DATA_DIR, f"{data_name}/{order}/{data_name}_str_{split}_{k}.txt") 79 | strings = Path(string_path).read_text(encoding="utf=8").splitlines() 80 | 81 | total_strings.extend(strings) 82 | 83 | # red_strings = [''.join(red) for red in red_list] 84 | 85 | max_len = max([len(string) for string in total_strings]) 86 | group_max_len = max_len / k_square 87 | red_len = [len(remove_redundant(string)) for string in total_strings] 88 | 89 | return max_len, group_max_len, red_len 90 | 91 | for data in ['GDSS_com', 'planar', 'GDSS_enz', 'GDSS_grid']: 92 | # rnn_string = compute_compression_rate(data, 'BFS', 'graphrnn', True) 93 | _, rnn_string, red_len = get_max_len(data) 94 | print(round(mean(red_len), 3)) 95 | 96 | 97 | # result_df.to_csv('compression.csv') 98 | 99 | 100 | -------------------------------------------------------------------------------- /eval_vun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import networkx as nx 4 | 5 | from evaluation.evaluation_spectre import eval_fraction_isomorphic, eval_fraction_unique_non_isomorphic_valid, eval_acc_grid_graph, eval_acc_planar_graph, eval_acc_sbm_graph, is_grid_graph, is_planar_graph, is_sbm_graph, eval_fraction_unique 6 | from evaluation.evaluation_spectre import eval_fraction_unique 7 | from data.data_utils import train_val_test_split, adj_to_graph 8 | from data.load_data import load_proteins_data 9 | 10 | gcg_dict = {'GDSS_com': 'May02-00:49:13', 'GDSS_grid': 'May02-15:56:32', 'GDSS_ego': "May14-07:56:26", 11 | 'GDSS_enz': 'May02-06:39:38' , 'planar': "May14-11:34:11", 'sbm': 'Jun23-07:39:09'} 12 | 13 | graphgen_dict = {'GDSS_com': 'DFScodeRNN_com_small_2023-05-14 01:40:56', 14 | 'GDSS_grid': 'DFScodeRNN_grid_2023-05-14 01:43:04', 15 | 'GDSS_ego': "DFScodeRNN_ego_small_2023-05-14 01:39:50", 16 | 'GDSS_enz': 'DFScodeRNN_enz_2023-05-14 02:11:21', 17 | 'planar': "DFScodeRNN_planar_2023-05-14 01:53:01"} 18 | 19 | digress_dict = {'GDSS_com': '2023-05-13/14-54-28', 20 | 'GDSS_grid': '2023-05-13/15-02-58', 21 | 'GDSS_ego': '2023-05-13/14-53-38', 22 | 'GDSS_enz': '2023-05-13/14-56-45', 23 | 'planar': "2023-05-13/13-49-30"} 24 | 25 | DATA_DIR = "resource" 26 | 27 | def load_generated_graphs(data_name, method): 28 | if method == 'train': 29 | # for train 30 | with open(f'resource/{data_name}/C-M/{data_name}_test_graphs.pkl', 'rb') as f: 31 | graphs = pickle.load(f) 32 | 33 | elif method == 'digress': 34 | # digress 35 | file_name = digress_dict[data_name] 36 | with open(f'../DiGress/src/outputs/{file_name}/graphs/generated_graphs.pkl', 'rb') as f: 37 | graphs = pickle.load(f) 38 | adjs = [graph[1] for graph in graphs] 39 | graphs = [nx.from_numpy_array(adj.numpy()) for adj in adjs] 40 | 41 | elif method == 'graphgen': 42 | file_name = graphgen_dict[data_name] 43 | with open(f'../graphgen/graphs/{file_name}/generated_graphs.pkl', 'rb') as f: 44 | graphs = pickle.load(f) 45 | 46 | elif method == 'gcg': 47 | file_name = gcg_dict[data_name] 48 | with open(f'samples/graphs/{data_name}/{file_name}.pkl', 'rb') as f: 49 | graphs = pickle.load(f) 50 | 51 | elif method == 'gdss': 52 | data_dict = {'GDSS_com': 'community_small', 'GDSS_enz': 'ENZYMES', 'GDSS_grid': 'grid'} 53 | with open(f'../GDSS/samples/pkl/{data_dict[data_name]}/test/{data_name}_sample.pkl', 'rb') as f: 54 | graphs = pickle.load(f) 55 | 56 | return graphs 57 | 58 | def load_train_graphs(data_name): 59 | if data_name in ['planar', 'sbm']: 60 | adjs, _, _, _, _, _, _, _ = torch.load(f'{DATA_DIR}/{data_name}/{data_name}.pt') 61 | graphs = [adj_to_graph(adj.numpy()) for adj in adjs] 62 | 63 | elif data_name == 'proteins': 64 | adjs = load_proteins_data(DATA_DIR) 65 | graphs = [adj_to_graph(adj.numpy()) for adj in adjs] 66 | else: 67 | with open (f'{DATA_DIR}/{data_name}/{data_name}.pkl', 'rb') as f: 68 | graphs = pickle.load(f) 69 | train_graphs, val_graphs, test_graphs = train_val_test_split(graphs, data_name) 70 | return train_graphs 71 | 72 | data_name = 'planar' 73 | gen_graphs = load_generated_graphs(data_name, 'gcg') 74 | 75 | print(eval_fraction_unique(gen_graphs)) 76 | 77 | train_graphs = load_train_graphs(data_name) 78 | 79 | if data_name == 'GDSS_grid': 80 | val = eval_acc_grid_graph(gen_graphs) 81 | validity_func = is_grid_graph 82 | elif data_name == 'sbm': 83 | acc = eval_acc_sbm_graph(gen_graphs, refinement_steps=1000, strict=False) 84 | validity_func = is_sbm_graph 85 | elif data_name == 'planar': 86 | acc = eval_acc_planar_graph(gen_graphs) 87 | validity_func = is_planar_graph 88 | else: 89 | validity_func = lambda x: True 90 | 91 | n = len(gen_graphs) 92 | val = len([is_grid_graph(graph) for graph in gen_graphs])/n 93 | unique, un, vun = eval_fraction_unique_non_isomorphic_valid(gen_graphs, train_graphs, validity_func=validity_func) 94 | novel = eval_fraction_isomorphic(gen_graphs, train_graphs) 95 | 96 | print(f'validity: {round(val, 3)}') 97 | print(f'unique: {round(unique, 3)}') 98 | print(f'novel: {round(1 - novel, 3)}') 99 | if data_name == 'proteins': 100 | print(f'un: {round(un, 3)}') 101 | else: 102 | print(f'vun: {round(vun, 3)}') 103 | -------------------------------------------------------------------------------- /draw_samples.py: -------------------------------------------------------------------------------- 1 | from plot import plot_graphs_list, plot_one_graph, save_fig 2 | import pickle 3 | import networkx as nx 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from scipy.interpolate import make_interp_spline 7 | import numpy as np 8 | 9 | from data.mol_utils import smiles_to_mols 10 | from rdkit.Chem import Draw 11 | 12 | 13 | data_name = 'GDSS_com' 14 | method = 'digress' 15 | 16 | gcg_dict = {'GDSS_com': 'May02-00:49:13', 'GDSS_grid': 'May02-15:56:32', 'GDSS_ego': "May14-07:56:26", 17 | 'GDSS_enz': 'May02-06:39:38' , 'planar': "May14-11:34:11"} 18 | 19 | graphgen_dict = {'GDSS_com': 'DFScodeRNN_com_small_2023-05-14 01:40:56', 20 | 'GDSS_grid': 'DFScodeRNN_grid_2023-05-14 01:43:04', 21 | 'GDSS_ego': "DFScodeRNN_ego_small_2023-05-14 01:39:50", 22 | 'GDSS_enz': 'DFScodeRNN_enz_2023-05-14 02:11:21', 23 | 'planar': "DFScodeRNN_planar_2023-05-14 01:53:01"} 24 | 25 | digress_dict = {'GDSS_com': '2023-05-13/14-54-28', 26 | 'GDSS_grid': '2023-05-13/15-02-58', 27 | 'GDSS_ego': '2023-05-13/14-53-38', 28 | 'GDSS_enz': '2023-05-13/14-56-45', 29 | 'planar': "2023-05-13/13-49-30"} 30 | 31 | def draw_generated_graphs(data_name, method, index=0): 32 | if method == 'train': 33 | # for train 34 | with open(f'resource/{data_name}/C-M/{data_name}_test_graphs.pkl', 'rb') as f: 35 | graphs = pickle.load(f) 36 | 37 | elif method == 'digress': 38 | # digress 39 | file_name = digress_dict[data_name] 40 | with open(f'../DiGress/src/outputs/{file_name}/graphs/generated_graphs.pkl', 'rb') as f: 41 | graphs = pickle.load(f) 42 | adjs = [graph[1] for graph in graphs] 43 | graphs = [nx.from_numpy_array(adj.numpy()) for adj in adjs] 44 | 45 | elif method == 'graphgen': 46 | file_name = graphgen_dict[data_name] 47 | with open(f'../graphgen/graphs/{file_name}/generated_graphs.pkl', 'rb') as f: 48 | graphs = pickle.load(f) 49 | 50 | elif method == 'gcg': 51 | file_name = gcg_dict[data_name] 52 | with open(f'samples/graphs/{data_name}/{file_name}.pkl', 'rb') as f: 53 | graphs = pickle.load(f) 54 | 55 | elif method == 'gdss': 56 | data_dict = {'GDSS_com': 'community_small', 'GDSS_enz': 'ENZYMES', 'GDSS_grid': 'grid'} 57 | with open(f'../GDSS/samples/pkl/{data_dict[data_name]}/test/{data_name}_sample.pkl', 'rb') as f: 58 | graphs = pickle.load(f) 59 | print(len(graphs)) 60 | 61 | 62 | # plot_graphs_list(graphs[:4], title=f'{data_name}-{method}', save_dir=f'figure/{data_name}', max_num=9) 63 | # plot_one_graph(graphs[index], title=f'{method}-one', save_dir=f'figure/{data_name}') 64 | 65 | def draw_generated_molecules(data_name): 66 | if data_name == 'qm9': 67 | with open("samples/smiles/qm9/May09-07:00:25.txt", 'r') as f: 68 | smiles = f.readlines() 69 | elif data_name == 'zinc': 70 | with open("samples/smiles/zinc/May06-13:30:46.txt", 'r') as f: 71 | smiles = f.readlines() 72 | 73 | mols = smiles_to_mols(smiles[:24]) 74 | img = Draw.MolsToGridImage(mols, molsPerRow=8) 75 | img.save(f"samples/fig/figure/{data_name}/{data_name}.png") 76 | 77 | 78 | def draw_loss_plot(): 79 | df = pd.read_csv('resource/planar_ab_pe.csv') 80 | fig, ax = plt.subplots() 81 | x = np.arange(0,500,1/5) 82 | x_y_spline_tpe = make_interp_spline(x, df['tpe'].dropna()) 83 | x_tpe = np.linspace(x.min(), x.max(), 100) 84 | y_tpe = x_y_spline_tpe(x_tpe) 85 | 86 | x_y_spline_ape = make_interp_spline(x, df['ape'].dropna()) 87 | x_ape = np.linspace(x.min(), x.max(), 100) 88 | y_ape = x_y_spline_ape(x_ape) 89 | 90 | x_y_spline_rpe = make_interp_spline(x, df['rpe'].dropna()) 91 | x_rpe = np.linspace(x.min(), x.max(), 100) 92 | y_rpe = x_y_spline_rpe(x_rpe) 93 | 94 | ax.plot(x_tpe, y_tpe, label='TPE', color='#F8C159', linewidth=3) 95 | ax.plot(x_ape, y_ape, label='APE', color='#4384C2', linewidth=3) 96 | ax.plot(x_rpe, y_rpe, label='RPE', color='#EF4C56', linewidth=3) 97 | ax.legend(fontsize=13) 98 | ax.set_xlabel('Epochs', fontsize=13) 99 | ax.set_ylabel('Training loss', fontsize=13) 100 | ax.grid(linestyle='dotted') 101 | for pos in ['left', 'right', 'top', 'bottom']: 102 | ax.spines[pos].set_linewidth(1.5) 103 | 104 | for data in ['GDSS_com', 'GDSS_grid']: 105 | # for data in ['GDSS_grid']: 106 | print(data) 107 | for method in ['train', 'graphgen', 'gdss']: 108 | # for method in ['gdss']: 109 | print(method) 110 | draw_generated_graphs(data, method, 0) 111 | 112 | # draw_generated_molecules('zinc') -------------------------------------------------------------------------------- /data/load_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import networkx as nx 3 | import torch 4 | from tqdm import tqdm 5 | import pandas as pd 6 | 7 | from data.orderings import ORDER_FUNCS, order_graphs 8 | from data.data_utils import train_val_test_split, adj_to_k2_tree, map_new_ordered_graph, adj_to_graph, tree_to_bfs_string 9 | from data.mol_utils import canonicalize_smiles, smiles_to_mols, add_self_loop, tree_to_bfs_string_mol, mols_to_nx 10 | 11 | 12 | DATA_DIR = "resource" 13 | 14 | def generate_string(dataset_name, order='C-M', k=2): 15 | ''' 16 | Generate strings for each dataset / split (without degree (only 0-1)) 17 | ''' 18 | # load molecule graphs 19 | if dataset_name == 'planar': 20 | adjs, _, _, _, _, _, _, _ = torch.load(f'{DATA_DIR}/{dataset_name}/{dataset_name}.pt') 21 | graphs = [adj_to_graph(adj.numpy()) for adj in adjs] 22 | else: 23 | with open (f'{DATA_DIR}/{dataset_name}/{dataset_name}.pkl', 'rb') as f: 24 | graphs = pickle.load(f) 25 | train_graphs, val_graphs, test_graphs = train_val_test_split(graphs, dataset_name) 26 | graph_list = [] 27 | for graphs in train_graphs, val_graphs, test_graphs: 28 | num_rep = 1 29 | # order graphs 30 | order_func = ORDER_FUNCS[order] 31 | total_ordered_graphs = order_graphs(graphs, num_repetitions=num_rep, order_func=order_func, seed=0, is_mol=True) 32 | new_ordered_graphs = [map_new_ordered_graph(graph) for graph in tqdm(total_ordered_graphs, 'Map new ordered graphs')] 33 | graph_list.append(new_ordered_graphs) 34 | 35 | # write graphs 36 | splits = ['train', 'val', 'test'] 37 | 38 | for graphs, split in zip(graph_list, splits): 39 | adjs = [nx.adjacency_matrix(graph, range(len(graph))) for graph in graphs] 40 | trees = [adj_to_k2_tree(torch.Tensor(adj.todense()), return_tree=True, k=k, is_mol=False) for adj in tqdm(adjs, 'Generating tree from adj')] 41 | strings = [tree_to_bfs_string(tree, string_type='group') for tree in tqdm(trees, 'Generating strings from tree')] 42 | file_name = f'{dataset_name}_str_{split}' 43 | with open(f'{DATA_DIR}/{dataset_name}/{order}/{file_name}_{k}.txt', 'w') as f: 44 | for string in strings: 45 | f.write(f'{string}\n') 46 | if split == 'test': 47 | with open(f'{DATA_DIR}/{dataset_name}/{order}/{dataset_name}_test_graphs.pkl', 'wb') as f: 48 | pickle.dump(graphs, f) 49 | return graph_list 50 | 51 | def generate_mol_string(dataset_name, order='C-M', is_small=False): 52 | ''' 53 | Generate strings for each dataset / split (without degree (only 0-1)) 54 | ''' 55 | # load molecule graphs 56 | col_dict = {'qm9': 'SMILES1', 'zinc': 'smiles'} 57 | df = pd.read_csv(f'{DATA_DIR}/{dataset_name}/{dataset_name}.csv') 58 | smiles = list(df[col_dict[dataset_name]]) 59 | if is_small: 60 | smiles = smiles[:100] 61 | smiles = [s for s in smiles if len(s)>1] 62 | smiles = canonicalize_smiles(smiles) 63 | splits = ['train', 'val', 'test'] 64 | train_smiles, val_smiles, test_smiles = train_val_test_split(smiles, dataset_name) 65 | for s, split in zip([train_smiles, val_smiles, test_smiles], splits): 66 | if is_small: 67 | with open(f'{DATA_DIR}/{dataset_name}/{order}/{dataset_name}_small_smiles_{split}.txt', 'w') as f: 68 | for string in s: 69 | f.write(f'{string}\n') 70 | else: 71 | with open(f'{DATA_DIR}/{dataset_name}/{order}/{dataset_name}_smiles_{split}.txt', 'w') as f: 72 | for string in s: 73 | f.write(f'{string}\n') 74 | graph_list = [] 75 | for smiles in train_smiles, val_smiles, test_smiles: 76 | mols = smiles_to_mols(smiles) 77 | graphs = mols_to_nx(mols) 78 | graphs = [add_self_loop(graph) for graph in tqdm(graphs, 'Adding self-loops')] 79 | num_rep = 1 80 | # order graphs 81 | order_func = ORDER_FUNCS[order] 82 | total_graphs = graphs 83 | total_ordered_graphs = order_graphs(total_graphs, num_repetitions=num_rep, order_func=order_func, seed=0, is_mol=True) 84 | new_ordered_graphs = [map_new_ordered_graph(graph) for graph in tqdm(total_ordered_graphs, 'Map new ordered graphs')] 85 | graph_list.append(new_ordered_graphs) 86 | 87 | # write graphs 88 | 89 | for graphs, split in zip(graph_list, splits): 90 | weighted_adjs = [nx.attr_matrix(graph, edge_attr='label', rc_order=range(len(graph))) for graph in graphs] 91 | trees = [adj_to_k2_tree(torch.Tensor(adj), return_tree=True, is_mol=True) for adj in tqdm(weighted_adjs, 'Generating tree from adj')] 92 | strings = [tree_to_bfs_string_mol(tree) for tree in tqdm(trees, 'Generating strings from tree')] 93 | if is_small: 94 | file_name = f'{dataset_name}_small_str_{split}' 95 | else: 96 | file_name = f'{dataset_name}_str_{split}' 97 | with open(f'{DATA_DIR}/{dataset_name}/{order}/{file_name}.txt', 'w') as f: 98 | for string in strings: 99 | f.write(f'{string}\n') 100 | if split == 'test': 101 | if is_small: 102 | with open(f'{DATA_DIR}/{dataset_name}/{order}/{dataset_name}_small_test_graphs.pkl', 'wb') as f: 103 | pickle.dump(graphs, f) 104 | else: 105 | with open(f'{DATA_DIR}/{dataset_name}/{order}/{dataset_name}_test_graphs.pkl', 'wb') as f: 106 | pickle.dump(graphs, f) -------------------------------------------------------------------------------- /data/orderings.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import random 3 | from collections import deque 4 | from operator import itemgetter 5 | from dataclasses import dataclass 6 | from itertools import chain 7 | import numpy as np 8 | import networkx as nx 9 | import torch 10 | from torch_geometric.data import Data 11 | from torch_geometric.utils import from_scipy_sparse_matrix 12 | 13 | # Code adpated from https://github.com/Genentech/bandwidth-graph-generation 14 | def bw_from_adj(A: np.ndarray) -> int: 15 | """calculate bandwidth from adjacency matrix""" 16 | band_sizes = np.arange(A.shape[0]) - A.argmax(axis=1) 17 | return band_sizes.max() 18 | 19 | 20 | def random_BFS_order(G: nx.Graph, seed=0) -> tuple: 21 | """ 22 | :param G: Graph 23 | :return: random BFS order, maximum queue length (equal to bandwidth of ordering) 24 | """ 25 | connected_components = list(nx.connected_components(G)) 26 | if len(connected_components) > 1: 27 | graphs = [G.subgraph(cc) for cc in connected_components] 28 | else: 29 | graphs = [G] 30 | 31 | order_list = [] 32 | for graph in graphs: 33 | start = random.choice(list(graph)) 34 | edges = nx.bfs_edges(graph, start) 35 | nodes = [start] + [v for u, v in edges] 36 | order_list.append(nodes) 37 | 38 | if len(order_list) == 1: 39 | return order_list[0], 0 40 | else: 41 | return list(chain(*order_list)), 0 42 | 43 | 44 | def bw_from_order(G: nx.Graph, order: list) -> int: 45 | return bw_from_adj(nx.to_numpy_array(G, nodelist=order)) 46 | 47 | 48 | def random_DFS_order(G: nx.Graph, seed=0): 49 | """ 50 | :param G: Graph 51 | :return: random DFS order, maximum queue length (equal to bandwidth of ordering) 52 | """ 53 | connected_components = list(nx.connected_components(G)) 54 | if len(connected_components) > 1: 55 | graphs = [G.subgraph(cc) for cc in connected_components] 56 | else: 57 | graphs = [G] 58 | 59 | order_list = [] 60 | for graph in graphs: 61 | start = random.choice(list(graph)) 62 | edges = nx.dfs_edges(graph, start) 63 | nodes = [start] + [v for u, v in edges] 64 | order_list.append(nodes) 65 | 66 | if len(order_list) == 1: 67 | return order_list[0], 0 68 | else: 69 | return list(chain(*order_list)), 0 70 | 71 | 72 | def uniform_random_order(G: nx.Graph) -> tuple: 73 | order = list(G.nodes()) 74 | random.shuffle(order) 75 | bw = bw_from_order(G, order) 76 | return order, bw 77 | 78 | def random_connected_cuthill_mckee_ordering(G: nx.Graph, seed=0, heuristic=None) -> tuple: 79 | """ 80 | adapted from NX source. 81 | :return: node order, bandwidth 82 | """ 83 | # the cuthill mckee algorithm for connected graphs 84 | random.seed(seed) 85 | connected_components = list(nx.connected_components(G)) 86 | if len(connected_components) > 1: 87 | graphs = [G.subgraph(cc) for cc in connected_components] 88 | else: 89 | graphs = [G] 90 | 91 | order_list = [] 92 | for graph in graphs: 93 | if heuristic is None: 94 | start = pseudo_peripheral_node(graph, seed) 95 | else: 96 | start = heuristic(graph) 97 | visited = {start} 98 | queue = deque([start]) 99 | max_q_len = 1 100 | i = 0 101 | order = [] 102 | while queue: 103 | parent = queue.popleft() 104 | order.append(parent) 105 | random.seed(seed+i) 106 | key = random.random() 107 | nd = sorted(list(G.degree(set(G[parent]) - visited)), key=lambda x: (x[1], key)) 108 | children = [n for n, d in nd] 109 | visited.update(children) 110 | queue.extend(children) 111 | max_q_len = max(len(queue), max_q_len) 112 | i+=1 113 | order_list.append(order) 114 | 115 | if len(order_list) == 1: 116 | return order_list[0], 0 117 | else: 118 | return list(chain(*order_list)), 0 119 | 120 | 121 | def pseudo_peripheral_node(G: nx.Graph, seed=0) -> int: 122 | """adapted from NX source""" 123 | # helper for cuthill-mckee to find a node in a "pseudo peripheral pair" 124 | # to use as good starting node 125 | random.seed(seed) 126 | u = random.choice(list(G)) 127 | lp = 0 128 | v = u 129 | while True: 130 | spl = dict(nx.shortest_path_length(G, v)) 131 | l = max(spl.values()) 132 | if l <= lp: 133 | break 134 | lp = l 135 | farthest = (n for n, dist in spl.items() if dist == l) 136 | v, deg = min(G.degree(farthest), key=itemgetter(1)) 137 | return v 138 | 139 | 140 | @dataclass 141 | class OrderedGraph: 142 | graph: nx.Graph 143 | seed: int 144 | ordering: list 145 | bw: int 146 | 147 | def to_data(self) -> Data: 148 | A = nx.to_scipy_sparse_matrix(self.graph, nodelist=self.ordering) 149 | edge_index = from_scipy_sparse_matrix(A)[0] 150 | return Data(edge_index=edge_index) 151 | 152 | def to_adjacency(self) -> torch.Tensor: 153 | return torch.tensor( 154 | nx.to_numpy_array(self.graph, nodelist=self.ordering), 155 | dtype=torch.float32, 156 | ) 157 | 158 | 159 | def order_graphs( 160 | graphs: list, 161 | order_func, 162 | num_repetitions: int = 1, seed: int = 0, is_mol=False 163 | ): 164 | ordered_graphs = [] 165 | for i, graph in enumerate(graphs): 166 | for j in range(num_repetitions): 167 | graph = graph.copy() 168 | # seed = i * (j + 1) + j 169 | random.seed(seed) 170 | np.random.seed(seed) 171 | if not is_mol: 172 | graph.remove_edges_from(nx.selfloop_edges(graph)) 173 | graph = nx.convert_node_labels_to_integers(graph) 174 | order, bw = order_func(graph, seed) 175 | ordered_graphs.append(OrderedGraph( 176 | graph=graph, seed=seed, 177 | ordering=order, bw=bw, 178 | )) 179 | return ordered_graphs 180 | 181 | 182 | ORDER_FUNCS = { 183 | "C-M": random_connected_cuthill_mckee_ordering, 184 | "BFS": random_BFS_order, 185 | "DFS": random_DFS_order, 186 | } 187 | -------------------------------------------------------------------------------- /trainer/train_trans_generator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | import wandb 5 | import os 6 | from pytorch_lightning.loggers import WandbLogger 7 | from torch.nn.utils.rnn import pad_sequence 8 | from pytorch_lightning.callbacks import ModelCheckpoint, Timer 9 | 10 | from evaluation.evaluation import compute_sequence_cross_entropy 11 | from model.trans_generator import TransGenerator 12 | from train_generator import BaseGeneratorLightningModule 13 | 14 | from signal import signal, SIGPIPE, SIG_DFL 15 | signal(SIGPIPE,SIG_DFL) 16 | 17 | 18 | class TransGeneratorLightningModule(BaseGeneratorLightningModule): 19 | def __init__(self, hparams): 20 | super().__init__(hparams) 21 | 22 | def setup_model(self, hparams): 23 | self.model = TransGenerator( 24 | num_layers=hparams.num_layers, 25 | emb_size=hparams.emb_size, 26 | nhead=hparams.nhead, 27 | dim_feedforward=hparams.dim_feedforward, 28 | input_dropout=hparams.input_dropout, 29 | dropout=hparams.dropout, 30 | max_len=hparams.max_len, 31 | string_type=hparams.string_type, 32 | tree_pos=hparams.tree_pos, 33 | pos_type=hparams.pos_type, 34 | learn_pos=hparams.learn_pos, 35 | abs_pos=hparams.abs_pos, 36 | k=hparams.k 37 | ) 38 | 39 | ### 40 | def train_dataloader(self): 41 | return DataLoader( 42 | self.train_dataset, 43 | batch_size=self.hparams.batch_size, 44 | shuffle=True, 45 | collate_fn=lambda sequences: pad_sequence(sequences, batch_first=True, padding_value=0), 46 | num_workers=self.hparams.num_workers, 47 | ) 48 | 49 | def test_dataloader(self): 50 | return DataLoader( 51 | self.test_dataset, 52 | batch_size=self.hparams.batch_size, 53 | shuffle=False, 54 | collate_fn=lambda sequences: pad_sequence(sequences, batch_first=True, padding_value=0), 55 | num_workers=self.hparams.num_workers, 56 | ) 57 | 58 | def val_dataloader(self): 59 | return DataLoader( 60 | self.val_dataset, 61 | batch_size=self.hparams.batch_size, 62 | shuffle=False, 63 | collate_fn=lambda sequences: pad_sequence(sequences, batch_first=True, padding_value=0), 64 | num_workers=self.hparams.num_workers, 65 | ) 66 | 67 | ### Main steps 68 | def shared_step(self, batched_data): 69 | loss, statistics = 0.0, dict() 70 | logits = self.model(batched_data) 71 | loss = compute_sequence_cross_entropy(logits, batched_data, self.hparams.string_type) 72 | statistics["loss/total"] = loss 73 | # statistics["acc/total"] = compute_sequence_accuracy(logits, batched_data, ignore_index=0)[0] 74 | return loss, statistics 75 | 76 | 77 | @staticmethod 78 | def add_args(parser): 79 | 80 | parser.add_argument("--dataset_name", type=str, default="qm9") 81 | parser.add_argument("--batch_size", type=int, default=32) 82 | parser.add_argument("--num_workers", type=int, default=6) 83 | parser.add_argument("--ckpt_path", type=str, default='no') 84 | 85 | 86 | parser.add_argument("--order", type=str, default="C-M") 87 | parser.add_argument("--replicate", type=int, default=0) 88 | # 89 | parser.add_argument("--emb_size", type=int, default=512) 90 | parser.add_argument("--dropout", type=float, default=0.1) 91 | parser.add_argument("--lr", type=float, default=0.002) 92 | 93 | parser.add_argument("--check_sample_every_n_epoch", type=int, default=2) 94 | parser.add_argument("--num_samples", type=int, default=50) 95 | parser.add_argument("--sample_batch_size", type=int, default=50) 96 | parser.add_argument("--max_epochs", type=int, default=20) 97 | parser.add_argument("--wandb_on", type=str, default='disabled') 98 | 99 | parser.add_argument("--group", type=str, default='string') 100 | parser.add_argument("--model", type=str, default='trans') 101 | parser.add_argument("--max_len", type=int, default=87) 102 | parser.add_argument("--string_type", type=str, default='zinc-red-high') 103 | parser.add_argument("--max_depth", type=int, default=20) 104 | 105 | # transformer 106 | parser.add_argument("--num_layers", type=int, default=3) 107 | parser.add_argument("--nhead", type=int, default=8) 108 | parser.add_argument("--dim_feedforward", type=int, default=512) 109 | parser.add_argument("--input_dropout", type=float, default=0.0) 110 | parser.add_argument("--tree_pos", action="store_true") 111 | parser.add_argument("--pos_type", type=str, default='emb') 112 | parser.add_argument("--gradient_clip_val", type=float, default=1.0) 113 | parser.add_argument("--learn_pos", action="store_true") 114 | parser.add_argument("--abs_pos", action="store_true") 115 | 116 | parser.add_argument("--k", type=int, default=2) 117 | 118 | return parser 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser() 123 | TransGeneratorLightningModule.add_args(parser) 124 | hparams = parser.parse_args() 125 | 126 | pos_type_dict = {'emb': 'tpe', 'group-emb': 'gtpe'} 127 | 128 | wandb_logger = WandbLogger(name=f'{hparams.dataset_name}-{hparams.model}-{hparams.string_type}-{pos_type_dict[hparams.pos_type]}-{hparams.k}', 129 | project='k2g', group=f'{hparams.group}', mode=f'{hparams.wandb_on}') 130 | 131 | wandb.config.update(hparams) 132 | 133 | 134 | model = TransGeneratorLightningModule(hparams) 135 | 136 | checkpoint_callback = ModelCheckpoint( 137 | dirpath=os.path.join("resource/checkpoint/", wandb.run.id), 138 | ) 139 | 140 | wandb.watch(model) 141 | timer = Timer(duration="00:12:00:00") 142 | trainer = pl.Trainer( 143 | devices=1, 144 | accelerator='gpu', 145 | default_root_dir="/resource/log/", 146 | max_epochs=hparams.max_epochs, 147 | gradient_clip_val=hparams.gradient_clip_val, 148 | callbacks=[checkpoint_callback, timer], 149 | logger=wandb_logger 150 | ) 151 | 152 | trainer.fit(model) 153 | print(round(timer.time_elapsed("train"),3)) -------------------------------------------------------------------------------- /data/tokens.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import numpy as np 3 | import re 4 | import random 5 | 6 | 7 | PAD_TOKEN = "[pad]" 8 | BOS_TOKEN = "[bos]" 9 | EOS_TOKEN = "[eos]" 10 | UNK_TOKEN = "" 11 | 12 | standard_tokens = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN] 13 | 14 | TOKENS_GROUP = standard_tokens.copy() 15 | group_num_tokens = list(product([0,1], repeat=4)) 16 | TOKENS_GROUP.extend([''.join(str(token)).replace(', ', '')[1:-1] for token in group_num_tokens if token!=(0,0,0,0)]) 17 | 18 | TOKENS_GROUP_RED = TOKENS_GROUP.copy() 19 | group_num_tokens = list(product([0,1], repeat=3)) 20 | TOKENS_GROUP_RED.extend([''.join(str(token)).replace(', ', '')[1:-1] for token in group_num_tokens if token!=(0,0,0)]) 21 | 22 | TOKENS_GROUP_RED_DICT = {} 23 | TOKENS_GROUP_RED_DICT[2] = TOKENS_GROUP_RED 24 | TOKENS_GROUP_THREE = standard_tokens.copy() 25 | group_num_tokens = list(product([0,1], repeat=9)) 26 | TOKENS_GROUP_THREE.extend([''.join(str(token)).replace(', ', '')[1:-1] for token in group_num_tokens if token!=tuple(np.zeros(9, dtype=int))]) 27 | group_num_tokens = list(product([0,1], repeat=6)) 28 | TOKENS_GROUP_THREE.extend([''.join(str(token)).replace(', ', '')[1:-1] for token in group_num_tokens if token!=tuple(np.zeros(6, dtype=int))]) 29 | TOKENS_GROUP_RED_DICT[3] = TOKENS_GROUP_THREE 30 | 31 | 32 | def grouper_mol(string, k=2): 33 | non_group_tokens = [] 34 | k_square = k**2 35 | string_iter = iter(string) 36 | peek = None 37 | while True: 38 | char = peek if peek else next(string_iter, "") 39 | peek = None 40 | if not char: 41 | break 42 | if char in ['C', 'B']: 43 | peek = next(string_iter, "") 44 | if char + peek in ['Cl', 'Br']: 45 | token = char + peek 46 | peek = None 47 | else: 48 | token = char 49 | else: 50 | token = char 51 | non_group_tokens.append(token) 52 | string_cut = [non_group_tokens[i:i+k_square] for i in range(0,len(non_group_tokens),k_square)] 53 | cut_list = [*string_cut] 54 | return cut_list 55 | 56 | TOKENS_MOL = TOKENS_GROUP.copy() 57 | # 5: single / 6: double / 7:triple / 8: aromatic 58 | bond_tokens = [0,5,6,7,8] 59 | # mol_bond_tokens: only edge type + 0s (without diagonal) 60 | mol_bond_tokens = list(product(bond_tokens, repeat=4)) 61 | TOKENS_MOL.extend([''.join(str(token)).replace(', ', '')[1:-1] for token in mol_bond_tokens]) 62 | 63 | node_tokens_dict = {'qm9': ['F', 'O', 'N', 'C'], 'zinc': ['F', 'O', 'N', 'C', 'P', 'I', 'Cl', 'Br', 'S']} 64 | additional_tokens_dict = dict() 65 | 66 | for data in ['qm9', 'zinc']: 67 | node_tokens = node_tokens_dict[data] 68 | additional_tokens = list(product(node_tokens, bond_tokens, bond_tokens, node_tokens)) 69 | additional_tokens.extend(list(product(node_tokens, [0], [0], [0]))) 70 | additional_tokens.extend(list(product(bond_tokens, node_tokens, [0], [0]))) 71 | additional_tokens.extend(list(product(bond_tokens, [0], node_tokens, [0]))) 72 | additional_tokens.extend(list(product(bond_tokens, [0], node_tokens, bond_tokens))) 73 | # additional_tokens.extend(list(product(bond_tokens, node_tokens, [0], bond_tokens))) 74 | additional_tokens_dict[data] = additional_tokens 75 | 76 | TOKENS_QM9 = TOKENS_MOL.copy() 77 | TOKENS_QM9.extend([''.join(str(token)).replace(', ', '').replace('\'', '')[1:-1] for token in additional_tokens_dict['qm9']]) 78 | TOKENS_ZINC = TOKENS_MOL.copy() 79 | TOKENS_ZINC.extend([''.join(str(token)).replace(', ', '').replace('\'', '')[1:-1] for token in additional_tokens_dict['zinc']]) 80 | 81 | TOKENS_MOL_RED = TOKENS_GROUP_RED.copy() 82 | TOKENS_MOL_RED.extend([''.join(str(token)).replace(', ', '')[1:-1] for token in mol_bond_tokens]) 83 | mol_bond_tokens_red = list(product(bond_tokens, repeat=3)) 84 | TOKENS_MOL_RED.extend([''.join(str(token)).replace(', ', '')[1:-1] for token in mol_bond_tokens_red]) 85 | 86 | TOKENS_QM9_RED = TOKENS_MOL_RED.copy() 87 | qm9_additional_tokens = [''.join(str(token)).replace(', ', '').replace('\'', '')[1:-1] for token in additional_tokens_dict['qm9']] 88 | TOKENS_QM9_RED.extend(qm9_additional_tokens) 89 | TOKENS_QM9_RED.extend([token[:1]+token[2:] for token in qm9_additional_tokens]) 90 | 91 | TOKENS_ZINC_RED = TOKENS_MOL_RED.copy() 92 | zinc_additional_tokens = [''.join(str(token)).replace(', ', '').replace('\'', '')[1:-1] for token in additional_tokens_dict['zinc']] 93 | TOKENS_ZINC_RED.extend(zinc_additional_tokens) 94 | group_ad_tokens = [grouper_mol(token, 2)[0] for token in zinc_additional_tokens] 95 | for group in group_ad_tokens: 96 | del group[1] 97 | 98 | TOKENS_ZINC_RED.extend(list(set([''.join(token) for token in group_ad_tokens]))) 99 | 100 | TOKENS_DICT = {'group-red': TOKENS_GROUP_RED, 101 | 'group-red-3': TOKENS_GROUP_THREE, 'qm9-red': TOKENS_QM9_RED, 'zinc-red': TOKENS_ZINC_RED} 102 | 103 | 104 | def token_list_to_dict(tokens): 105 | return {token: i for i, token in enumerate(tokens)} 106 | 107 | TOKENS_KEY_DICT = {key: token_list_to_dict(value) for key, value in TOKENS_DICT.items()} 108 | 109 | def token_to_id(string_type, k=2): 110 | return TOKENS_KEY_DICT[string_type] 111 | 112 | def id_to_token(tokens): 113 | return {idx: tokens[idx] for idx in range(len(tokens))} 114 | 115 | def tokenize(string, string_type, k): 116 | tokens = ["[bos]"] 117 | if 'red' in string_type: 118 | tokens.extend(string) 119 | else: 120 | tokens.extend([*string]) 121 | tokens.append("[eos]") 122 | TOKEN2ID = token_to_id(string_type, k) 123 | return [TOKEN2ID[token] for token in tokens] 124 | 125 | def map_one(token): 126 | mapping_dict = {'2': '1', '3': '1', '4': '1'} 127 | return ''.join([mapping_dict.get(x, x) for x in token]) 128 | 129 | def untokenize(sequence, string_type, k=2): 130 | ID2TOKEN = id_to_token(TOKENS_DICT[string_type]) 131 | tokens = [ID2TOKEN[id_] for id_ in sequence] 132 | org_tokens = tokens 133 | if tokens[0] != "[bos]": 134 | return "", "".join(org_tokens) 135 | elif "[eos]" not in tokens: 136 | return "", "".join(org_tokens) 137 | 138 | tokens = tokens[1 : tokens.index("[eos]")] 139 | if ("[bos]" in tokens) or ("[pad]" in tokens): 140 | return "", "".join(org_tokens) 141 | 142 | return tokens, org_tokens 143 | 144 | def replace_character_random(token): 145 | chars = token.split(' ') 146 | final_token = "" 147 | for i, char in enumerate(chars): 148 | random.seed(token+char+str(i)) 149 | add_int = random.randrange(0,10) 150 | if char.isnumeric(): 151 | final_token += char 152 | else: 153 | final_token += char + str(add_int) 154 | final_token += " " 155 | return final_token[:-1] 156 | -------------------------------------------------------------------------------- /data/mol_utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from rdkit import Chem 3 | from tqdm import tqdm 4 | import torch 5 | import numpy as np 6 | import re 7 | 8 | from data.data_utils import TYPE_NODE_DICT, NODE_TYPE_DICT, BOND_TYPE_DICT 9 | 10 | 11 | DATA_DIR = "resource" 12 | 13 | # codes adapted from https://github.com/harryjo97/GDSS 14 | 15 | def canonicalize_smiles(smiles): 16 | return [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in tqdm(smiles, 'Canonicalizing')] 17 | 18 | def mols_to_smiles(mols): 19 | return [Chem.MolToSmiles(mol) for mol in tqdm(mols, 'molecules to SMILES')] 20 | 21 | def smiles_to_mols(smiles): 22 | return [Chem.MolFromSmiles(s) for s in tqdm(smiles, 'SMILES to molecules')] 23 | 24 | def tree_to_bfs_string_mol(tree): 25 | bfs_node_list = [tree[node] for node in tree.expand_tree(mode=tree.WIDTH, 26 | key=lambda x: (int(x.identifier.split('-')[0]), int(x.identifier.split('-')[1])))][1:] 27 | bfs_value_list = [str(int(node.tag)) for node in bfs_node_list] 28 | 29 | final_value_list = [TYPE_NODE_DICT[token] if token in TYPE_NODE_DICT.keys() else token for token in bfs_value_list] 30 | 31 | return ''.join(final_value_list) 32 | 33 | def add_self_loop(graph): 34 | for node in graph.nodes: 35 | node_label = graph.nodes[node]['label'] 36 | graph.add_edge(node, node, label=node_label) 37 | return graph 38 | 39 | def mols_to_nx(mols): 40 | nx_graphs = [] 41 | for mol in tqdm(mols, 'Molecules to graph'): 42 | if not mol: 43 | continue 44 | G = nx.Graph() 45 | 46 | for atom in mol.GetAtoms(): 47 | G.add_node(atom.GetIdx(), 48 | label=NODE_TYPE_DICT[atom.GetSymbol()]) 49 | for bond in mol.GetBonds(): 50 | G.add_edge(bond.GetBeginAtomIdx(), 51 | bond.GetEndAtomIdx(), 52 | label=BOND_TYPE_DICT[bond.GetBondTypeAsDouble()]) 53 | nx_graphs.append(G) 54 | 55 | return nx_graphs 56 | 57 | def check_adj_validity_mol(adj): 58 | if adj.size == 0: 59 | return None 60 | non_padded_index = max(max(np.argwhere(adj.any(axis=0)))[0], max(np.argwhere(adj.any(axis=1)))[0])+1 61 | x = adj.diagonal()[:non_padded_index] 62 | # check if diagonal elements are all node types and all diagonal elements are full / not proper bond type 63 | if len([atom for atom in x if atom in NODE_TYPE_DICT.values()]) == non_padded_index: 64 | # not proper bond type 65 | check_bond_adj = adj.copy() 66 | np.fill_diagonal(check_bond_adj, 0) 67 | bond_type_set = set(check_bond_adj.flatten()) 68 | bond_type_set.remove(0) 69 | if len([bt for bt in bond_type_set if bt not in BOND_TYPE_DICT.values()]) == 0: 70 | return adj 71 | else: 72 | return None 73 | else: 74 | return None 75 | 76 | def adj_to_graph_mol(weighted_adj, is_cuda=False): 77 | if is_cuda: 78 | weighted_adj = weighted_adj.detach().cpu().numpy() 79 | 80 | non_padded_index = max(max(np.argwhere(weighted_adj.any(axis=0)))[0], max(np.argwhere(weighted_adj.any(axis=1)))[0])+1 81 | adj = weighted_adj[:non_padded_index, :non_padded_index] 82 | 83 | x = adj.diagonal().copy() 84 | np.fill_diagonal(adj, 0) 85 | 86 | mol, no_correct = generate_mol(x, adj) 87 | return mol, no_correct 88 | 89 | ATOM_VALENCY = {12: 4, 11: 3, 10: 2, 9: 1, 13: 3, 17: 2, 15: 1, 16: 1, 14: 1} 90 | bond_decoder = {5: Chem.rdchem.BondType.SINGLE, 6: Chem.rdchem.BondType.DOUBLE, 91 | 7: Chem.rdchem.BondType.TRIPLE, 8: Chem.rdchem.BondType.AROMATIC} 92 | NODE_TYPE_TO_ATOM_NUM = {9: 9, 10: 8, 11: 7, 12: 6, 13: 15, 14: 53, 15: 17, 16: 35, 17: 16} 93 | 94 | def generate_mol(x, adj): 95 | mol = construct_mol(x, adj) 96 | cmol, no_correct = correct_mol(mol) 97 | vcmol = valid_mol_can_with_seg(cmol, largest_connected_comp=True) 98 | return vcmol, no_correct 99 | 100 | def construct_mol(x, adj): 101 | mol = Chem.RWMol() 102 | for atom in x: 103 | mol.AddAtom(Chem.Atom(NODE_TYPE_TO_ATOM_NUM[atom])) 104 | 105 | for start, end in zip(*np.nonzero(adj)): 106 | if start > end: 107 | mol.AddBond(int(start), int(end), bond_decoder[adj[start, end]]) 108 | flag, atomid_valence = check_valency(mol) 109 | if flag: 110 | continue 111 | else: 112 | assert len(atomid_valence) == 2 113 | idx = atomid_valence[0] 114 | v = atomid_valence[1] 115 | an = mol.GetAtomWithIdx(idx).GetAtomicNum() 116 | if an in (10, 11, 17) and (v - ATOM_VALENCY[an]) == 1: 117 | mol.GetAtomWithIdx(idx).SetFormalCharge(1) 118 | return mol 119 | 120 | 121 | def check_valency(mol): 122 | try: 123 | Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) 124 | return True, None 125 | except ValueError as e: 126 | e = str(e) 127 | p = e.find('#') 128 | e_sub = e[p:] 129 | atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) 130 | return False, atomid_valence 131 | 132 | # codes adapted from https://github.com/cvignac/DiGress 133 | def correct_mol(m): 134 | # xsm = Chem.MolToSmiles(x, isomericSmiles=True) 135 | mol = m 136 | 137 | ##### 138 | no_correct = False 139 | flag, _ = check_valency(mol) 140 | if flag: 141 | no_correct = True 142 | 143 | while True: 144 | flag, atomid_valence = check_valency(mol) 145 | if flag: 146 | break 147 | else: 148 | assert len(atomid_valence) == 2 149 | idx = atomid_valence[0] 150 | v = atomid_valence[1] 151 | queue = [] 152 | check_idx = 0 153 | for b in mol.GetAtomWithIdx(idx).GetBonds(): 154 | type = int(b.GetBondType()) 155 | queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())) 156 | if type == 12: 157 | check_idx += 1 158 | queue.sort(key=lambda tup: tup[1], reverse=True) 159 | 160 | if queue[-1][1] == 12: 161 | return None, no_correct 162 | elif len(queue) > 0: 163 | start = queue[check_idx][2] 164 | end = queue[check_idx][3] 165 | t = queue[check_idx][1] - 1 166 | mol.RemoveBond(start, end) 167 | if t >= 1: 168 | mol.AddBond(start, end, bond_decoder[t+4]) 169 | return mol, no_correct 170 | 171 | def valid_mol_can_with_seg(m, largest_connected_comp=True): 172 | if m is None: 173 | return None 174 | sm = Chem.MolToSmiles(m, isomericSmiles=True) 175 | if largest_connected_comp and '.' in sm: 176 | vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.') 177 | vsm.sort(key=lambda tup: tup[1], reverse=True) 178 | mol = Chem.MolFromSmiles(vsm[0][0]) 179 | else: 180 | mol = Chem.MolFromSmiles(sm) 181 | return mol 182 | 183 | def fix_symmetry_mol(weighted_adj): 184 | sym_adj = torch.tril(weighted_adj) + torch.tril(weighted_adj).T 185 | sym_adj[range(len(sym_adj)), range(len(sym_adj))] = sym_adj.diagonal()/2 186 | return sym_adj -------------------------------------------------------------------------------- /evaluation/evaluation_spectre.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | # import graph_tool.all as gt 3 | from scipy.stats import chi2 4 | import os 5 | import numpy as np 6 | import torch 7 | import concurrent.futures 8 | import copy 9 | from tqdm import tqdm 10 | 11 | # codes adapted from https://github.com/KarolisMart/SPECTRE 12 | def is_planar_graph(G): 13 | return nx.is_connected(G) and nx.check_planarity(G)[0] 14 | 15 | def is_grid_graph(G): 16 | """ 17 | Check if the graph is grid, by comparing with all the real grids with the same node count 18 | """ 19 | all_grid_file = f"data/all_grids.pt" 20 | if os.path.isfile(all_grid_file): 21 | all_grids = torch.load(all_grid_file) 22 | else: 23 | all_grids = {} 24 | for i in range(2, 20): 25 | for j in range(2, 20): 26 | G_grid = nx.grid_2d_graph(i, j) 27 | n_nodes = f"{len(G_grid.nodes())}" 28 | all_grids[n_nodes] = all_grids.get(n_nodes, []) + [G_grid] 29 | torch.save(all_grids, all_grid_file) 30 | 31 | n_nodes = f"{len(G.nodes())}" 32 | if n_nodes in all_grids: 33 | for G_grid in all_grids[n_nodes]: 34 | if nx.faster_could_be_isomorphic(G, G_grid): 35 | if nx.is_isomorphic(G, G_grid): 36 | return True 37 | return False 38 | else: 39 | return False 40 | 41 | def is_sbm_graph(G, p_intra=0.3, p_inter=0.005, strict=True, refinement_steps=1000): 42 | """ 43 | Check if how closely given graph matches a SBM with given probabilites by computing mean probability of Wald test statistic for each recovered parameter 44 | """ 45 | 46 | adj = nx.adjacency_matrix(G).toarray() 47 | idx = adj.nonzero() 48 | g = gt.Graph() 49 | g.add_edge_list(np.transpose(idx)) 50 | try: 51 | state = gt.minimize_blockmodel_dl(g) 52 | except ValueError: 53 | if strict: 54 | return False 55 | else: 56 | return 0.0 57 | 58 | # Refine using merge-split MCMC 59 | for i in range(refinement_steps): 60 | state.multiflip_mcmc_sweep(beta=np.inf, niter=10) 61 | 62 | b = state.get_blocks() 63 | b = gt.contiguous_map(state.get_blocks()) 64 | state = state.copy(b=b) 65 | e = state.get_matrix() 66 | n_blocks = state.get_nonempty_B() 67 | node_counts = state.get_nr().get_array()[:n_blocks] 68 | edge_counts = e.todense()[:n_blocks, :n_blocks] 69 | if strict: 70 | if (node_counts > 40).sum() > 0 or (node_counts < 20).sum() > 0 or n_blocks > 5 or n_blocks < 2: 71 | return False 72 | 73 | max_intra_edges = node_counts * (node_counts - 1) 74 | est_p_intra = np.diagonal(edge_counts) / (max_intra_edges + 1e-6) 75 | 76 | max_inter_edges = node_counts.reshape((-1, 1)) @ node_counts.reshape((1, -1)) 77 | np.fill_diagonal(edge_counts, 0) 78 | est_p_inter = edge_counts / (max_inter_edges + 1e-6) 79 | 80 | W_p_intra = (est_p_intra - p_intra)**2 / (est_p_intra * (1-est_p_intra) + 1e-6) 81 | W_p_inter = (est_p_inter - p_inter)**2 / (est_p_inter * (1-est_p_inter) + 1e-6) 82 | 83 | W = W_p_inter.copy() 84 | np.fill_diagonal(W, W_p_intra) 85 | p = 1 - chi2.cdf(abs(W), 1) 86 | p = p.mean() 87 | if strict: 88 | return p > 0.9 # p value < 10 % 89 | else: 90 | return p 91 | 92 | def eval_fraction_unique(fake_graphs, precise=False): 93 | count_non_unique = 0 94 | fake_evaluated = [] 95 | for fake_g in fake_graphs: 96 | unique = True 97 | if not fake_g.number_of_nodes() == 0: 98 | for fake_old in fake_evaluated: 99 | if precise: 100 | if nx.faster_could_be_isomorphic(fake_g, fake_old): 101 | if nx.is_isomorphic(fake_g, fake_old): 102 | count_non_unique += 1 103 | unique = False 104 | break 105 | else: 106 | if nx.faster_could_be_isomorphic(fake_g, fake_old): 107 | if nx.could_be_isomorphic(fake_g, fake_old): 108 | count_non_unique += 1 109 | unique = False 110 | break 111 | if unique: 112 | fake_evaluated.append(fake_g) 113 | 114 | frac_unique = (float(len(fake_graphs)) - count_non_unique) / float(len(fake_graphs)) # Fraction of distinct isomorphism classes in the fake graphs 115 | 116 | return frac_unique 117 | 118 | def eval_fraction_unique_ego(fake_graphs): 119 | gen_hash = [nx.weisfeiler_lehman_graph_hash(graph) for graph in tqdm(fake_graphs)] 120 | return len(set(gen_hash))/len(gen_hash) 121 | 122 | 123 | def eval_fraction_isomorphic(fake_graphs, train_graphs): 124 | count = 0 125 | for fake_g in tqdm(fake_graphs): 126 | for train_g in train_graphs: 127 | if nx.faster_could_be_isomorphic(fake_g, train_g): 128 | if nx.fast_could_be_isomorphic(fake_g, train_g): 129 | if nx.could_be_isomorphic(fake_g, train_g): 130 | if nx.is_isomorphic(fake_g, train_g): 131 | count += 1 132 | break 133 | return count / float(len(fake_graphs)) 134 | 135 | def eval_fraction_isomorphic_ego(fake_graphs, train_graphs): 136 | gen_hash = [nx.weisfeiler_lehman_graph_hash(graph) for graph in tqdm(fake_graphs)] 137 | train_hash = [nx.weisfeiler_lehman_graph_hash(graph) for graph in tqdm(train_graphs)] 138 | novel = 1-sum([h in set(gen_hash).intersection(train_hash) for h in gen_hash])/len(gen_hash) 139 | return novel 140 | 141 | def eval_fraction_unique_non_isomorphic_valid_ego(fake_graphs, train_graphs, validity_func = (lambda x: True)): 142 | count_valid = 0 143 | count_isomorphic = 0 144 | count_non_unique = 0 145 | fake_evaluated = [] 146 | gen_hash = [nx.weisfeiler_lehman_graph_hash(graph) for graph in tqdm(fake_graphs)] 147 | train_hash = [nx.weisfeiler_lehman_graph_hash(graph) for graph in tqdm(train_graphs)] 148 | for fake_g in fake_graphs: 149 | unique = True 150 | 151 | for fake_old in fake_evaluated: 152 | if nx.faster_could_be_isomorphic(fake_g, fake_old): 153 | if nx.is_isomorphic(fake_g, fake_old): 154 | count_non_unique += 1 155 | unique = False 156 | break 157 | if unique: 158 | fake_evaluated.append(fake_g) 159 | non_isomorphic = True 160 | for train_g, train_h in zip(train_graphs, train_hash): 161 | if nx.faster_could_be_isomorphic(fake_g, train_g): 162 | if train_h in set(gen_hash): 163 | count_isomorphic += 1 164 | non_isomorphic = False 165 | break 166 | if non_isomorphic: 167 | if validity_func(fake_g): 168 | count_valid += 1 169 | 170 | frac_unique = (float(len(fake_graphs)) - count_non_unique) / float(len(fake_graphs)) # Fraction of distinct isomorphism classes in the fake graphs 171 | frac_unique_non_isomorphic = (float(len(fake_graphs)) - count_non_unique - count_isomorphic) / float(len(fake_graphs)) # Fraction of distinct isomorphism classes in the fake graphs that are not in the training set 172 | frac_unique_non_isomorphic_valid = count_valid / float(len(fake_graphs)) # Fraction of distinct isomorphism classes in the fake graphs that are not in the training set and are valid 173 | return frac_unique, frac_unique_non_isomorphic, frac_unique_non_isomorphic_valid 174 | 175 | def eval_fraction_unique_non_isomorphic_valid(fake_graphs, train_graphs, validity_func = (lambda x: True)): 176 | count_valid = 0 177 | count_isomorphic = 0 178 | count_non_unique = 0 179 | fake_evaluated = [] 180 | for fake_g in fake_graphs: 181 | unique = True 182 | 183 | for fake_old in fake_evaluated: 184 | if nx.faster_could_be_isomorphic(fake_g, fake_old): 185 | if nx.is_isomorphic(fake_g, fake_old): 186 | count_non_unique += 1 187 | unique = False 188 | break 189 | if unique: 190 | fake_evaluated.append(fake_g) 191 | non_isomorphic = True 192 | for train_g in train_graphs: 193 | if nx.faster_could_be_isomorphic(fake_g, train_g): 194 | if nx.is_isomorphic(fake_g, train_g): 195 | count_isomorphic += 1 196 | non_isomorphic = False 197 | break 198 | if non_isomorphic: 199 | if validity_func(fake_g): 200 | count_valid += 1 201 | 202 | frac_unique = (float(len(fake_graphs)) - count_non_unique) / float(len(fake_graphs)) # Fraction of distinct isomorphism classes in the fake graphs 203 | frac_unique_non_isomorphic = (float(len(fake_graphs)) - count_non_unique - count_isomorphic) / float(len(fake_graphs)) # Fraction of distinct isomorphism classes in the fake graphs that are not in the training set 204 | frac_unique_non_isomorphic_valid = count_valid / float(len(fake_graphs)) # Fraction of distinct isomorphism classes in the fake graphs that are not in the training set and are valid 205 | return frac_unique, frac_unique_non_isomorphic, frac_unique_non_isomorphic_valid 206 | 207 | def eval_acc_grid_graph(G_list, grid_start=10, grid_end=20): 208 | count = 0 209 | for gg in G_list: 210 | if is_grid_graph(gg): 211 | count += 1 212 | return count / float(len(G_list)) 213 | 214 | def eval_acc_sbm_graph(G_list, p_intra=0.3, p_inter=0.005, strict=True, refinement_steps=1000, is_parallel=True): 215 | count = 0.0 216 | if is_parallel: 217 | with concurrent.futures.ThreadPoolExecutor() as executor: 218 | for prob in executor.map(is_sbm_graph, 219 | [gg for gg in G_list], [p_intra for i in range(len(G_list))], [p_inter for i in range(len(G_list))], 220 | [strict for i in range(len(G_list))], [refinement_steps for i in range(len(G_list))]): 221 | count += prob 222 | else: 223 | for gg in G_list: 224 | count += is_sbm_graph(gg, p_intra=p_intra, p_inter=p_inter, strict=strict, refinement_steps=refinement_steps) 225 | return count / float(len(G_list)) 226 | 227 | def eval_acc_planar_graph(G_list, grid_start=10, grid_end=20): 228 | count = 0 229 | for gg in G_list: 230 | if is_planar_graph(gg): 231 | count += 1 232 | return count / float(len(G_list)) 233 | 234 | def is_lobster_graph(G): 235 | """ 236 | Check a given graph is a lobster graph or not 237 | 238 | Removing leaf nodes twice: 239 | 240 | lobster -> caterpillar -> path 241 | 242 | """ 243 | ### Check if G is a tree 244 | if nx.is_tree(G): 245 | G = G.copy() 246 | ### Check if G is a path after removing leaves twice 247 | leaves = [n for n, d in G.degree() if d == 1] 248 | G.remove_nodes_from(leaves) 249 | 250 | leaves = [n for n, d in G.degree() if d == 1] 251 | G.remove_nodes_from(leaves) 252 | 253 | num_nodes = len(G.nodes()) 254 | num_degree_one = [d for n, d in G.degree() if d == 1] 255 | num_degree_two = [d for n, d in G.degree() if d == 2] 256 | 257 | if sum(num_degree_one) == 2 and sum(num_degree_two) == 2 * (num_nodes - 2): 258 | return True 259 | elif sum(num_degree_one) == 0 and sum(num_degree_two) == 0: 260 | return True 261 | else: 262 | return False 263 | else: 264 | return False 265 | 266 | def eval_acc_lobster_graph(G_list): 267 | G_list = [copy.deepcopy(gg) for gg in G_list] 268 | count = 0 269 | for gg in G_list: 270 | if is_lobster_graph(gg): 271 | count += 1 272 | return count / float(len(G_list)) -------------------------------------------------------------------------------- /trainer/train_generator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.loggers import WandbLogger 7 | import wandb 8 | from time import gmtime, strftime 9 | import time 10 | from moses.metrics.metrics import get_all_metrics 11 | 12 | from data.dataset import ComDataset, EnzDataset, GridDataset, QM9Dataset, ZINCDataset, PlanarDataset 13 | from data.data_utils import tree_to_adj, adj_to_graph, check_tree_validity, generate_final_tree_red, fix_symmetry, generate_initial_tree_red 14 | from data.mol_utils import adj_to_graph_mol, mols_to_smiles, check_adj_validity_mol, mols_to_nx, fix_symmetry_mol, canonicalize_smiles 15 | from evaluation.evaluation import compute_sequence_accuracy, compute_sequence_cross_entropy, save_graph_list, load_eval_settings, eval_graph_list 16 | from plot import plot_graphs_list 17 | from data.tokens import untokenize 18 | from model.trans_generator import TransGenerator 19 | from data.load_data import generate_string 20 | 21 | 22 | DATA_DIR = "resource" 23 | 24 | class BaseGeneratorLightningModule(pl.LightningModule): 25 | def __init__(self, hparams): 26 | super(BaseGeneratorLightningModule, self).__init__() 27 | hparams = argparse.Namespace(**hparams) if isinstance(hparams, dict) else hparams 28 | self.save_hyperparameters(hparams) 29 | self.setup_datasets(hparams) 30 | self.setup_model(hparams) 31 | self.ts = strftime('%b%d-%H:%M:%S', gmtime()) 32 | wandb.config['ts'] = self.ts 33 | 34 | def setup_datasets(self, hparams): 35 | self.string_type = hparams.string_type 36 | self.order = hparams.order 37 | dataset_cls = { 38 | "GDSS_grid": GridDataset, 39 | "GDSS_com": ComDataset, 40 | "GDSS_enz": EnzDataset, 41 | 'qm9': QM9Dataset, 42 | 'zinc': ZINCDataset, 43 | 'planar': PlanarDataset, 44 | }.get(hparams.dataset_name) 45 | self.train_graphs, _ , self.test_graphs = generate_string(hparams.dataset_name, hparams.order, hparams.k) 46 | self.train_dataset, self.val_dataset, self.test_dataset = [dataset_cls(split, self.string_type, self.order) 47 | for split in ['train', 'val', 'test']] 48 | if hparams.dataset_name in ['qm9', 'zinc']: 49 | with open(f'{DATA_DIR}/{hparams.dataset_name}/{hparams.order}/{hparams.dataset_name}' + f'_smiles_train.txt', 'r') as f: 50 | self.train_smiles = f.readlines()[:100] 51 | self.train_smiles = canonicalize_smiles(self.train_smiles) 52 | with open(f'{DATA_DIR}/{hparams.dataset_name}/{hparams.order}/{hparams.dataset_name}' + f'_smiles_test.txt', 'r') as f: 53 | self.test_smiles = f.readlines()[:100] 54 | self.test_smiles = canonicalize_smiles(self.test_smiles) 55 | self.max_depth = hparams.max_depth 56 | 57 | def setup_model(self, hparams): 58 | self.model = TransGenerator( 59 | num_layers=hparams.num_layers, 60 | emb_size=hparams.emb_size, 61 | nhead=hparams.nhead, 62 | dim_feedforward=hparams.dim_feedforward, 63 | input_dropout=hparams.input_dropout, 64 | dropout=hparams.dropout, 65 | max_len=hparams.max_len, 66 | string_type=hparams.string_type, 67 | tree_pos=hparams.tree_pos, 68 | pos_type=hparams.pos_type, 69 | learn_pos=hparams.learn_pos, 70 | abs_pos=hparams.abs_pos, 71 | k=hparams.k 72 | ) 73 | 74 | def configure_optimizers(self): 75 | optimizer = torch.optim.AdamW( 76 | self.parameters(), 77 | lr=self.hparams.lr, 78 | ) 79 | 80 | return [optimizer] 81 | 82 | ### Main steps 83 | def shared_step(self, batched_data): 84 | loss, statistics = 0.0, dict() 85 | # decoding 86 | logits = self.model(batched_data) 87 | loss = compute_sequence_cross_entropy(logits, batched_data, ignore_index=0) 88 | statistics["loss/total"] = loss 89 | statistics["acc/total"] = compute_sequence_accuracy(logits, batched_data, ignore_index=0)[0] 90 | 91 | return loss, statistics 92 | 93 | def training_step(self, batched_data, batch_idx): 94 | loss, statistics = self.shared_step(batched_data) 95 | for key, val in statistics.items(): 96 | # self.log(f"train/{key}", val, on_step=True, logger=True) 97 | wandb.log({f"train/{key}": val}) 98 | return loss 99 | 100 | def validation_step(self, batched_data, batch_idx): 101 | loss, statistics = self.shared_step(batched_data) 102 | for key, val in statistics.items(): 103 | wandb.log({f"val/{key}": val}) 104 | self.log(f"val/{key}", val, on_step=False, on_epoch=True, logger=True) 105 | pass 106 | 107 | def validation_epoch_end(self, outputs): 108 | if (self.current_epoch + 1) % self.hparams.check_sample_every_n_epoch == 0: 109 | self.check_samples() 110 | 111 | def check_samples(self): 112 | num_samples = self.hparams.num_samples if not self.trainer.sanity_checking else 2 113 | string_list, org_string_list, generation_time = self.sample(num_samples) 114 | wandb.log({"time": round(generation_time, 3)}) 115 | 116 | 117 | if not self.trainer.sanity_checking: 118 | valid_string_list = [string for string in string_list if len(string)>0] 119 | sampled_trees = [generate_initial_tree_red(string, self.hparams.k) for string in valid_string_list] 120 | valid_sampled_trees = [tree for tree in sampled_trees if check_tree_validity(tree)] 121 | valid_sampled_trees = [generate_final_tree_red(tree, self.hparams.k) for tree in tqdm(valid_sampled_trees, "Sampling: converting string to tree")] 122 | 123 | wandb.log({"validity": len(valid_string_list)/len(string_list)}) 124 | # write down string 125 | 126 | # for molecular dataset 127 | if self.hparams.string_type in ['zinc-red', 'qm9-red']: 128 | # valid_sampled_trees = sampled_trees[:len(self.test_graphs)] 129 | adjs = [fix_symmetry_mol(tree_to_adj(tree)).numpy() for tree in tqdm(valid_sampled_trees, "Sampling: converting tree into adj")] 130 | valid_adjs = [valid_adj for valid_adj in [check_adj_validity_mol(adj) for adj in adjs] if valid_adj is not None] 131 | mols_no_correct = [adj_to_graph_mol(adj) for adj in valid_adjs] 132 | mols_no_correct = [elem for elem in mols_no_correct if elem[0] is not None] 133 | mols = [elem[0] for elem in mols_no_correct] 134 | no_corrects = [elem[1] for elem in mols_no_correct] 135 | num_mols = len(mols) 136 | gen_smiles = mols_to_smiles(mols) 137 | gen_smiles = [smi for smi in gen_smiles if len(smi)] 138 | table = wandb.Table(columns=['SMILES']) 139 | for s in gen_smiles: 140 | table.add_data(s) 141 | wandb.log({'SMILES': table}) 142 | save_dir = f'{self.hparams.dataset_name}/{self.ts}' 143 | scores_nspdk = eval_graph_list(self.test_graphs, mols_to_nx(mols), methods=['nspdk'])['nspdk'] 144 | with open(f'samples/smiles/{save_dir}.txt', 'w') as f: 145 | for smiles in gen_smiles: 146 | f.write(f'{smiles}\n') 147 | scores = get_all_metrics(gen=gen_smiles, device=self.device, n_jobs=8, test=self.test_smiles, train=self.train_smiles, k=len(gen_smiles)) 148 | 149 | metrics_dict = scores 150 | metrics_dict['unique'] = scores[f'unique@{len(gen_smiles)}'] 151 | del metrics_dict[f'unique@{len(gen_smiles)}'] 152 | metrics_dict['NSPDK'] = scores_nspdk 153 | metrics_dict['validity_wo_cor'] = sum(no_corrects) / num_mols 154 | wandb.log(metrics_dict) 155 | # for generic graph dataset 156 | else: 157 | table = wandb.Table(columns=['Orginal', 'String', 'Validity']) 158 | if 'red' in self.hparams.string_type: 159 | org_string_list = [''.join(org) for org in org_string_list] 160 | string_list = [''.join(s) for s in string_list] 161 | for org_string, string in zip(org_string_list, string_list): 162 | table.add_data(org_string, string, (len(string)>0 and len(string)%4 == 0)) 163 | wandb.log({'strings': table}) 164 | if len(sampled_trees) > 0: 165 | tree_validity = len(valid_sampled_trees) / len(sampled_trees) 166 | else: 167 | tree_validity = 0 168 | wandb.log({"tree-validity": tree_validity}) 169 | 170 | adjs = [fix_symmetry(tree_to_adj(tree, self.hparams.k)).numpy() for tree in tqdm(valid_sampled_trees[:2*len(self.test_graphs)], "Sampling: converting tree into graph")] 171 | adjs = [adj for adj in adjs if adj is not None] 172 | sampled_graphs = [adj_to_graph(adj) for adj in adjs[:len(self.test_graphs)]] 173 | save_graph_list(self.hparams.dataset_name, self.ts, sampled_graphs, valid_string_list, string_list, org_string_list) 174 | plot_dir = f'{self.hparams.dataset_name}/{self.ts}' 175 | plot_graphs_list(sampled_graphs, save_dir=plot_dir) 176 | wandb.log({"samples": wandb.Image(f'./samples/fig/{plot_dir}/title.png')}) 177 | 178 | # GDSS evaluation 179 | gen_graphs = sampled_graphs[:len(self.test_graphs)] 180 | methods, kernels = load_eval_settings('') 181 | if len(sampled_graphs) == 0: 182 | mmd_results = {'degree': np.nan, 'orbit': np.nan, 'cluster': np.nan, 'spectral': np.nan} 183 | else: 184 | mmd_results = eval_graph_list(self.test_graphs, gen_graphs, methods=methods, kernels=kernels) 185 | wandb.log(mmd_results) 186 | 187 | # SPECTRE evaluation 188 | 189 | 190 | 191 | def sample(self, num_samples): 192 | offset = 0 193 | string_list = [] 194 | org_string_list = [] 195 | while offset < num_samples: 196 | cur_num_samples = min(num_samples - offset, self.hparams.sample_batch_size) 197 | offset += cur_num_samples 198 | 199 | self.model.eval() 200 | with torch.no_grad(): 201 | t0 = time.perf_counter() 202 | sequences = self.model.decode(cur_num_samples, max_len=self.hparams.max_len, device=self.device) 203 | generation_time = time.perf_counter() - t0 204 | print(round(generation_time, 3)) 205 | 206 | strings = [untokenize(sequence, self.hparams.string_type, self.hparams.k)[0] for sequence in sequences.tolist()] 207 | org_strings = [untokenize(sequence, self.hparams.string_type, self.hparams.k)[1] for sequence in sequences.tolist()] 208 | string_list.extend(strings) 209 | org_string_list.extend(org_strings) 210 | org_string_list = [] 211 | 212 | return string_list, org_string_list, generation_time 213 | 214 | @staticmethod 215 | def add_args(parser): 216 | 217 | 218 | return parser 219 | 220 | 221 | if __name__ == "__main__": 222 | 223 | parser = argparse.ArgumentParser() 224 | BaseGeneratorLightningModule.add_args(parser) 225 | 226 | 227 | hparams = parser.parse_args() 228 | wandb_logger = WandbLogger(name=f'{hparams.dataset_name}', project='k2g', 229 | group=f'{hparams.group}', mode=f'{hparams.wandb_on}') 230 | wandb.config.update(hparams) 231 | 232 | model = BaseGeneratorLightningModule(hparams) 233 | wandb.watch(model) 234 | 235 | trainer = pl.Trainer( 236 | gpus=1, 237 | default_root_dir="../resource/log/", 238 | max_epochs=hparams.max_epochs, 239 | logger=wandb_logger 240 | ) 241 | trainer.fit(model) -------------------------------------------------------------------------------- /model/trans_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Categorical 4 | import math 5 | from tqdm import tqdm 6 | import numpy as np 7 | from collections import deque 8 | from torch.nn.functional import pad 9 | import math 10 | import re 11 | import time 12 | 13 | from data.tokens import PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, TOKENS_DICT, token_to_id, id_to_token, TOKENS_GROUP_THREE 14 | 15 | 16 | # helper Module to convert tensor of input indices into corresponding tensor of token embeddings 17 | class TokenEmbedding(nn.Module): 18 | def __init__(self, vocab_size, emb_size, learn_pos, max_len): 19 | super(TokenEmbedding, self).__init__() 20 | self.embedding = nn.Embedding(vocab_size, emb_size) 21 | self.emb_size = emb_size 22 | self.learn_pos = learn_pos 23 | # max_len+2: for eos / bos token 24 | self.positional_embedding = nn.Parameter(torch.randn([1, max_len+2, emb_size])) 25 | 26 | def forward(self, tokens): 27 | x = self.embedding(tokens.long()) * math.sqrt(self.emb_size) 28 | x_batch_size = x.shape[0] 29 | x_seq_len = x.shape[1] 30 | if self.learn_pos: 31 | pe = self.positional_embedding[:,:x_seq_len] 32 | pe_stack = torch.tile(pe, (x_batch_size, 1, 1)) 33 | return x+pe_stack 34 | return x 35 | 36 | class AbsolutePositionalEncoding(nn.Module): 37 | def __init__(self, d_model, max_len=8000): 38 | super().__init__() 39 | position = torch.arange(max_len).unsqueeze(1) 40 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 41 | pe = torch.zeros(max_len, 1, d_model) 42 | pe[:, 0, 0::2] = torch.sin(position * div_term) 43 | pe[:, 0, 1::2] = torch.cos(position * div_term) 44 | self.register_buffer('pe', pe) 45 | 46 | def forward(self, x): 47 | # x: shape [batch_size, seq_len, emb_size (pad_vocab_size)] 48 | x = x + self.pe[:x.size(1), :].transpose(0,1) 49 | return x 50 | 51 | class TreePositionalEncoding(nn.Module): 52 | def __init__(self, d_model, token2id, pos_type, max_len, k=2): 53 | super().__init__() 54 | self.d_model = d_model 55 | self.pos_type = pos_type 56 | l = range(k**2+1) 57 | position_dict = {str(key): np.zeros(len(l)-1, dtype=int) for key in l} 58 | for key, value in position_dict.items(): 59 | position_dict[key][int(key)-1] = 1 60 | position_dict['0'][-1] = 0 61 | position_dict = {key: tuple(value) for key, value in position_dict.items()} 62 | 63 | self.k = k 64 | self.k_square = k**2 65 | # self.pos_dict = {'0': (0,0,0,0), '1': (1,0,0,0), '2': (0,1,0,0), '3': (0,0,1,0), '4': (0,0,0,1)} 66 | self.pos_dict = position_dict 67 | self.token2id = token2id 68 | self.bos = self.token2id[BOS_TOKEN] 69 | self.eos = self.token2id[EOS_TOKEN] 70 | self.pad = self.token2id[PAD_TOKEN] 71 | if 'group' in self.pos_type: 72 | self.max_pe_length = int(math.log(max_len, self.k_square)+5) 73 | else: 74 | self.max_pe_length = int(math.log(max_len, self.k_square)+5)*(k**2) 75 | self.positional_embedding = nn.Linear(self.max_pe_length, self.d_model) 76 | self.padder = torch.nn.ReplicationPad2d((0,self.d_model-self.max_pe_length,0,0)) 77 | is_group_dict = {'emb': False, 'group-emb': True, 'pad': False, 'group-pad': True} 78 | self.is_group = is_group_dict[self.pos_type] 79 | 80 | def filter_row_string(self, row_string): 81 | # filter bos to eos 82 | l = row_string.tolist() 83 | try: 84 | return l[1:l.index(self.eos)] 85 | except ValueError: 86 | return l[1:] 87 | 88 | def get_pe(self, row_string): 89 | l = self.filter_row_string(row_string) 90 | if len(l) == 0: 91 | return torch.zeros((1,1)) 92 | elif len(l) > 3: 93 | string = l[0:self.k_square] 94 | pos_list = list(range(1,self.k_square+1)) 95 | else: 96 | string = l[0:len(l)] 97 | pos_list = list(range(1,len(l)+1)) 98 | 99 | str_pos_queue = deque([(s, p) for s, p in zip(string, pos_list)]) 100 | for i in np.arange(self.k_square,len(l),self.k_square): 101 | cur_string = l[i:i+self.k_square] 102 | cur_parent, cur_parent_pos = str_pos_queue.popleft() 103 | # if value is 0, it cannot be parent node -> skip 104 | while((cur_parent == self.token2id['0']) and (len(str_pos_queue) > 0)): 105 | cur_parent, cur_parent_pos = str_pos_queue.popleft() 106 | # i: order of the child node in the same parent 107 | cur_pos = [cur_parent_pos*10+i for i in range(1,1+len(cur_string))] 108 | # pos_list: final position of each node 109 | pos_list.extend(cur_pos) 110 | str_pos_queue.extend([(s, c) for s, c in zip(cur_string, cur_pos)]) 111 | # map position vector to each position 112 | reverse_pos_list = [str(pos)[::-1] for pos in pos_list] 113 | tensor_pos_list = [self.map_pos_to_tensor(pos) for pos in reverse_pos_list] 114 | max_size = len(tensor_pos_list[-1]) 115 | final_pos_list = [pad(pos, (0,max_size-len(pos))) for pos in tensor_pos_list] 116 | 117 | # return shape: sequence len * pe size 118 | return torch.stack(final_pos_list) 119 | 120 | def pe(self, pe_tensor, emb_size): 121 | return self.positional_embedding(pe_tensor) * math.sqrt(emb_size) 122 | 123 | def map_pos_to_tensor(self, pos): 124 | result = [] 125 | for char in pos: 126 | result.extend(self.pos_dict[char]) 127 | 128 | return torch.tensor(result) 129 | 130 | def map_pos_to_tensor_group(self, pos): 131 | return torch.tensor([eval(p) for p in pos]) 132 | 133 | def finalize_pe(self, pos_list, is_group): 134 | if is_group: 135 | int_pos_list = [list(str(pos)) for pos in pos_list] 136 | tensor_pos_list = [self.map_pos_to_tensor_group(pos) for pos in int_pos_list] 137 | else: 138 | tensor_pos_list = [self.map_pos_to_tensor(str(pos)) for pos in pos_list] 139 | max_size = len(tensor_pos_list[-1]) 140 | final_pos_list = [pad(pos, (0,max_size-len(pos))) for pos in tensor_pos_list] 141 | return torch.stack(final_pos_list) 142 | 143 | def forward(self, x, org_x): 144 | # x shape: batch size * sequence len * emb size 145 | if x.size(1) == 1: 146 | return x 147 | # pe shape: sequence len * pe size 148 | pos_lists = [self.get_pe(x_string) for x_string in org_x] 149 | pe_list = [self.finalize_pe(pos_list, is_group=self.is_group) for pos_list in pos_lists] 150 | max_str_length = x.shape[1] 151 | pe_tensor = torch.stack([pad(pe, (0,self.max_pe_length-pe.shape[1],1,max_str_length-pe.shape[0]-1)) 152 | for pe in pe_list]).to(x.device).float() 153 | if 'pad' in self.pos_type: 154 | pe = self.padder(pe_tensor).to(x.device) 155 | elif 'emb' in self.pos_type: 156 | pe = self.pe(pe_tensor, self.d_model) 157 | 158 | x += pe 159 | return x 160 | 161 | 162 | class GroupTreePositionalEncoding(TreePositionalEncoding): 163 | def __init__(self, d_model, token2id, id2token, pos_type, max_len, k): 164 | super().__init__(d_model, token2id, pos_type, max_len, k) 165 | self.id2token = id2token 166 | 167 | def map_string_to_sum(self, raw_string): 168 | # group position (1,1,0,1,1,1 parent -> 11, 12, 14, 15, 16) 169 | string = self.id2token[raw_string] 170 | result = [char.start()+1 for char in re.finditer(r'(?!0).', string)] 171 | result.append(0) 172 | return result 173 | 174 | def get_pe(self, row_string): 175 | l = self.filter_row_string(row_string) 176 | if (len(l) == 0) or (self.pad in l) or (self.bos in l): 177 | return [0] 178 | else: 179 | str_queue = deque(self.map_string_to_sum(l[0])) 180 | pos_list = [1] 181 | 182 | pos_queue = deque(pos_list) 183 | cur_parent_pos = pos_queue.popleft() 184 | 185 | for i in range(1,len(l)): 186 | tree_index = str_queue.popleft() 187 | cur_string = l[i] 188 | if tree_index == 0: 189 | if len(str_queue) == 0: 190 | break 191 | tree_index = str_queue.popleft() 192 | cur_parent_pos = pos_queue.popleft() 193 | 194 | cur_pos = cur_parent_pos*10+tree_index 195 | pos_list.append(cur_pos) 196 | pos_queue.append(cur_pos) 197 | str_queue.extend(self.map_string_to_sum(cur_string)) 198 | 199 | return pos_list 200 | 201 | 202 | class TransGenerator(nn.Module): 203 | ''' 204 | without tree information (only string) 205 | ''' 206 | 207 | def __init__( 208 | self, num_layers, emb_size, nhead, dim_feedforward, 209 | input_dropout, dropout, max_len, string_type, tree_pos, pos_type, learn_pos, abs_pos, k 210 | ): 211 | super(TransGenerator, self).__init__() 212 | self.nhead = nhead 213 | self.tokens = TOKENS_DICT[string_type] 214 | self.ID2TOKEN = id_to_token(self.tokens) 215 | self.string_type = string_type 216 | self.k = k 217 | self.TOKEN2ID = token_to_id(self.string_type, self.k) 218 | self.tree_pos = tree_pos 219 | self.pos_type = pos_type 220 | self.learn_pos = learn_pos 221 | self.abs_pos = abs_pos 222 | self.max_len = max_len 223 | 224 | 225 | if self.abs_pos: 226 | self.positional_encoding = AbsolutePositionalEncoding(emb_size) 227 | 228 | if string_type in ['group', 'bfs-deg-group', 'qm9', 'zinc', 'group-red', 'qm9-red', 'zinc-red', 'group-red-3']: 229 | if self.tree_pos: 230 | self.positional_encoding = GroupTreePositionalEncoding(emb_size, self.TOKEN2ID, self.ID2TOKEN, self.pos_type, self.max_len, self.k) 231 | else: 232 | if self.tree_pos: 233 | self.positional_encoding = TreePositionalEncoding(emb_size, self.TOKEN2ID, self.pos_type, self.max_len, self.k) 234 | # 235 | self.token_embedding_layer = TokenEmbedding(len(self.tokens), emb_size, self.learn_pos, self.max_len) 236 | self.input_dropout = nn.Dropout(input_dropout) 237 | 238 | # 239 | self.distance_embedding_layer = nn.Embedding(max_len + 1, nhead) 240 | 241 | # 242 | encoder_layer = nn.TransformerEncoderLayer(emb_size, nhead, dim_feedforward, dropout, "gelu") 243 | encoder_norm = nn.LayerNorm(emb_size) 244 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers, encoder_norm) 245 | 246 | # 247 | self.generator = nn.Linear(emb_size, len(self.tokens)) 248 | # self.node_type_generator = nn.Linear(emb_size, len()) 249 | 250 | def forward(self, sequences): 251 | batch_size = sequences.size(0) 252 | sequence_len = sequences.size(1) 253 | TOKEN2ID = token_to_id(self.string_type, self.k) 254 | # 255 | out = self.token_embedding_layer(sequences) 256 | if self.tree_pos: 257 | out = self.positional_encoding(out, sequences) 258 | if self.abs_pos: 259 | out = self.positional_encoding(out) 260 | out = self.input_dropout(out) 261 | 262 | 263 | if self.tree_pos or self.abs_pos: 264 | mask = torch.zeros(batch_size, sequence_len, sequence_len, self.nhead, device=out.device) 265 | else: 266 | # relational positional encoding 267 | distance_squares = torch.abs(torch.arange(sequence_len).unsqueeze(0) - torch.arange(sequence_len).unsqueeze(1)) 268 | distance_squares[distance_squares > self.max_len] = self.max_len 269 | distance_squares = distance_squares.unsqueeze(0).repeat(batch_size, 1, 1) 270 | distance_squares = distance_squares.to(out.device) 271 | mask = self.distance_embedding_layer(distance_squares) 272 | 273 | mask = mask.permute(0, 3, 1, 2) 274 | 275 | # 276 | bool_mask = (torch.triu(torch.ones((sequence_len, sequence_len))) == 1).transpose(0, 1) 277 | bool_mask = bool_mask.view(1, 1, sequence_len, sequence_len).repeat(batch_size, self.nhead, 1, 1).to(out.device) 278 | mask = mask.masked_fill(bool_mask == 0, float("-inf")) 279 | mask = mask.reshape(-1, sequence_len, sequence_len) 280 | 281 | # 282 | 283 | key_padding_mask = sequences == TOKEN2ID[PAD_TOKEN] 284 | 285 | out = out.transpose(0, 1) 286 | out = self.transformer(out, mask, key_padding_mask) 287 | out = out.transpose(0, 1) 288 | 289 | # 290 | logits = self.generator(out) 291 | return logits 292 | 293 | def decode(self, num_samples, max_len, device): 294 | TOKEN2ID = token_to_id(self.string_type, self.k) 295 | sequences = torch.LongTensor([[TOKEN2ID[BOS_TOKEN]] for _ in range(num_samples)]).to(device) 296 | ended = torch.tensor([False for _ in range(num_samples)], dtype=torch.bool).to(device) 297 | for _ in tqdm(range(max_len), "generation"): 298 | if ended.all(): 299 | break 300 | logits = self(sequences) 301 | preds = Categorical(logits=logits[:, -1]).sample() 302 | preds[ended] = TOKEN2ID[PAD_TOKEN] 303 | sequences = torch.cat([sequences, preds.unsqueeze(1)], dim=1) 304 | 305 | ended = torch.logical_or(ended, preds == TOKEN2ID[EOS_TOKEN]) 306 | 307 | return sequences -------------------------------------------------------------------------------- /evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import pickle 5 | import numpy as np 6 | from scipy.linalg import toeplitz 7 | import pyemd 8 | import concurrent.futures 9 | from datetime import datetime 10 | from scipy.linalg import eigvalsh 11 | import networkx as nx 12 | from functools import partial 13 | import random 14 | import subprocess as sp 15 | from eden.graph import vectorize 16 | from sklearn.metrics.pairwise import pairwise_kernels 17 | # import graph_tool.all as gt 18 | # from scipy.stats import chi2 19 | 20 | from data.tokens import TOKENS_DICT 21 | 22 | 23 | def save_graph_list(log_folder_name, exp_name, gen_graph_list, valid_string_list, string_list, org_string_list): 24 | if not(os.path.isdir(f'./samples/graphs/{log_folder_name}')): 25 | os.makedirs(os.path.join(f'./samples/graphs/{log_folder_name}')) 26 | if not(os.path.isdir(f'./samples/string/{log_folder_name}')): 27 | os.makedirs(os.path.join(f'./samples/string/{log_folder_name}')) 28 | with open(f'./samples/graphs/{log_folder_name}/{exp_name}.pkl', 'wb') as f: 29 | pickle.dump(obj=gen_graph_list, file=f, protocol=pickle.HIGHEST_PROTOCOL) 30 | with open(f'./samples/string//{log_folder_name}/{exp_name}.txt', 'w') as f : 31 | for smiles in string_list: 32 | f.write("%s\n" %smiles) 33 | with open(f'./samples/string//{log_folder_name}/{exp_name}_val.txt', 'w') as f : 34 | for smiles in valid_string_list: 35 | f.write("%s\n" %smiles) 36 | with open(f'./samples/string//{log_folder_name}/{exp_name}_org.txt', 'w') as f : 37 | for smiles in org_string_list: 38 | f.write("%s\n" %smiles) 39 | save_dir = f'./samples/graphs/{log_folder_name}/{exp_name}.pkl' 40 | return save_dir 41 | 42 | def compute_sequence_accuracy(logits, batched_sequence_data, ignore_index=0): 43 | batch_size = batched_sequence_data.size(0) 44 | targets = batched_sequence_data.squeeze() 45 | 46 | preds = torch.argmax(logits, dim=-1) 47 | 48 | correct = preds == targets 49 | correct[targets == ignore_index] = True 50 | elem_acc = correct[targets != 0].float().mean() 51 | sequence_acc = correct.view(batch_size, -1).all(dim=1).float().mean() 52 | 53 | return elem_acc, sequence_acc 54 | 55 | def compute_sequence_cross_entropy(logits, batched_sequence_data, string_type): 56 | logits = logits[:,:-1] 57 | targets = batched_sequence_data[:,1:] 58 | weight_vector = [0,0] 59 | tokens = TOKENS_DICT[string_type] 60 | weight_vector.extend([1/(len(tokens)-2) for _ in range(len(tokens)-2)]) 61 | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), 62 | weight=torch.FloatTensor(weight_vector).to(logits.device)) 63 | return loss 64 | 65 | def process_tensor(x, y): 66 | support_size = max(len(x), len(y)) 67 | if len(x) < len(y): 68 | x = np.hstack((x, [0.0] * (support_size - len(x)))) 69 | elif len(y) < len(x): 70 | y = np.hstack((y, [0.0] * (support_size - len(y)))) 71 | return x, y 72 | 73 | def emd(x, y, distance_scaling=1.0): 74 | # -------- convert histogram values x and y to float, and make them equal len -------- 75 | x = x.astype(np.float) 76 | y = y.astype(np.float) 77 | support_size = max(len(x), len(y)) 78 | # -------- diagonal-constant matrix -------- 79 | d_mat = toeplitz(range(support_size)).astype(np.float) 80 | distance_mat = d_mat / distance_scaling 81 | x, y = process_tensor(x, y) 82 | 83 | emd_value = pyemd.emd(x, y, distance_mat) 84 | return np.abs(emd_value) 85 | 86 | def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0): 87 | """ Gaussian kernel with squared distance in exponential term replaced by EMD 88 | Args: 89 | x, y: 1D pmf of two distributions with the same support 90 | sigma: standard deviation 91 | """ 92 | emd_value = emd(x, y, distance_scaling) 93 | return np.exp(-emd_value * emd_value / (2 * sigma * sigma)) 94 | 95 | def gaussian(x, y, sigma=1.0): 96 | x = x.astype(np.float) 97 | y = y.astype(np.float) 98 | x, y = process_tensor(x, y) 99 | dist = np.linalg.norm(x - y, 2) 100 | return np.exp(-dist * dist / (2 * sigma * sigma)) 101 | 102 | 103 | def load_eval_settings(data, orbit_on=True): 104 | # Settings for generic graph generation 105 | methods = ['degree', 'cluster', 'orbit', 'spectral'] 106 | kernels = {'degree':gaussian_emd, 107 | 'cluster':gaussian_emd, 108 | 'orbit':gaussian, 109 | 'spectral':gaussian_emd} 110 | return methods, kernels 111 | 112 | def kernel_parallel_worker(t): 113 | return kernel_parallel_unpacked(*t) 114 | 115 | def kernel_parallel_unpacked(x, samples2, kernel): 116 | d = 0 117 | for s2 in samples2: 118 | d += kernel(x, s2) 119 | return d 120 | 121 | def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs): 122 | """ Discrepancy between 2 samples 123 | """ 124 | d = 0 125 | if not is_parallel: 126 | for s1 in samples1: 127 | for s2 in samples2: 128 | d += kernel(s1, s2, *args, **kwargs) 129 | else: 130 | with concurrent.futures.ProcessPoolExecutor() as executor: 131 | for dist in executor.map(kernel_parallel_worker, 132 | [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1]): 133 | d += dist 134 | d /= len(samples1) * len(samples2) 135 | return d 136 | 137 | def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs): 138 | """ MMD between two samples 139 | """ 140 | # -------- normalize histograms into pmf -------- 141 | if is_hist: 142 | samples1 = [s1 / np.sum(s1) for s1 in samples1] 143 | samples2 = [s2 / np.sum(s2) for s2 in samples2] 144 | return disc(samples1, samples1, kernel, *args, **kwargs) + \ 145 | disc(samples2, samples2, kernel, *args, **kwargs) - \ 146 | 2 * disc(samples1, samples2, kernel, *args, **kwargs) 147 | 148 | def degree_worker(G): 149 | return np.array(nx.degree_histogram(G)) 150 | 151 | PRINT_TIME = False 152 | 153 | # -------- Compute degree MMD -------- 154 | def degree_stats(graph_ref_list, graph_pred_list, KERNEL=gaussian_emd, is_parallel=True): 155 | ''' Compute the distance between the degree distributions of two unordered sets of graphs. 156 | Args: 157 | graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated 158 | ''' 159 | sample_ref = [] 160 | sample_pred = [] 161 | # -------- in case an empty graph is generated -------- 162 | graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] 163 | 164 | prev = datetime.now() 165 | if is_parallel: 166 | with concurrent.futures.ThreadPoolExecutor() as executor: 167 | for deg_hist in executor.map(degree_worker, graph_ref_list): 168 | sample_ref.append(deg_hist) 169 | with concurrent.futures.ThreadPoolExecutor() as executor: 170 | for deg_hist in executor.map(degree_worker, graph_pred_list_remove_empty): 171 | sample_pred.append(deg_hist) 172 | 173 | else: 174 | for i in range(len(graph_ref_list)): 175 | degree_temp = np.array(nx.degree_histogram(graph_ref_list[i])) 176 | sample_ref.append(degree_temp) 177 | for i in range(len(graph_pred_list_remove_empty)): 178 | degree_temp = np.array(nx.degree_histogram(graph_pred_list_remove_empty[i])) 179 | sample_pred.append(degree_temp) 180 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=KERNEL) 181 | elapsed = datetime.now() - prev 182 | if PRINT_TIME: 183 | print('Time computing degree mmd: ', elapsed) 184 | return mmd_dist 185 | 186 | 187 | def spectral_worker(G): 188 | eigs = eigvalsh(nx.normalized_laplacian_matrix(G).todense()) 189 | spectral_pmf, _ = np.histogram(eigs, bins=200, range=(-1e-5, 2), density=False) 190 | spectral_pmf = spectral_pmf / spectral_pmf.sum() 191 | return spectral_pmf 192 | 193 | 194 | # -------- Compute spectral MMD -------- 195 | def spectral_stats(graph_ref_list, graph_pred_list, KERNEL=gaussian_emd, is_parallel=True): 196 | ''' Compute the distance between the degree distributions of two unordered sets of graphs. 197 | Args: 198 | graph_ref_list, graph_target_list: two lists of networkx graphs to be evaluated 199 | ''' 200 | sample_ref = [] 201 | sample_pred = [] 202 | graph_pred_list_remove_empty = [ 203 | G for G in graph_pred_list if not G.number_of_nodes() == 0 204 | ] 205 | 206 | prev = datetime.now() 207 | if is_parallel: 208 | with concurrent.futures.ThreadPoolExecutor() as executor: 209 | for spectral_density in executor.map(spectral_worker, graph_ref_list): 210 | sample_ref.append(spectral_density) 211 | with concurrent.futures.ThreadPoolExecutor() as executor: 212 | for spectral_density in executor.map(spectral_worker, graph_pred_list_remove_empty): 213 | sample_pred.append(spectral_density) 214 | else: 215 | for i in range(len(graph_ref_list)): 216 | spectral_temp = spectral_worker(graph_ref_list[i]) 217 | sample_ref.append(spectral_temp) 218 | for i in range(len(graph_pred_list_remove_empty)): 219 | spectral_temp = spectral_worker(graph_pred_list_remove_empty[i]) 220 | sample_pred.append(spectral_temp) 221 | 222 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=KERNEL) 223 | 224 | elapsed = datetime.now() - prev 225 | if PRINT_TIME: 226 | print('Time computing degree mmd: ', elapsed) 227 | return mmd_dist 228 | 229 | 230 | def clustering_worker(param): 231 | G, bins = param 232 | clustering_coeffs_list = list(nx.clustering(G).values()) 233 | hist, _ = np.histogram( 234 | clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 235 | return hist 236 | 237 | 238 | # -------- Compute clustering coefficients MMD -------- 239 | def clustering_stats(graph_ref_list, graph_pred_list, KERNEL=gaussian, bins=100, is_parallel=True): 240 | sample_ref = [] 241 | sample_pred = [] 242 | graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] 243 | 244 | prev = datetime.now() 245 | if is_parallel: 246 | with concurrent.futures.ThreadPoolExecutor() as executor: 247 | for clustering_hist in executor.map(clustering_worker, 248 | [(G, bins) for G in graph_ref_list]): 249 | sample_ref.append(clustering_hist) 250 | with concurrent.futures.ThreadPoolExecutor() as executor: 251 | for clustering_hist in executor.map(clustering_worker, 252 | [(G, bins) for G in graph_pred_list_remove_empty]): 253 | sample_pred.append(clustering_hist) 254 | else: 255 | for i in range(len(graph_ref_list)): 256 | clustering_coeffs_list = list(nx.clustering(graph_ref_list[i]).values()) 257 | hist, _ = np.histogram( 258 | clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 259 | sample_ref.append(hist) 260 | 261 | for i in range(len(graph_pred_list_remove_empty)): 262 | clustering_coeffs_list = list(nx.clustering(graph_pred_list_remove_empty[i]).values()) 263 | hist, _ = np.histogram( 264 | clustering_coeffs_list, bins=bins, range=(0.0, 1.0), density=False) 265 | sample_pred.append(hist) 266 | try: 267 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=KERNEL, 268 | sigma=1.0 / 10, distance_scaling=bins) 269 | except: 270 | mmd_dist = compute_mmd(sample_ref, sample_pred, kernel=KERNEL, sigma=1.0 / 10) 271 | elapsed = datetime.now() - prev 272 | if PRINT_TIME: 273 | print('Time computing clustering mmd: ', elapsed) 274 | return mmd_dist 275 | 276 | ORCA_DIR = 'evaluation/orca' 277 | COUNT_START_STR = 'orbit counts: \n' 278 | 279 | 280 | def edge_list_reindexed(G): 281 | idx = 0 282 | id2idx = dict() 283 | for u in G.nodes(): 284 | id2idx[str(u)] = idx 285 | idx += 1 286 | 287 | edges = [] 288 | for (u, v) in G.edges(): 289 | edges.append((id2idx[str(u)], id2idx[str(v)])) 290 | return edges 291 | 292 | 293 | def orca(graph): 294 | tmp_file_path = os.path.join(ORCA_DIR, f'tmp-{random.random():.4f}.txt') 295 | f = open(tmp_file_path, 'w') 296 | f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\n') 297 | for (u, v) in edge_list_reindexed(graph): 298 | f.write(str(u) + ' ' + str(v) + '\n') 299 | f.close() 300 | 301 | output = sp.check_output([os.path.join(ORCA_DIR, 'orca'), 'node', '4', tmp_file_path, 'std']) 302 | output = output.decode('utf8').strip() 303 | 304 | idx = output.find(COUNT_START_STR) + len(COUNT_START_STR) 305 | output = output[idx:] 306 | node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' '))) 307 | for node_cnts in output.strip('\n').split('\n')]) 308 | 309 | try: 310 | os.remove(tmp_file_path) 311 | except OSError: 312 | pass 313 | 314 | return node_orbit_counts 315 | 316 | def orbit_stats_all(graph_ref_list, graph_pred_list, KERNEL=gaussian): 317 | total_counts_ref = [] 318 | total_counts_pred = [] 319 | 320 | prev = datetime.now() 321 | 322 | for G in graph_ref_list: 323 | try: 324 | orbit_counts = orca(G) 325 | except Exception as e: 326 | print(e) 327 | continue 328 | orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes() 329 | total_counts_ref.append(orbit_counts_graph) 330 | 331 | for G in graph_pred_list: 332 | try: 333 | orbit_counts = orca(G) 334 | except: 335 | print('orca failed') 336 | continue 337 | orbit_counts_graph = np.sum(orbit_counts, axis=0) / G.number_of_nodes() 338 | total_counts_pred.append(orbit_counts_graph) 339 | 340 | total_counts_ref = np.array(total_counts_ref) 341 | total_counts_pred = np.array(total_counts_pred) 342 | mmd_dist = compute_mmd(total_counts_ref, total_counts_pred, kernel=KERNEL, 343 | is_hist=False, sigma=30.0) 344 | 345 | elapsed = datetime.now() - prev 346 | if PRINT_TIME: 347 | print('Time computing orbit mmd: ', elapsed) 348 | return mmd_dist 349 | 350 | ### code adapted from https://github.com/idea-iitd/graphgen/blob/master/metrics/mmd.py 351 | def compute_nspdk_mmd(samples1, samples2, metric, is_hist=True, n_jobs=None): 352 | def kernel_compute(X, Y=None, is_hist=True, metric='linear', n_jobs=None): 353 | X = vectorize(X, complexity=4, discrete=True) 354 | if Y is not None: 355 | Y = vectorize(Y, complexity=4, discrete=True) 356 | return pairwise_kernels(X, Y, metric='linear', n_jobs=n_jobs) 357 | 358 | X = kernel_compute(samples1, is_hist=is_hist, metric=metric, n_jobs=n_jobs) 359 | Y = kernel_compute(samples2, is_hist=is_hist, metric=metric, n_jobs=n_jobs) 360 | Z = kernel_compute(samples1, Y=samples2, is_hist=is_hist, metric=metric, n_jobs=n_jobs) 361 | 362 | return np.average(X) + np.average(Y) - 2 * np.average(Z) 363 | 364 | ##### code adapted from https://github.com/idea-iitd/graphgen/blob/master/metrics/stats.py 365 | def nspdk_stats(graph_ref_list, graph_pred_list): 366 | graph_pred_list_remove_empty = [G for G in graph_pred_list if not G.number_of_nodes() == 0] 367 | 368 | prev = datetime.now() 369 | mmd_dist = compute_nspdk_mmd(graph_ref_list, graph_pred_list_remove_empty, metric='nspdk', is_hist=False, n_jobs=20) 370 | elapsed = datetime.now() - prev 371 | if PRINT_TIME: 372 | print('Time computing degree mmd: ', elapsed) 373 | return mmd_dist 374 | 375 | METHOD_NAME_TO_FUNC = { 376 | 'degree': degree_stats, 377 | 'cluster': clustering_stats, 378 | 'orbit': orbit_stats_all, 379 | 'spectral': spectral_stats, 380 | 'nspdk': nspdk_stats 381 | } 382 | 383 | 384 | # -------- Evaluate generated generic graphs -------- 385 | def eval_graph_list(graph_ref_list, graph_pred_list, methods=None, kernels=None): 386 | if methods is None: 387 | methods = ['degree', 'cluster', 'orbit'] 388 | results = {} 389 | for method in methods: 390 | if method == 'nspdk': 391 | results[method] = METHOD_NAME_TO_FUNC[method](graph_ref_list, graph_pred_list) 392 | else: 393 | results[method] = round(METHOD_NAME_TO_FUNC[method](graph_ref_list, graph_pred_list, kernels[method]), 6) 394 | print('\033[91m' + f'{method:9s}' + '\033[0m' + ' : ' + '\033[94m' + f'{results[method]:.6f}' + '\033[0m') 395 | return results 396 | 397 | # # codes adapted from https://github.com/KarolisMart/SPECTRE 398 | # def is_planar_graph(G): 399 | # return nx.is_connected(G) and nx.check_planarity(G)[0] 400 | 401 | # def is_grid_graph(G): 402 | # """ 403 | # Check if the graph is grid, by comparing with all the real grids with the same node count 404 | # """ 405 | # all_grid_file = f"data/all_grids.pt" 406 | # if os.path.isfile(all_grid_file): 407 | # all_grids = torch.load(all_grid_file) 408 | # else: 409 | # all_grids = {} 410 | # for i in range(2, 20): 411 | # for j in range(2, 20): 412 | # G_grid = nx.grid_2d_graph(i, j) 413 | # n_nodes = f"{len(G_grid.nodes())}" 414 | # all_grids[n_nodes] = all_grids.get(n_nodes, []) + [G_grid] 415 | # torch.save(all_grids, all_grid_file) 416 | 417 | # n_nodes = f"{len(G.nodes())}" 418 | # if n_nodes in all_grids: 419 | # for G_grid in all_grids[n_nodes]: 420 | # if nx.faster_could_be_isomorphic(G, G_grid): 421 | # if nx.is_isomorphic(G, G_grid): 422 | # return True 423 | # return False 424 | # else: 425 | # return False 426 | 427 | # def is_sbm_graph(G, p_intra=0.3, p_inter=0.005, strict=True, refinement_steps=1000): 428 | # """ 429 | # Check if how closely given graph matches a SBM with given probabilites by computing mean probability of Wald test statistic for each recovered parameter 430 | # """ 431 | 432 | # adj = nx.adjacency_matrix(G).toarray() 433 | # idx = adj.nonzero() 434 | # g = gt.Graph() 435 | # g.add_edge_list(np.transpose(idx)) 436 | # try: 437 | # state = gt.minimize_blockmodel_dl(g) 438 | # except ValueError: 439 | # if strict: 440 | # return False 441 | # else: 442 | # return 0.0 443 | 444 | # # Refine using merge-split MCMC 445 | # for i in range(refinement_steps): 446 | # state.multiflip_mcmc_sweep(beta=np.inf, niter=10) 447 | 448 | # b = state.get_blocks() 449 | # b = gt.contiguous_map(state.get_blocks()) 450 | # state = state.copy(b=b) 451 | # e = state.get_matrix() 452 | # n_blocks = state.get_nonempty_B() 453 | # node_counts = state.get_nr().get_array()[:n_blocks] 454 | # edge_counts = e.todense()[:n_blocks, :n_blocks] 455 | # if strict: 456 | # if (node_counts > 40).sum() > 0 or (node_counts < 20).sum() > 0 or n_blocks > 5 or n_blocks < 2: 457 | # return False 458 | 459 | # max_intra_edges = node_counts * (node_counts - 1) 460 | # est_p_intra = np.diagonal(edge_counts) / (max_intra_edges + 1e-6) 461 | 462 | # max_inter_edges = node_counts.reshape((-1, 1)) @ node_counts.reshape((1, -1)) 463 | # np.fill_diagonal(edge_counts, 0) 464 | # est_p_inter = edge_counts / (max_inter_edges + 1e-6) 465 | 466 | # W_p_intra = (est_p_intra - p_intra)**2 / (est_p_intra * (1-est_p_intra) + 1e-6) 467 | # W_p_inter = (est_p_inter - p_inter)**2 / (est_p_inter * (1-est_p_inter) + 1e-6) 468 | 469 | # W = W_p_inter.copy() 470 | # np.fill_diagonal(W, W_p_intra) 471 | # p = 1 - chi2.cdf(abs(W), 1) 472 | # p = p.mean() 473 | # if strict: 474 | # return p > 0.9 # p value < 10 % 475 | # else: 476 | # return p -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import ZeroPad2d 3 | from torch import LongTensor 4 | from torch.utils.data import random_split 5 | from torch import count_nonzero 6 | import math 7 | from collections import deque 8 | from treelib import Tree, Node 9 | from sklearn.model_selection import train_test_split 10 | from itertools import zip_longest 11 | import networkx as nx 12 | from treelib import Tree, Node 13 | import numpy as np 14 | from itertools import compress, islice 15 | import os 16 | from pathlib import Path 17 | import json 18 | import re 19 | 20 | 21 | from data.tokens import grouper_mol 22 | 23 | 24 | DATA_DIR = "resource" 25 | NODE_TYPE_DICT = {'F': 9, 'O': 10, 'N': 11, 'C': 12, 'P': 13, 'I': 14, 'Cl': 15, 'Br': 16, 'S': 17} 26 | TYPE_NODE_DICT = {str(key): value for value, key in NODE_TYPE_DICT.items()} 27 | BOND_TYPE_DICT = {1: 5, 2: 6, 3: 7, 1.5: 8} 28 | TYPE_BOND_DICT = {key: value for value, key in NODE_TYPE_DICT.items()} 29 | 30 | def get_level(node): 31 | return int(node.identifier.split('-')[0]) 32 | 33 | def get_location(node): 34 | return int(node.identifier.split('-')[1]) 35 | 36 | def get_k(tree): 37 | return math.sqrt(len(tree[tree.root].successors(tree.identifier))) 38 | 39 | def get_parent(node, tree): 40 | return tree[node.predecessor(tree.identifier)] 41 | 42 | def get_children_identifier(node, tree): 43 | return sorted(node.successors(tree.identifier), key=lambda x: get_sort_key(x)) 44 | 45 | def get_sort_key(node_id): 46 | if len(node_id.split('-')) > 2: 47 | try: 48 | return (int(node_id.split('-')[0]), int(node_id.split('-')[1]), int(node_id.split('-')[2])) 49 | except: 50 | assert node_id.split('-') 51 | else: 52 | return (int(node_id.split('-')[0]), int(node_id.split('-')[1])) 53 | 54 | def nearest_power(N, base=2): 55 | a = int(math.log(N, base)) 56 | if base**a == N: 57 | return N 58 | 59 | return base**(a + 1) 60 | 61 | def adj_to_k2_tree(adj, return_tree=False, is_wholetree=False, k=4, is_mol=False): 62 | if not is_mol: 63 | adj[adj > 0] = 1 64 | n_org_nodes = adj.shape[0] 65 | # add padding (proper size for k) 66 | n_nodes = nearest_power(n_org_nodes, k) 67 | k_square = k**2 68 | padder = ZeroPad2d((0, n_nodes-n_org_nodes, 0, n_nodes-n_org_nodes)) 69 | padded_adj = padder(adj) 70 | total_level = int(math.log(n_nodes, k)) 71 | tree_list = [] 72 | leaf_list = [] 73 | tree = Tree() 74 | # add root node 75 | tree.create_node("root", "0") 76 | tree_key_list = deque([]) 77 | slice_size = int(n_nodes / k) 78 | # slice matrices 79 | start_index = range(0,n_nodes,slice_size) 80 | end_index = range(slice_size,n_nodes+1,slice_size) 81 | slices = [] 82 | for row_start, row_end in zip(start_index, end_index): 83 | for col_start, col_end in zip(start_index, end_index): 84 | slices.append(padded_adj[row_start:row_end, col_start:col_end]) 85 | 86 | sliced_adjs = deque(slices) 87 | sliced_adjs_is_zero = LongTensor([int(count_nonzero(adj)>0) for adj in sliced_adjs]) 88 | tree_list.append(sliced_adjs_is_zero) 89 | # molecule + only leaf 90 | if is_mol and adj.shape[0] == k: 91 | tree_element_list = deque(list(map(int, torch.flatten(adj).tolist()))) 92 | else: 93 | tree_element_list = deque(sliced_adjs_is_zero) 94 | 95 | for i, elem in enumerate(tree_element_list, 1): 96 | tree.create_node(elem, f"1-{i}", parent="0") 97 | tree_key_list.append(f"1-{i}") 98 | 99 | while (slice_size != 1): 100 | n_nodes = sliced_adjs[0].shape[0] 101 | if n_nodes == k: 102 | if is_wholetree: 103 | leaf_list = [adj.reshape(k_square,) for adj in sliced_adjs] 104 | else: 105 | leaf_list = [adj.reshape(k_square,) for adj in sliced_adjs if count_nonzero(adj)>0] 106 | break 107 | slice_size = int(n_nodes / k) 108 | target_adj = sliced_adjs.popleft() 109 | target_adj_size = target_adj.shape[0] 110 | if return_tree: 111 | parent_node_key = tree_key_list.popleft() 112 | # remove adding leaves to 0 113 | if not is_wholetree: 114 | if count_nonzero(target_adj) == 0: 115 | continue 116 | # generate tree_list and leaf_list 117 | new_sliced_adjs = [] 118 | start_index = range(0,n_nodes,slice_size) 119 | end_index = range(slice_size,n_nodes+1,slice_size) 120 | for row_start, row_end in zip(start_index, end_index): 121 | for col_start, col_end in zip(start_index, end_index): 122 | new_sliced_adjs.append(target_adj[row_start:row_end, col_start:col_end]) 123 | new_sliced_adjs_is_zero = LongTensor([int(count_nonzero(adj)>0) for adj in new_sliced_adjs]) 124 | sliced_adjs.extend(new_sliced_adjs) 125 | tree_list.append(new_sliced_adjs_is_zero) 126 | 127 | if return_tree: 128 | # generate tree 129 | tree_element_list.extend(new_sliced_adjs_is_zero) 130 | cur_level = int(total_level - math.log(target_adj_size, k) + 1) 131 | cur_level_key_list = [int(key.split('-')[1]) for key in tree_key_list if int(key.split('-')[0]) == cur_level] 132 | if len(cur_level_key_list) > 0: 133 | key_starting_point = max(cur_level_key_list) 134 | else: 135 | key_starting_point = 0 136 | for i, elem in enumerate(new_sliced_adjs_is_zero, key_starting_point+1): 137 | tree.create_node(elem, f"{cur_level}-{i}", parent=parent_node_key) 138 | tree_key_list.append(f"{cur_level}-{i}") 139 | 140 | if return_tree: 141 | # add leaves to tree 142 | leaves = [node for node in tree.leaves() if node.tag == 1] 143 | index = 1 144 | for leaf, leaf_values in zip(leaves, leaf_list): 145 | for value in leaf_values: 146 | tree.create_node(int(value), f"{total_level}-{index}", parent=leaf) 147 | index += 1 148 | return tree 149 | else: 150 | return tree_list, leaf_list 151 | 152 | def check_tree_validity(tree): 153 | depth = tree.depth() 154 | if depth == 1: 155 | return False 156 | leaves = [leaf for leaf in tree.leaves() if leaf.tag != '0'] 157 | invalid_leaves = [leaf for leaf in leaves if tree.depth(leaf)!=depth] 158 | if len(invalid_leaves) == 0: 159 | return True 160 | else: 161 | return False 162 | 163 | def tree_to_adj(tree, k=2): 164 | ''' 165 | convert k2 tree to adjacency matrix 166 | ''' 167 | try: 168 | tree = map_starting_point(tree, k) 169 | except TypeError: 170 | return None 171 | depth = tree.depth() 172 | leaves = [leaf for leaf in tree.leaves() if leaf.tag != '0'] 173 | one_data_points = [leaf.data for leaf in leaves] 174 | x_list = [data[0] for data in one_data_points] 175 | y_list = [data[1] for data in one_data_points] 176 | label_list = [NODE_TYPE_DICT[leaf.tag] if leaf.tag in NODE_TYPE_DICT.keys() else int(leaf.tag) for leaf in leaves] 177 | matrix_size = int(k**depth) 178 | adj = torch.zeros((matrix_size, matrix_size)) 179 | for x, y, label in zip(x_list, y_list, label_list): 180 | # if (x > len(adj)) or (y > len(adj)): 181 | # return None 182 | adj[x, y] = label 183 | 184 | return adj 185 | 186 | def map_starting_point(tree, k): 187 | ''' 188 | map starting points for each elements in tree (to convert adjacency matrix) 189 | ''' 190 | try: 191 | bfs_list = [tree[node] for node in tree.expand_tree(mode=Tree.WIDTH, 192 | key=lambda x: (int(x.identifier.split('-')[0]), int(x.identifier.split('-')[1]), int(x.identifier.split('-')[2])))] 193 | except: 194 | bfs_list = [tree[node] for node in tree.expand_tree(mode=Tree.WIDTH, 195 | key=lambda x: (int(x.identifier.split('-')[0]), int(x.identifier.split('-')[1])))] 196 | bfs_list[0].data = (0,0) 197 | 198 | for node in bfs_list[1:]: 199 | parent = get_parent(node, tree) 200 | siblings = get_children_identifier(parent, tree) 201 | index = siblings.index(node.identifier) 202 | level = get_level(node) 203 | tree_depth = tree.depth() 204 | matrix_size = k**tree_depth 205 | adding_value = int(matrix_size/(k**level)) 206 | parent_starting_point = parent.data 207 | node.data = (parent_starting_point[0]+adding_value*int(index/k), parent_starting_point[1]+adding_value*int(index%k)) 208 | 209 | return tree 210 | 211 | def map_child_deg(node, tree): 212 | ''' 213 | return sum of direct children nodes' degree (tag) 214 | ''' 215 | if node.is_leaf(): 216 | return str(int(node.tag)) 217 | 218 | children = get_children_identifier(node, tree) 219 | child_deg = sum([int(tree[child].tag > 0) for child in children]) 220 | 221 | return str(child_deg) 222 | 223 | def map_all_child_deg(node, tree): 224 | ''' 225 | return sum of all children nodes' degree (tag) 226 | ''' 227 | if node.is_leaf(): 228 | return str(int(node.tag)) 229 | 230 | children = get_children_identifier(node, tree) 231 | child_deg = sum([int(tree[child].tag) for child in children]) 232 | 233 | return str(child_deg) 234 | 235 | def tree_to_bfs_string(tree, string_type='group-red'): 236 | bfs_node_list = [tree[node] for node in tree.expand_tree(mode=tree.WIDTH, 237 | key=lambda x: (int(x.identifier.split('-')[0]), int(x.identifier.split('-')[1])))][1:] 238 | bfs_value_list = [str(int(node.tag)) for node in bfs_node_list] 239 | 240 | return ''.join(bfs_value_list) 241 | 242 | def grouper(n, iterable, fillvalue=None): 243 | "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" 244 | args = [iter(iterable)] * n 245 | return zip_longest(fillvalue=fillvalue, *args) 246 | 247 | def bfs_string_to_tree(string, is_zinc=False, k=2): 248 | k_square = k**2 249 | tree = Tree() 250 | tree.create_node("root", "0") 251 | parent_node = tree["0"] 252 | node_deque = deque([]) 253 | if is_zinc: 254 | node_groups = grouper_mol(string) 255 | else: 256 | node_groups = grouper(k_square, string) 257 | for node_1, node_2, node_3, node_4 in node_groups: 258 | parent_level = get_level(parent_node) 259 | cur_level_max = max([get_location(node) for node in tree.nodes.values() if get_level(node) == parent_level+1], default=0) 260 | for i, node_tag in enumerate([node_1, node_2, node_3, node_4], 1): 261 | if node_tag == None: 262 | break 263 | new_node = Node(tag=node_tag, identifier=f"{parent_level+1}-{cur_level_max+i}") 264 | tree.add_node(new_node, parent=parent_node) 265 | node_deque.append(new_node) 266 | parent_node = node_deque.popleft() 267 | while(parent_node.tag == '0'): 268 | if len(node_deque) == 0: 269 | return tree 270 | parent_node = node_deque.popleft() 271 | return tree 272 | 273 | def clean_string(string): 274 | 275 | if "[pad]" in string: 276 | string = string[:string.index("[pad]")] 277 | 278 | return string 279 | 280 | 281 | def train_val_test_split( 282 | data: list, 283 | data_name='GDSS_com', 284 | train_size: float = 0.7, val_size: float = 0.1, test_size: float = 0.2, 285 | seed: int = 42, 286 | ): 287 | if data_name in ['qm9', 'zinc']: 288 | # code adpated from https://github.com/harryjo97/GDSS 289 | with open(os.path.join(DATA_DIR, f'{data_name}/valid_idx_{data_name}.json')) as f: 290 | test_idx = json.load(f) 291 | if data_name == 'qm9': 292 | test_idx = test_idx['valid_idxs'] 293 | test_idx = [int(i) for i in test_idx] 294 | train_idx = [i for i in range(len(data)) if i not in test_idx] 295 | test = [data[i] for i in test_idx] 296 | train_val = [data[i] for i in train_idx] 297 | train, val = train_test_split(train_val, train_size=train_size / (train_size + val_size), random_state=seed, shuffle=True) 298 | elif data_name in ['planar', 'sbm', 'proteins']: 299 | # code adapted from https://github.com/KarolisMart/SPECTRE 300 | test_len = int(round(len(data)*0.2)) 301 | train_len = int(round((len(data) - test_len)*0.8)) 302 | val_len = len(data) - train_len - test_len 303 | train, val, test = random_split(data, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(1234)) 304 | elif data_name in ['profold', 'collab']: 305 | train, test = train_test_split(data, train_size=train_size, shuffle=False) 306 | val = test 307 | else: 308 | train_val, test = train_test_split(data, train_size=train_size + val_size, shuffle=False) 309 | train, val = train_test_split(train_val, train_size=train_size / (train_size + val_size), random_state=seed, shuffle=True) 310 | return train, val, test 311 | 312 | def adj_to_graph(adj, is_cuda=False): 313 | if is_cuda: 314 | adj = adj.detach().cpu().numpy() 315 | G = nx.from_numpy_matrix(adj) 316 | G.remove_edges_from(nx.selfloop_edges(G)) 317 | G.remove_nodes_from(list(nx.isolates(G))) 318 | if G.number_of_nodes() < 1: 319 | G.add_node(1) 320 | return G 321 | 322 | def map_tree_pe(tree): 323 | depth = tree.depth() 324 | k = get_k(tree) 325 | size = int(depth*(k**2)) 326 | pe = torch.zeros((size)) 327 | for node in tree.nodes.values(): 328 | node.data = pe 329 | if not node.is_root(): 330 | parent = get_parent(node, tree) 331 | branch = get_children_identifier(parent, tree).index(node.identifier) 332 | current_pe = torch.zeros(int(k**2)) 333 | current_pe[branch] = 1 334 | pe = torch.cat((current_pe, parent.data[:int(size-k**2)])) 335 | node.data = pe 336 | return tree 337 | 338 | def map_new_ordered_graph(ordered_graph): 339 | ''' 340 | Map ordered_graph object to ordered networkx graph 341 | ''' 342 | org_graph = ordered_graph.graph 343 | ordering = ordered_graph.ordering 344 | new_graph = nx.from_numpy_array(nx.adjacency_matrix(org_graph, nodelist=ordering)) 345 | return new_graph 346 | 347 | # for redundant removed strings 348 | 349 | def generate_final_tree_red(tree, k=2): 350 | tree_with_iden = add_zero_to_identifier(tree) 351 | final_tree = add_symmetry_to_tree(tree_with_iden, k) 352 | 353 | return final_tree 354 | 355 | def generate_initial_tree_red(string_token_list, k=2): 356 | node_groups = [tuple(grouper_mol(string, k)[0]) for string in string_token_list] 357 | tree = Tree() 358 | tree.create_node("root", "0-0-0") 359 | parent_node = tree["0-0-0"] 360 | node_deque = deque([]) 361 | for nodes in node_groups: 362 | parent_level = get_level(parent_node) 363 | cur_level_max = max([get_location(node) for node in tree.nodes.values() if get_level(node) == parent_level+1], default=0) 364 | for i, node_tag in enumerate(nodes, 1): 365 | if node_tag == None: 366 | break 367 | new_node = Node(tag=node_tag, identifier=f"{parent_level+1}-{cur_level_max+i}") 368 | tree.add_node(new_node, parent=parent_node) 369 | node_deque.append(new_node) 370 | parent_node = node_deque.popleft() 371 | while(parent_node.tag == '0'): 372 | if len(node_deque) == 0: 373 | return tree 374 | parent_node = node_deque.popleft() 375 | 376 | return tree 377 | 378 | def find_new_identifier(node_id, index, is_dup=1): 379 | split = node_id.split('-') 380 | # k = 2 381 | if index == 1: 382 | num = 1 383 | num *= is_dup 384 | cur_pre_identifier = split[0] + '-' + split[1] 385 | # k = 3 386 | elif index == 2: 387 | num = 2 388 | num *= is_dup 389 | cur_pre_identifier = split[0] + '-' + split[1] 390 | elif index == 4: 391 | num = 1 392 | num *= is_dup 393 | cur_pre_identifier = split[0] + '-' + str(int(split[1])-2) 394 | elif index == 5: 395 | num = 1 396 | num *= is_dup 397 | cur_pre_identifier = split[0] + '-' + str(int(split[1])-1) 398 | else: 399 | num = 2 400 | num *= is_dup 401 | cur_pre_identifier = split[0] + '-' + split[1] 402 | 403 | new_last_identifier = int(split[2])-num 404 | return cur_pre_identifier + '-' + str(new_last_identifier) 405 | 406 | def get_child_index(k): 407 | indices = [] 408 | c = 0 409 | for i in range(k+1): 410 | for j in range(1,i+1): 411 | c=c+1 412 | if j != i: 413 | indices.append(c) 414 | return indices 415 | 416 | 417 | def add_symmetry_to_tree(tree, k): 418 | k_square = k**2 419 | bfs_node_list = [tree[node] for node in tree.expand_tree(mode=tree.WIDTH, key=lambda x: (int(x.identifier.split('-')[0]), int(x.identifier.split('-')[1])))] 420 | node_list = [node for node in bfs_node_list[::-1] if not node.is_leaf()] 421 | for node in node_list: 422 | child_nodes = get_children_identifier(node, tree) 423 | if len(child_nodes) < k_square: 424 | postfixes = get_child_index(k) 425 | 426 | for index in postfixes: 427 | copy_node = tree.get_node(child_nodes[int(index)-1]) 428 | new_node = Node(tag=copy_node.tag, identifier=find_new_identifier(copy_node.identifier, index)) 429 | subtree = Tree(tree.subtree(child_nodes[int(index)-1]), deep=True) 430 | new_tree = Tree(subtree, deep=True) 431 | if len(subtree) > 1: 432 | for nid, n in sorted(subtree.nodes.items(), key=lambda x: (int(x[0].split('-')[0]), int(x[0].split('-')[1]), int(x[0].split('-')[2]))): 433 | count_dup = len([key for key in subtree.nodes.keys() 434 | if (key.split('-')[0] == nid.split('-')[0]) and (key.split('-')[1] == nid.split('-')[1])]) 435 | org_dup = len([key for key in tree.nodes.keys() 436 | if (key.split('-')[0] == nid.split('-')[0]) and (key.split('-')[1] == nid.split('-')[1])]) 437 | new_iden = find_new_identifier(nid, index, (count_dup+org_dup)*100) 438 | while (new_iden in tree): 439 | new = int(new_iden.split('-')[2]) - 1 440 | new_iden = new_iden.split('-')[0] + '-' + new_iden.split('-')[1] + '-' + str(new) 441 | new_tree.update_node(nid, identifier=new_iden) 442 | tree.paste(node.identifier, new_tree) 443 | 444 | else: 445 | tree.add_node(new_node, parent=node) 446 | 447 | return tree 448 | 449 | def add_zero_to_identifier(tree): 450 | new_tree = Tree(tree, deep=True) 451 | for node in tree.nodes: 452 | new_identifier = node 453 | while (len(new_identifier.split('-'))<3): 454 | new_identifier = new_identifier + '-10000' 455 | new_tree.update_node(node, identifier=new_identifier) 456 | return new_tree 457 | 458 | def fix_symmetry(adj): 459 | try: 460 | sym_adj = torch.tril(adj) + torch.tril(adj).T 461 | except TypeError: 462 | return None 463 | return torch.where(sym_adj>0, 1, 0) 464 | 465 | 466 | def map_deg_string(string): 467 | new_string = [] 468 | group_queue = deque(grouper(4, string)) 469 | group_queue.popleft() 470 | for index, char in enumerate(string): 471 | if len(group_queue) == 0: 472 | left = string[index:] 473 | break 474 | if char == '0': 475 | new_string.append(char) 476 | else: 477 | new_string.append(str(sum([int(char) for char in group_queue.popleft()]))) 478 | 479 | return ''.join(new_string) + left 480 | 481 | def remove_redundant(input_string, is_mol=False, k=2): 482 | k_square = k**2 483 | string = input_string[0:k_square] 484 | pos_list = list(range(1, k_square+1)) 485 | str_pos_queue = deque([(s, p) for s, p in zip(string, pos_list)]) 486 | if is_mol: 487 | group_list = list(grouper_mol(input_string)) 488 | else: 489 | group_list = list(grouper(k_square, input_string)) 490 | for cur_string in [''.join(token) for token in group_list][1:]: 491 | cur_parent, cur_parent_pos = str_pos_queue.popleft() 492 | # if value is 0, it cannot be parent node -> skip 493 | while((cur_parent == '0') and (len(str_pos_queue) > 0)): 494 | cur_parent, cur_parent_pos = str_pos_queue.popleft() 495 | # i: order of the child node in the same parent 496 | cur_pos = [cur_parent_pos*10+i for i in range(1,k_square+1)] 497 | # pos_list: final position of each node 498 | pos_list.extend(cur_pos) 499 | if is_mol: 500 | str_pos_queue.extend([(s, c) for s, c in zip(grouper_mol(cur_string)[0], cur_pos)]) 501 | else: 502 | str_pos_queue.extend([(s, c) for s, c in zip(cur_string, cur_pos)]) 503 | 504 | pos_list = [str(pos) for pos in pos_list] 505 | # prefix: diagonal 506 | prefixes = [str((i-1)*k + i) for i in range(1,k+1)] 507 | # posfix: upper diagonal 508 | postfixes = [] 509 | for i in range(1,k+1): 510 | for j in range(i+1, k+1): 511 | postfixes.append(str((i-1)*k+j)) 512 | 513 | # find positions ends with 2 including only 1 and 4 514 | remove_pos_prefix_list = [pos for i, pos in enumerate(pos_list) 515 | if (pos[-1] in postfixes) and len((set(pos[:-1]))-set(prefixes))==0] 516 | remain_pos_index = [not pos.startswith(tuple(remove_pos_prefix_list)) for pos in pos_list] 517 | remain_pos_list = [pos for pos in pos_list if not pos.startswith(tuple(remove_pos_prefix_list))] 518 | 519 | # find cutting points (one block) 520 | cut_list = [i for i, pos in enumerate(remain_pos_list) if pos[-1] == str(k**2)] 521 | cut_list_2 = [0] 522 | cut_list_2.extend(cut_list[:-1]) 523 | cut_size_list = [i - j for i, j in zip(cut_list , cut_list_2)] 524 | cut_size_list[0] += 1 525 | if is_mol: 526 | final_string_list = list(compress([item for sublist in grouper_mol(input_string) for item in sublist], remain_pos_index)) 527 | else: 528 | final_string_list = list(compress([*input_string], remain_pos_index)) 529 | 530 | pos_list_iter = iter(final_string_list) 531 | final_string_cut_list = [list(islice(pos_list_iter, i)) for i in cut_size_list] 532 | 533 | return [''.join(l) for l in final_string_cut_list] 534 | 535 | def get_max_len(data_name, order='C-M', k=2): 536 | total_strings = [] 537 | k_square = k**2 538 | for split in ['train', 'test', 'val']: 539 | if k > 2: 540 | string_path = os.path.join(DATA_DIR, f"{data_name}/{order}/{data_name}_str_{split}_{k}.txt") 541 | else: 542 | string_path = os.path.join(DATA_DIR, f"{data_name}/{order}/{data_name}_str_{split}.txt") 543 | 544 | # string_path = os.path.join(DATA_DIR, f"{data_name}/{order}/{data_name}_str_{split}_{k}.txt") 545 | strings = Path(string_path).read_text(encoding="utf=8").splitlines() 546 | 547 | total_strings.extend(strings) 548 | 549 | red_list = [remove_redundant(string, is_mol=False, k=k) for string in total_strings] 550 | # red_strings = [''.join(red) for red in red_list] 551 | 552 | max_len = max([len(string) for string in total_strings]) 553 | group_max_len = max_len / k_square 554 | red_max_len = max([len(string) for string in red_list]) 555 | 556 | # max_len: the length of original string 557 | # group_max_len: the length of group string 558 | # red_max_len: the length of group-red string 559 | return max_len, group_max_len, red_max_len 560 | 561 | def clean_high_feature(string): 562 | new_string = [] 563 | for token in string: 564 | if re.search('[a-zA-Z]', token): 565 | new_token = "" 566 | for char in token.split(" "): 567 | if char.isnumeric(): 568 | new_token += char 569 | else: 570 | new_token += char[:-1] 571 | new_string.append(new_token) 572 | else: 573 | new_string.append(token.replace(" ", "")) 574 | return new_string -------------------------------------------------------------------------------- /evaluation/orca/orca.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | using namespace std; 14 | 15 | 16 | typedef long long int64; 17 | typedef pair PII; 18 | typedef struct { int first, second, third; } TIII; 19 | 20 | struct PAIR { 21 | int a, b; 22 | PAIR(int a0, int b0) { a=min(a0,b0); b=max(a0,b0); } 23 | }; 24 | bool operator<(const PAIR &x, const PAIR &y) { 25 | if (x.a==y.a) return x.bb) swap(a,b); 42 | if (b>c) swap(b,c); 43 | if (a>b) swap(a,b); 44 | } 45 | }; 46 | bool operator<(const TRIPLE &x, const TRIPLE &y) { 47 | if (x.a==y.a) { 48 | if (x.b==y.b) return x.c common2; 62 | unordered_map common3; 63 | unordered_map::iterator common2_it; 64 | unordered_map::iterator common3_it; 65 | 66 | #define common3_get(x) (((common3_it=common3.find(x))!=common3.end())?(common3_it->second):0) 67 | #define common2_get(x) (((common2_it=common2.find(x))!=common2.end())?(common2_it->second):0) 68 | 69 | int n,m; // n = number of nodes, m = number of edges 70 | int *deg; // degrees of individual nodes 71 | PAIR *edges; // list of edges 72 | 73 | int **adj; // adj[x] - adjacency list of node x 74 | PII **inc; // inc[x] - incidence list of node x: (y, edge id) 75 | bool adjacent_list(int x, int y) { return binary_search(adj[x],adj[x]+deg[x],y); } 76 | int *adj_matrix; // compressed adjacency matrix 77 | const int adj_chunk = 8*sizeof(int); 78 | bool adjacent_matrix(int x, int y) { return adj_matrix[(x*n+y)/adj_chunk]&(1<<((x*n+y)%adj_chunk)); } 79 | bool (*adjacent)(int,int); 80 | int getEdgeId(int x, int y) { return inc[x][lower_bound(adj[x],adj[x]+deg[x],y)-adj[x]].second; } 81 | 82 | int64 **orbit; // orbit[x][o] - how many times does node x participate in orbit o 83 | int64 **eorbit; // eorbit[x][o] - how many times does node x participate in edge orbit o 84 | 85 | /** count graphlets on max 4 nodes */ 86 | void count4() { 87 | clock_t startTime, endTime; 88 | startTime = clock(); 89 | clock_t startTime_all, endTime_all; 90 | startTime_all = startTime; 91 | int frac,frac_prev; 92 | 93 | // precompute triangles that span over edges 94 | printf("stage 1 - precomputing common nodes\n"); 95 | int *tri = (int*)calloc(m,sizeof(int)); 96 | frac_prev=-1; 97 | for (int i=0;i= x) break; 128 | nn=0; 129 | for (int ny=0;ny= y) break; 132 | if (adjacent(x,z)==0) continue; 133 | neigh[nn++]=z; 134 | } 135 | for (int i=0;i= x) break; 289 | nn=0; 290 | for (int ny=0;ny= y) break; 293 | if (neighx[z]==-1) continue; 294 | int xz=neighx[z]; 295 | neigh[nn]=z; 296 | neigh_edges[nn]={xz, yz}; 297 | nn++; 298 | } 299 | for (int i=0;i=0;nx--) { 330 | int y=inc[x][nx].first, xy=inc[x][nx].second; 331 | if (y <= x) break; 332 | nn=0; 333 | for (int ny=deg[y]-1;ny>=0;ny--) { 334 | int z=adj[y][ny]; 335 | if (z <= y) break; 336 | if (adjacent(x,z)==0) continue; 337 | neigh[nn++]=z; 338 | } 339 | for (int i=0;i= x) break; 494 | nn=0; 495 | for (int ny=0;ny= y) break; 498 | if (adjacent(x,z)) { 499 | neigh[nn++]=z; 500 | } 501 | } 502 | for (int i=0;i2 && tri[xb]>2)?(common3_get(TRIPLE(x,a,b))-1):0; 601 | f_71 += (tri[xa]>2 && tri[xc]>2)?(common3_get(TRIPLE(x,a,c))-1):0; 602 | f_71 += (tri[xb]>2 && tri[xc]>2)?(common3_get(TRIPLE(x,b,c))-1):0; 603 | f_67 += tri[xa]-2+tri[xb]-2+tri[xc]-2; 604 | f_66 += common2_get(PAIR(a,b))-2; 605 | f_66 += common2_get(PAIR(a,c))-2; 606 | f_66 += common2_get(PAIR(b,c))-2; 607 | f_58 += deg[x]-3; 608 | f_57 += deg[a]-3+deg[b]-3+deg[c]-3; 609 | } 610 | } 611 | 612 | // x = orbit-13 (diamond) 613 | for (int nx2=0;nx21 && tri[xc]>1)?(common3_get(TRIPLE(x,b,c))-1):0; 621 | f_68 += common3_get(TRIPLE(a,b,c))-1; 622 | f_64 += common2_get(PAIR(b,c))-2; 623 | f_61 += tri[xb]-1+tri[xc]-1; 624 | f_60 += common2_get(PAIR(a,b))-1; 625 | f_60 += common2_get(PAIR(a,c))-1; 626 | f_55 += tri[xa]-2; 627 | f_48 += deg[b]-2+deg[c]-2; 628 | f_42 += deg[x]-3; 629 | f_41 += deg[a]-3; 630 | } 631 | } 632 | 633 | // x = orbit-12 (diamond) 634 | for (int nx2=nx1+1;nx21)?common3_get(TRIPLE(a,b,c)):0; 642 | f_63 += common_x[c]-2; 643 | f_59 += tri[ac]-1+common2_get(PAIR(b,c))-1; 644 | f_54 += common2_get(PAIR(a,b))-2; 645 | f_47 += deg[x]-2; 646 | f_46 += deg[c]-2; 647 | f_40 += deg[a]-3+deg[b]-3; 648 | } 649 | } 650 | 651 | // x = orbit-8 (cycle) 652 | for (int nx2=nx1+1;nx20)?common3_get(TRIPLE(a,b,c)):0; 660 | f_53 += tri[xa]+tri[xb]; 661 | f_51 += tri[ac]+common2_get(PAIR(c,b)); 662 | f_50 += common_x[c]-2; 663 | f_49 += common_a[b]-2; 664 | f_38 += deg[x]-2; 665 | f_37 += deg[a]-2+deg[b]-2; 666 | f_36 += deg[c]-2; 667 | } 668 | } 669 | 670 | // x = orbit-11 (paw) 671 | for (int nx2=nx1+1;nx21 && tri[ac]>1)?common3_get(TRIPLE(a,b,c)):0; 710 | f_45 += common2_get(PAIR(b,c))-1; 711 | f_39 += tri[ab]-1+tri[ac]-1; 712 | f_31 += deg[a]-3; 713 | f_28 += deg[x]-1; 714 | f_24 += deg[b]-2+deg[c]-2; 715 | } 716 | } 717 | 718 | // x = orbit-4 (path) 719 | for (int na=0;na= x) break; 914 | nn=0; 915 | for (int ny=0;ny= y) break; 918 | if (neighx[z]==-1) continue; 919 | int xz=neighx[z]; 920 | neigh[nn]=z; 921 | neigh_edges[nn]={xz, yz}; 922 | nn++; 923 | } 924 | for (int i=0;i=x) break; 994 | 995 | // common nodes of y and some other node 996 | for (int i=0;i> n >> m; 1367 | int d_max=0; 1368 | edges = (PAIR*)malloc(m*sizeof(PAIR)); 1369 | deg = (int*)calloc(n,sizeof(int)); 1370 | for (int i=0;i> a >> b; 1373 | if (!(0<=a && a(edges,edges+m).size())!=m) { 1390 | cerr << "Input file contains duplicate undirected edges." << endl; 1391 | return 0; 1392 | } 1393 | // set up adjacency matrix if it's smaller than 100MB 1394 | if ((int64)n*n < 100LL*1024*1024*8) { 1395 | adjacent = adjacent_matrix; 1396 | adj_matrix = (int*)calloc((n*n)/adj_chunk+1,sizeof(int)); 1397 | for (int i=0;i