├── cam_util ├── utils.py ├── aug.py ├── embedding_evaluation.py ├── chem_gnn.py └── dataset.py ├── README.md ├── early_stop.py ├── environment.yml ├── trasnfer_pretrain.py ├── unsupervised.py ├── datasets ├── tu_dataset.py └── transfer_mol_dataset.py ├── model ├── GCS_transfer.py └── GCS.py └── transfer_finetune.py /cam_util/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import Data 6 | 7 | 8 | def initialize_edge_weight(data): 9 | data.edge_weight = torch.ones(data.edge_index.shape[1], dtype=torch.float) 10 | return data 11 | 12 | def initialize_node_features(data): 13 | num_nodes = int(data.edge_index.max()) + 1 14 | data.x = torch.ones((num_nodes, 1)) 15 | return data 16 | 17 | def set_tu_dataset_y_shape(data): 18 | num_tasks = 1 19 | data.y = data.y.unsqueeze(num_tasks) 20 | return data 21 | 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Boosting Graph Contrastive Learning via Graph Contrastive Saliency 2 | 3 | This is the code for Boosting Graph Contrastive Learning via Graph Contrastive Saliency (GCS). 4 | GCS adaptively screens the semantic-related substructure in graphs by capitalizing on the proposed gradient-based Graph Contrastive Saliency (GCS). The goal is to identify the most semantically discriminative structures of a graph via contrastive learning, such that we can generate semantically meaningful augmentations by leveraging on saliency. 5 | 6 | ## Requirements 7 | 8 | To install requirements: 9 | 10 | ```setup 11 | conda env create -f environment.yaml 12 | ``` 13 | 14 | ## Unsupervised Learning 15 | 16 | To train the model for unsupervised graph-level tasks: 17 | 18 | ```setup 19 | python unsupervised.py 20 | ``` 21 | 22 | ## Transfer Learning 23 | Please refer to https://github.com/snap-stanford/pretrain-gnns#installation for environment setup and https://github.com/snap-stanford/pretrain-gnns#dataset-download to download dataset. 24 | 25 | 26 | To pretrain the model(s) in the paper for transfer learning: 27 | 28 | ```setup 29 | python transfer_pretrain.py 30 | ``` 31 | > Output: the file "latest.tar" 32 | 33 | To finetune the model(s) for downstream tasks: 34 | ```setup 35 | python transfer_finetune.py 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /early_stop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | 6 | class EarlyStopping: 7 | """Early stops the training if validation loss doesn't improve after a given patience.""" 8 | def __init__(self, directory, patience=7, verbose=False, delta=0): 9 | """ 10 | Args: 11 | patience (int): How long to wait after last time validation loss improved. 12 | Default: 7 13 | verbose (bool): If True, prints a message for each validation loss improvement. 14 | Default: False 15 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 16 | Default: 0 17 | path (str): Path for the checkpoint to be saved to. 18 | Default: 'checkpoint.pt' 19 | 20 | """ 21 | self.patience = patience 22 | self.verbose = verbose 23 | self.counter = 0 24 | self.best_score = None 25 | self.early_stop = False 26 | self.delta = delta 27 | self.directory = directory 28 | def __call__(self, score, save_dic): 29 | 30 | if self.best_score is None: 31 | self.best_score = score 32 | if not os.path.exists(self.directory + '/buffer'): 33 | os.makedirs(self.directory + '/buffer') 34 | torch.save(save_dic, os.path.join(self.directory + '/buffer', '{}_{}.tar'.format(score, 'checkpoint'))) 35 | 36 | elif score < self.best_score + self.delta: 37 | self.counter += 1 38 | if self.counter >= self.patience: 39 | self.early_stop = True 40 | else: 41 | self.best_score = score 42 | self.counter = 0 43 | if not os.path.exists(self.directory + '/buffer'): 44 | os.makedirs(self.directory + '/buffer') 45 | torch.save(save_dic, os.path.join(self.directory + '/buffer', '{}_{}.tar'.format(score, 'checkpoint'))) 46 | 47 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: good 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - blas=1.0=mkl 10 | - boost=1.74.0=py38h2b96118_5 11 | - boost-cpp=1.74.0=h359cf19_5 12 | - bottleneck=1.3.5=py38h26c90d9_1 13 | - bzip2=1.0.8=h7f98852_4 14 | - ca-certificates=2022.10.11=h06a4308_0 15 | - cairo=1.16.0=h19f5f5c_2 16 | - certifi=2022.12.7=py38h06a4308_0 17 | - cudatoolkit=11.3.1=h9edb442_10 18 | - cycler=0.11.0=pyhd8ed1ab_0 19 | - expat=2.4.9=h6a678d5_0 20 | - ffmpeg=4.3=hf484d3e_0 21 | - fontconfig=2.14.1=hc2a2eb6_0 22 | - freetype=2.12.1=hca18f0e_0 23 | - glib=2.69.1=he621ea3_2 24 | - gmp=6.2.1=h58526e2_0 25 | - gnutls=3.6.13=h85f3911_1 26 | - icu=69.1=h9c3ff4c_0 27 | - intel-openmp=2021.4.0=h06a4308_3561 28 | - jbig=2.1=h7f98852_2003 29 | - jpeg=9e=h166bdaf_1 30 | - kiwisolver=1.4.2=py38h295c915_0 31 | - lame=3.100=h7f98852_1001 32 | - lcms2=2.12=hddcbb42_0 33 | - ld_impl_linux-64=2.38=h1181459_1 34 | - lerc=2.2.1=h9c3ff4c_0 35 | - libdeflate=1.7=h7f98852_5 36 | - libffi=3.4.2=h6a678d5_6 37 | - libgcc-ng=12.2.0=h65d4601_19 38 | - libiconv=1.17=h166bdaf_0 39 | - libpng=1.6.37=h21135ba_2 40 | - libstdcxx-ng=11.2.0=h1234567_1 41 | - libtiff=4.3.0=hf544144_1 42 | - libuuid=2.32.1=h7f98852_1000 43 | - libuv=1.43.0=h7f98852_0 44 | - libwebp-base=1.2.2=h7f98852_1 45 | - libxcb=1.15=h7f8727e_0 46 | - libzlib=1.2.13=h166bdaf_4 47 | - llvm-openmp=15.0.5=he0ac6c6_0 48 | - lz4-c=1.9.3=h9c3ff4c_1 49 | - matplotlib-base=3.4.3=py38hf4fb855_1 50 | - mkl=2021.4.0=h06a4308_640 51 | - mkl-service=2.4.0=py38h95df7f1_0 52 | - mkl_fft=1.3.1=py38h8666266_1 53 | - mkl_random=1.2.2=py38h1abd341_0 54 | - ncurses=6.3=h5eee18b_3 55 | - nettle=3.6=he412f7d_0 56 | - numexpr=2.8.4=py38he184ba9_0 57 | - numpy=1.23.4=py38h14f4228_0 58 | - numpy-base=1.23.4=py38h31eccc5_0 59 | - olefile=0.46=pyh9f0ad1d_1 60 | - openh264=2.1.1=h780b84a_0 61 | - openjpeg=2.4.0=hb52868f_1 62 | - openssl=1.1.1s=h7f8727e_0 63 | - packaging=21.3=pyhd8ed1ab_0 64 | - pandas=1.4.4=py38h6a678d5_0 65 | - pcre=8.45=h9c3ff4c_0 66 | - pillow=8.3.2=py38h8e6f84c_0 67 | - pip=22.2.2=py38h06a4308_0 68 | - pixman=0.40.0=h36c2ea0_0 69 | - pycairo=1.22.0=py38h190342e_0 70 | - pyparsing=3.0.9=pyhd8ed1ab_0 71 | - python=3.8.15=h7a1cb2a_2 72 | - python-dateutil=2.8.2=pyhd8ed1ab_0 73 | - python_abi=3.8=2_cp38 74 | - pytorch=1.10.1=py3.8_cuda11.3_cudnn8.2.0_0 75 | - pytorch-mutex=1.0=cuda 76 | - pytz=2022.6=pyhd8ed1ab_0 77 | - rdkit=2022.03.2=py38ha829ea6_0 78 | - readline=8.2=h5eee18b_0 79 | - reportlab=3.5.68=py38hadf75a6_1 80 | - seaborn=0.12.2=py38h06a4308_0 81 | - setuptools=65.5.0=py38h06a4308_0 82 | - six=1.16.0=pyh6c4a22f_0 83 | - sqlalchemy=1.3.24=py38h0a891b7_1 84 | - sqlite=3.40.0=h5082296_0 85 | - tk=8.6.12=h1ccaba5_0 86 | - torchaudio=0.10.1=py38_cu113 87 | - torchvision=0.11.2=py38_cu113 88 | - tornado=6.2=py38h0a891b7_1 89 | - typing_extensions=4.4.0=pyha770c72_0 90 | - wheel=0.37.1=pyhd3eb1b0_0 91 | - xz=5.2.6=h5eee18b_0 92 | - zlib=1.2.13=h166bdaf_4 93 | - zstd=1.5.0=ha95c52a_0 94 | - pip: 95 | - beautifulsoup4==4.11.1 96 | - charset-normalizer==2.1.1 97 | - filelock==3.8.0 98 | - gdown==4.5.4 99 | - idna==3.4 100 | - jinja2==3.1.2 101 | - joblib==1.2.0 102 | - littleutils==0.2.2 103 | - markupsafe==2.1.1 104 | - munch==2.5.0 105 | - networkx==2.8 106 | - ogb==1.3.5 107 | - outdated==0.2.2 108 | - pysocks==1.7.1 109 | - requests==2.28.1 110 | - scikit-learn==1.1.3 111 | - scipy==1.9.3 112 | - soupsieve==2.3.2.post1 113 | - threadpoolctl==3.1.0 114 | - torch-geometric==2.1.0.post1 115 | - torch-scatter==2.0.9 116 | - torch-sparse==0.6.13 117 | - tqdm==4.64.1 118 | - urllib3==1.26.13 -------------------------------------------------------------------------------- /cam_util/aug.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import torch 5 | 6 | def drop_nodes(data): 7 | 8 | node_num, _ = data.x.size() 9 | _, edge_num = data.edge_index.size() 10 | drop_num = int(node_num / 10) 11 | 12 | idx_drop = np.random.choice(node_num, drop_num, replace=False) 13 | idx_nondrop = [n for n in range(node_num) if not n in idx_drop] 14 | idx_dict = {idx_nondrop[n]:n for n in list(range(node_num - drop_num))} 15 | 16 | # data.x = data.x[idx_nondrop] 17 | edge_index = data.edge_index.numpy() 18 | 19 | adj = torch.zeros((node_num, node_num)) 20 | adj[edge_index[0], edge_index[1]] = 1 21 | adj[idx_drop, :] = 0 22 | adj[:, idx_drop] = 0 23 | edge_index = adj.nonzero().t() 24 | 25 | data.edge_index = edge_index 26 | 27 | # edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] 28 | # edge_index = [[edge_index[0, n], edge_index[1, n]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop] 29 | # data.edge_index = torch.tensor(edge_index).transpose_(0, 1) 30 | 31 | return data 32 | 33 | 34 | def permute_edges(data): 35 | 36 | node_num, _ = data.x.size() 37 | _, edge_num = data.edge_index.size() 38 | permute_num = int(edge_num / 10) 39 | 40 | edge_index = data.edge_index.transpose(0, 1).numpy() 41 | mask = np.random.choice(edge_num, edge_num-permute_num, replace=False) 42 | edge_index_aug = edge_index[mask] 43 | data.edge_index = torch.tensor(edge_index_aug).transpose_(0, 1) 44 | 45 | if hasattr(data, "edge_attr") and data.edge_attr is not None: 46 | edge_attr = data.edge_attr.numpy() 47 | edge_attr_aug = edge_attr[mask] 48 | data.edge_attr = torch.tensor(edge_attr_aug) 49 | 50 | return data 51 | 52 | def subgraph(data): 53 | 54 | node_num, _ = data.x.size() 55 | _, edge_num = data.edge_index.size() 56 | sub_num = int(node_num * 0.2) 57 | 58 | edge_index = data.edge_index.numpy() 59 | 60 | idx_sub = [np.random.randint(node_num, size=1)[0]] 61 | idx_neigh = set([n for n in edge_index[1][edge_index[0]==idx_sub[0]]]) 62 | 63 | count = 0 64 | while len(idx_sub) <= sub_num: 65 | count = count + 1 66 | if count > node_num: 67 | break 68 | if len(idx_neigh) == 0: 69 | break 70 | sample_node = np.random.choice(list(idx_neigh)) 71 | if sample_node in idx_sub: 72 | continue 73 | idx_sub.append(sample_node) 74 | idx_neigh.union(set([n for n in edge_index[1][edge_index[0]==idx_sub[-1]]])) 75 | 76 | idx_drop = [n for n in range(node_num) if not n in idx_sub] 77 | idx_nondrop = idx_sub 78 | idx_dict = {idx_nondrop[n]:n for n in list(range(len(idx_nondrop)))} 79 | 80 | # data.x = data.x[idx_nondrop] 81 | edge_index = data.edge_index.numpy() 82 | 83 | adj = torch.zeros((node_num, node_num)) 84 | adj[edge_index[0], edge_index[1]] = 1 85 | adj[idx_drop, :] = 0 86 | adj[:, idx_drop] = 0 87 | edge_index = adj.nonzero().t() 88 | 89 | data.edge_index = edge_index 90 | 91 | 92 | 93 | # edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] 94 | # edge_index = [[edge_index[0, n], edge_index[1, n]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop] 95 | # data.edge_index = torch.tensor(edge_index).transpose_(0, 1) 96 | 97 | return data 98 | 99 | 100 | def mask_nodes(data): 101 | 102 | node_num, feat_dim = data.x.size() 103 | mask_num = int(node_num / 10) 104 | 105 | idx_mask = np.random.choice(node_num, mask_num, replace=False) 106 | data.x[idx_mask] = torch.tensor(np.random.normal(loc=0.5, scale=0.5, size=(mask_num, feat_dim)), dtype=torch.float32) 107 | 108 | return data -------------------------------------------------------------------------------- /cam_util/embedding_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.model_selection import GridSearchCV, KFold 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.multioutput import MultiOutputClassifier 6 | from sklearn.pipeline import make_pipeline 7 | from sklearn.preprocessing import StandardScaler 8 | from torch_geometric.loader import DataLoader 9 | 10 | 11 | def get_emb_y(loader, encoder, device, dtype='numpy', is_rand_label=False): 12 | x, y = encoder.get_embeddings(loader, device, is_rand_label) 13 | if dtype == 'numpy': 14 | return x,y 15 | elif dtype == 'torch': 16 | return torch.from_numpy(x).to(device), torch.from_numpy(y).to(device) 17 | else: 18 | raise NotImplementedError 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | class EmbeddingEvaluation(): 28 | def __init__(self, base_classifier, evaluator, task_type, num_tasks, device, params_dict=None, param_search=True,is_rand_label=False): 29 | self.is_rand_label = is_rand_label 30 | self.base_classifier = base_classifier 31 | self.evaluator = evaluator 32 | self.eval_metric = evaluator.eval_metric 33 | self.task_type = task_type 34 | self.num_tasks = num_tasks 35 | self.device = device 36 | self.param_search = param_search 37 | self.params_dict = params_dict 38 | if self.eval_metric == 'rmse': 39 | self.gscv_scoring_name = 'neg_root_mean_squared_error' 40 | elif self.eval_metric == 'mae': 41 | self.gscv_scoring_name = 'neg_mean_absolute_error' 42 | elif self.eval_metric == 'rocauc': 43 | self.gscv_scoring_name = 'roc_auc' 44 | elif self.eval_metric == 'accuracy': 45 | self.gscv_scoring_name = 'accuracy' 46 | else: 47 | raise ValueError('Undefined grid search scoring for metric %s ' % self.eval_metric) 48 | 49 | self.classifier = None 50 | def scorer(self, y_true, y_raw): 51 | input_dict = {"y_true": y_true, "y_pred": y_raw} 52 | score = self.evaluator.eval(input_dict)[self.eval_metric] 53 | return score 54 | 55 | def ee_binary_classification(self, train_emb, train_y, val_emb, val_y, test_emb, test_y): 56 | if self.param_search: 57 | params_dict = {'C': [0.001, 0.01,0.1,1,10,100,1000]} 58 | self.classifier = make_pipeline(StandardScaler(), 59 | GridSearchCV(self.base_classifier, params_dict, cv=5, scoring=self.gscv_scoring_name, n_jobs=16, verbose=0) 60 | ) 61 | else: 62 | self.classifier = make_pipeline(StandardScaler(), self.base_classifier) 63 | 64 | 65 | self.classifier.fit(train_emb, np.squeeze(train_y)) 66 | 67 | if self.eval_metric == 'accuracy': 68 | train_raw = self.classifier.predict(train_emb) 69 | val_raw = self.classifier.predict(val_emb) 70 | test_raw = self.classifier.predict(test_emb) 71 | else: 72 | train_raw = self.classifier.predict_proba(train_emb)[:, 1] 73 | val_raw = self.classifier.predict_proba(val_emb)[:, 1] 74 | test_raw = self.classifier.predict_proba(test_emb)[:, 1] 75 | 76 | return np.expand_dims(train_raw, axis=1), np.expand_dims(val_raw, axis=1), np.expand_dims(test_raw, axis=1) 77 | 78 | def ee_multioutput_binary_classification(self, train_emb, train_y, val_emb, val_y, test_emb, test_y): 79 | 80 | params_dict = { 81 | 'multioutputclassifier__estimator__C': [1e-1, 1e0, 1e1, 1e2]} 82 | self.classifier = make_pipeline(StandardScaler(), MultiOutputClassifier( 83 | self.base_classifier, n_jobs=-1)) 84 | 85 | if np.isnan(train_y).any(): 86 | print("Has NaNs ... ignoring them") 87 | train_y = np.nan_to_num(train_y) 88 | self.classifier.fit(train_emb, train_y) 89 | 90 | train_raw = np.transpose([y_pred[:, 1] for y_pred in self.classifier.predict_proba(train_emb)]) 91 | val_raw = np.transpose([y_pred[:, 1] for y_pred in self.classifier.predict_proba(val_emb)]) 92 | test_raw = np.transpose([y_pred[:, 1] for y_pred in self.classifier.predict_proba(test_emb)]) 93 | 94 | return train_raw, val_raw, test_raw 95 | 96 | def ee_regression(self, train_emb, train_y, val_emb, val_y, test_emb, test_y): 97 | if self.param_search: 98 | params_dict = {'alpha': [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5]} 99 | # params_dict = {'alpha': [500, 50, 5, 0.5, 0.05, 0.005, 0.0005]} 100 | self.classifier = GridSearchCV(self.base_classifier, params_dict, cv=5, 101 | scoring=self.gscv_scoring_name, n_jobs=16, verbose=0) 102 | else: 103 | self.classifier = self.base_classifier 104 | 105 | self.classifier.fit(train_emb, np.squeeze(train_y)) 106 | 107 | train_raw = self.classifier.predict(train_emb) 108 | val_raw = self.classifier.predict(val_emb) 109 | test_raw = self.classifier.predict(test_emb) 110 | 111 | return np.expand_dims(train_raw, axis=1), np.expand_dims(val_raw, axis=1), np.expand_dims(test_raw, axis=1) 112 | 113 | def embedding_evaluation(self, encoder, train_loader, valid_loader, test_loader): 114 | encoder.eval() 115 | train_emb, train_y = get_emb_y(train_loader, encoder, self.device, is_rand_label=self.is_rand_label) 116 | val_emb, val_y = get_emb_y(valid_loader, encoder, self.device, is_rand_label=self.is_rand_label) 117 | test_emb, test_y = get_emb_y(test_loader, encoder, self.device, is_rand_label=self.is_rand_label) 118 | if 'classification' in self.task_type: 119 | 120 | if self.num_tasks == 1: 121 | train_raw, val_raw, test_raw = self.ee_binary_classification(train_emb, train_y, val_emb, val_y, test_emb, 122 | test_y) 123 | elif self.num_tasks > 1: 124 | train_raw, val_raw, test_raw = self.ee_multioutput_binary_classification(train_emb, train_y, val_emb, val_y, 125 | test_emb, test_y) 126 | else: 127 | raise NotImplementedError 128 | else: 129 | if self.num_tasks == 1: 130 | train_raw, val_raw, test_raw = self.ee_regression(train_emb, train_y, val_emb, val_y, test_emb, test_y) 131 | else: 132 | raise NotImplementedError 133 | 134 | train_score = self.scorer(train_y, train_raw) 135 | 136 | val_score = self.scorer(val_y, val_raw) 137 | 138 | test_score = self.scorer(test_y, test_raw) 139 | 140 | return train_score, val_score, test_score 141 | 142 | def kf_embedding_evaluation(self, encoder, dataset, folds=10, batch_size=128): 143 | kf_train = [] 144 | kf_val = [] 145 | kf_test = [] 146 | 147 | kf = KFold(n_splits=folds, shuffle=True, random_state=None) 148 | for k_id, (train_val_index, test_index) in enumerate(kf.split(dataset)): 149 | test_dataset = [dataset[int(i)] for i in list(test_index)] 150 | train_index, val_index = train_test_split(train_val_index, test_size=0.2, random_state=None) 151 | 152 | train_dataset = [dataset[int(i)] for i in list(train_index)] 153 | val_dataset = [dataset[int(i)] for i in list(val_index)] 154 | 155 | train_loader = DataLoader(train_dataset, batch_size=batch_size) 156 | valid_loader = DataLoader(val_dataset, batch_size=batch_size) 157 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 158 | 159 | train_score, val_score, test_score = self.embedding_evaluation(encoder, train_loader, valid_loader, test_loader) 160 | 161 | kf_train.append(train_score) 162 | kf_val.append(val_score) 163 | kf_test.append(test_score) 164 | 165 | return np.array(kf_train).mean(), np.array(kf_val).mean(), np.array(kf_test).mean() -------------------------------------------------------------------------------- /trasnfer_pretrain.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import argparse 6 | from datasets.transfer_mol_dataset import MoleculeDataset 7 | from munch import Munch 8 | 9 | from torch_geometric.loader import DataLoader 10 | from torch_geometric.transforms import Compose 11 | 12 | from early_stop import EarlyStopping 13 | from datetime import datetime 14 | import os 15 | import shutil 16 | from tqdm import tqdm 17 | import json 18 | 19 | import torch 20 | from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR 21 | from torch import optim 22 | 23 | from model.GCS_transfer import Model 24 | from cam_util.utils import initialize_edge_weight 25 | from util.utils import scaffold_split 26 | 27 | 28 | def arg_parse(): 29 | str2bool = lambda x: x.lower() == "true" 30 | parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics') 31 | 32 | parser.add_argument('--dataset_name', type=str, default='chembl_filtered', help='dataset name') 33 | parser.add_argument('--dataset_root', type=str, default='storage/datasets', help='dataset dir') 34 | 35 | parser.add_argument('--cuda_device', type=str, default='3') 36 | parser.add_argument('--num_workers', type=int, default=8) 37 | 38 | parser.add_argument('--batch_size', type=int, default=256) 39 | parser.add_argument('--epochs', type=int, default=100) 40 | 41 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 42 | 43 | parser.add_argument('--lr_decay', type=int, default=30) 44 | parser.add_argument('--lr_gamma', type=float, default=0.1) 45 | parser.add_argument('--lr_scheduler', type=str, default='none', help='cos, step') 46 | parser.add_argument('--milestones', nargs='+', type=int, default=[40,60,80]) 47 | 48 | parser.add_argument('--warm_ratio', type=float, default=0.1, help='Number epochs to start cam contrast') 49 | parser.add_argument('--inner_iter', type=int, default=1, help='Number epochs to start cam contrast') 50 | parser.add_argument('--num_gc_layers', type=int, default=5, help='Number of GNN layers before pooling') 51 | parser.add_argument('--emb_dim', type=int, default=300) 52 | parser.add_argument('--drop_ratio', type=float, default=0.0, help='Dropout Ratio / Probability') 53 | parser.add_argument('--thres', type=float, default=0.1, help='0 to 1 for controlling the node to drop') 54 | parser.add_argument('--JK', type=str, default="last", help='how the node features across layers are combined. last, sum, max or concat') 55 | parser.add_argument('--gnn_type', type=str, default="gin") 56 | # parser.add_argument('--graph_pooling', type=str, default="mean", help='graph level pooling (sum, mean, max, set2set, attention)') 57 | 58 | parser.add_argument('--note', type=str, default='pretrain', help='note to record') 59 | 60 | parser.add_argument('--trails', type=int, default=1, help='number of runs (default: 0)') 61 | parser.add_argument('--seed', type=int, default=618) 62 | 63 | parser.add_argument("--loadFilename", type=str, default=None) 64 | 65 | 66 | args = parser.parse_args() 67 | return args 68 | 69 | 70 | def set_seed(seed): 71 | # Fix Random seed 72 | random.seed(seed) 73 | np.random.seed(seed) 74 | torch.manual_seed(seed) 75 | torch.cuda.manual_seed(seed) 76 | torch.cuda.manual_seed_all(seed) 77 | torch.backends.cudnn.deterministic = True 78 | torch.backends.cudnn.CEX = False 79 | 80 | 81 | def directory_name_generate(model, note): 82 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 83 | directory = "data/{}".format(model.name) 84 | directory = os.path.join(directory, current_time) 85 | directory = directory + '__' + note 86 | return directory 87 | 88 | 89 | 90 | def load_data(opt): 91 | my_transforms = Compose([initialize_edge_weight]) 92 | dataset = MoleculeDataset(opt.dataset_root + "/transfer_dataset/"+opt.dataset_name, dataset=opt.dataset_name, 93 | transform=my_transforms) 94 | print(dataset) 95 | train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) 96 | 97 | meta_info = Munch() 98 | meta_info.dataset_type = 'molecule' 99 | meta_info.model_level = 'graph' 100 | 101 | 102 | return train_loader, dataset, meta_info 103 | 104 | 105 | def train(opt): 106 | train_loader, dataset, meta_info = load_data(opt) 107 | 108 | device = torch.device("cuda:{0}".format(opt.cuda_device)) 109 | model = Model(meta_info, opt, device) 110 | model = model.to(device) 111 | 112 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 113 | 114 | if opt.lr_scheduler == 'step': 115 | scheduler = StepLR(optimizer, step_size=opt.lr_decay, gamma=opt.lr_gamma) 116 | elif opt.lr_scheduler == 'multi': 117 | scheduler = MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.lr_gamma) 118 | elif opt.lr_scheduler == 'cos': 119 | scheduler = CosineAnnealingLR(optimizer, T_max=opt.epochs) 120 | else: 121 | scheduler = MultiStepLR(optimizer, milestones=[99999999999999], gamma=opt.lr_gamma) 122 | 123 | start_epoch = 0 124 | 125 | if opt.loadFilename != None: 126 | checkpoint = torch.load(opt.loadFilename) 127 | sd = checkpoint['sd'] 128 | 129 | opt_sd = checkpoint['opt'] 130 | 131 | start_epoch = checkpoint['epoch'] + 1 132 | scheduler_sd = checkpoint['sche'] 133 | 134 | model.load_state_dict(sd) 135 | optimizer.load_state_dict(opt_sd) 136 | 137 | scheduler.load_state_dict(scheduler_sd) 138 | 139 | directory = directory_name_generate(model, opt.note) 140 | 141 | stop_manager = EarlyStopping(directory, patience=100) 142 | for epoch in range(start_epoch, opt.epochs): 143 | model.train() 144 | 145 | show = int(float(len(train_loader)) / 2.0) 146 | with tqdm(total=len(train_loader), desc="epoch"+str(epoch)) as pbar: 147 | for index, batch in enumerate(train_loader): 148 | 149 | batch = batch.to(device) 150 | model.train() 151 | 152 | loss = model(batch, epoch/opt.epochs) 153 | 154 | optimizer.zero_grad() 155 | loss.backward() 156 | optimizer.step() 157 | 158 | pbar.update(1) 159 | if index % show == 0: 160 | print("Train Iter:[{:<3}/{}], Model_Loss:[{:.4f}]".format(index, len(train_loader), loss)) 161 | 162 | scheduler.step() 163 | 164 | save_dic = { 165 | 'epoch': epoch, 166 | 'sd': model.state_dict(), 167 | 'opt': optimizer.state_dict(), 168 | 'sche': scheduler.state_dict(), 169 | } 170 | 171 | stop_manager(epoch, save_dic) 172 | 173 | 174 | 175 | torch.save({ 176 | 'sd': model.state_dict(), 177 | }, os.path.join(directory, 'latest.tar')) 178 | 179 | with open(os.path.join(directory, 'model_arg.json'), 'wt') as f: 180 | json.dump(vars(opt), f, indent=4) 181 | 182 | shutil.rmtree(directory + '/buffer') 183 | 184 | 185 | if __name__ == "__main__": 186 | opt = arg_parse() 187 | set_seed(opt.seed) 188 | total = [] 189 | for _ in range(opt.trails): 190 | train(opt) 191 | 192 | print(total) 193 | print(opt) -------------------------------------------------------------------------------- /unsupervised.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import argparse 6 | from cam_util.dataset import TUDataset, TUEvaluator 7 | from munch import Munch 8 | 9 | from torch_geometric.loader import DataLoader 10 | from torch_geometric.transforms import Compose 11 | 12 | from early_stop import EarlyStopping 13 | from datetime import datetime 14 | import os 15 | import shutil 16 | from tqdm import tqdm 17 | import json 18 | 19 | import torch 20 | from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR 21 | from torch import optim 22 | 23 | from model.GCS import Model 24 | from cam_util.embedding_evaluation import EmbeddingEvaluation 25 | from sklearn.svm import LinearSVC, SVC 26 | from cam_util.utils import initialize_edge_weight, initialize_node_features, set_tu_dataset_y_shape 27 | 28 | 29 | 30 | def arg_parse(): 31 | str2bool = lambda x: x.lower() == "true" 32 | parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics') 33 | 34 | parser.add_argument('--dataset_name', type=str, default='PROTEINS', help='dataset name') 35 | parser.add_argument('--dataset_root', type=str, default='storage/datasets', help='dataset dir') 36 | 37 | parser.add_argument('--cuda_device', type=str, default='0') 38 | 39 | parser.add_argument('--batch_size', type=int, default=128) 40 | parser.add_argument('--epochs', type=int, default=50) 41 | 42 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 43 | 44 | parser.add_argument('--lr_decay', type=int, default=30) 45 | parser.add_argument('--lr_gamma', type=float, default=0.1) 46 | parser.add_argument('--lr_scheduler', type=str, default='none', help='cos, step') 47 | parser.add_argument('--milestones', nargs='+', type=int, default=[40,60,80]) 48 | 49 | parser.add_argument('--warm_ratio', type=float, default=0, help='Number epochs to start cam contrast') 50 | parser.add_argument('--inner_iter', type=int, default=3, help='Number epochs to start cam contrast') 51 | parser.add_argument('--num_gc_layers', type=int, default=5, help='Number of GNN layers before pooling') 52 | parser.add_argument('--pooling_type', type=str, default='layerwise', help='GNN Pooling Type Standard/Layerwise') 53 | parser.add_argument('--emb_dim', type=int, default=32) 54 | parser.add_argument('--drop_ratio', type=float, default=0.5, help='Dropout Ratio / Probability') 55 | parser.add_argument('--thres', type=float, default=0.3, help='0 to 1 for controlling the node to drop') 56 | parser.add_argument('--downstream_classifier', type=str, default="non-linear", help="Downstream classifier is linear or non-linear") 57 | 58 | parser.add_argument('--note', type=str, default='', help='note to record') 59 | 60 | parser.add_argument('--trails', type=int, default=10, help='number of runs (default: 0)') 61 | parser.add_argument('--seed', type=int, default=618) 62 | 63 | parser.add_argument("--loadFilename", type=str, default=None) 64 | 65 | 66 | args = parser.parse_args() 67 | return args 68 | 69 | 70 | def set_seed(seed): 71 | # Fix Random seed 72 | random.seed(seed) 73 | np.random.seed(seed) 74 | torch.manual_seed(seed) 75 | torch.cuda.manual_seed(seed) 76 | torch.cuda.manual_seed_all(seed) 77 | torch.backends.cudnn.deterministic = True 78 | torch.backends.cudnn.CEX = False 79 | 80 | 81 | def directory_name_generate(model, note): 82 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 83 | directory = "data/{}".format(model.name) 84 | directory = os.path.join(directory, current_time) 85 | directory = directory + '__' + note 86 | return directory 87 | 88 | 89 | 90 | def load_data(opt): 91 | my_transforms = Compose([initialize_node_features, initialize_edge_weight, set_tu_dataset_y_shape]) 92 | dataset = TUDataset("storage/datasets", opt.dataset_name, transform=my_transforms) 93 | 94 | train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True) 95 | test_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False) 96 | 97 | meta_info = Munch() 98 | meta_info.dataset_type = 'real' 99 | meta_info.model_level = 'graph' 100 | 101 | meta_info.dim_node = dataset.num_node_attributes 102 | meta_info.dim_edge = dataset.num_edge_attributes 103 | 104 | return train_loader, dataset, meta_info 105 | 106 | 107 | def model_eval(model, dataset, device, eval_type): 108 | model.eval() 109 | evaluator = TUEvaluator() 110 | if eval_type == "linear": 111 | ee = EmbeddingEvaluation(LinearSVC(dual=False, fit_intercept=True), evaluator, dataset.task_type, dataset.num_tasks, 112 | device, param_search=True) 113 | else: 114 | ee = EmbeddingEvaluation(SVC(), evaluator, dataset.task_type, 115 | dataset.num_tasks, 116 | device, param_search=True) 117 | 118 | train_score, val_score, test_score = ee.kf_embedding_evaluation(model.encoder, dataset) 119 | return train_score, val_score, test_score 120 | 121 | 122 | def train(opt): 123 | train_loader, dataset, meta_info = load_data(opt) 124 | 125 | device = torch.device("cuda:{0}".format(opt.cuda_device)) 126 | model = Model(meta_info, opt, device) 127 | model = model.to(device) 128 | 129 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 130 | 131 | if opt.lr_scheduler == 'step': 132 | scheduler = StepLR(optimizer, step_size=opt.lr_decay, gamma=opt.lr_gamma) 133 | elif opt.lr_scheduler == 'multi': 134 | scheduler = MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.lr_gamma) 135 | elif opt.lr_scheduler == 'cos': 136 | scheduler = CosineAnnealingLR(optimizer, T_max=opt.epochs) 137 | else: 138 | scheduler = MultiStepLR(optimizer, milestones=[99999999999999], gamma=opt.lr_gamma) 139 | 140 | start_epoch = 0 141 | 142 | if opt.loadFilename != None: 143 | checkpoint = torch.load(opt.loadFilename) 144 | sd = checkpoint['sd'] 145 | 146 | opt_sd = checkpoint['opt'] 147 | 148 | start_epoch = checkpoint['epoch'] + 1 149 | scheduler_sd = checkpoint['sche'] 150 | 151 | model.load_state_dict(sd) 152 | optimizer.load_state_dict(opt_sd) 153 | 154 | scheduler.load_state_dict(scheduler_sd) 155 | 156 | directory = directory_name_generate(model, opt.note) 157 | 158 | stop_manager = EarlyStopping(directory, patience=100) 159 | for epoch in range(start_epoch, opt.epochs): 160 | model.train() 161 | 162 | show = int(float(len(train_loader)) / 2.0) 163 | with tqdm(total=len(train_loader), desc="epoch"+str(epoch)) as pbar: 164 | for index, batch in enumerate(train_loader): 165 | 166 | batch = batch.to(device) 167 | model.train() 168 | 169 | loss = model(batch, epoch/opt.epochs) 170 | 171 | optimizer.zero_grad() 172 | loss.backward() 173 | optimizer.step() 174 | 175 | pbar.update(1) 176 | if index % show == 0: 177 | print("Train Iter:[{:<3}/{}], Model_Loss:[{:.4f}]".format(index, len(train_loader), loss)) 178 | 179 | scheduler.step() 180 | 181 | train_score, val_score, test_score = model_eval(model, dataset, device, opt.downstream_classifier) 182 | print("Epoch:[{}/{}], valid:[{:.8f}]".format(epoch, opt.epochs, val_score)) 183 | 184 | save_dic = { 185 | 'epoch': epoch, 186 | 'sd': model.state_dict(), 187 | 'opt': optimizer.state_dict(), 188 | 'sche': scheduler.state_dict(), 189 | } 190 | stop_manager(val_score, save_dic) 191 | if stop_manager.early_stop: 192 | print("Early stopping") 193 | break 194 | 195 | 196 | ####### final test ########### 197 | train_score, val_score, test_score = model_eval(model, dataset, device, opt.downstream_classifier) 198 | print("Final_test {}".format(test_score)) 199 | 200 | best_checkpoint = torch.load(os.path.join( 201 | directory + '/buffer', '{}_{}.tar'.format(stop_manager.best_score, 'checkpoint'))) 202 | torch.save({ 203 | 'sd': best_checkpoint['sd'], 204 | }, os.path.join(directory, 'best_{}.tar'.format(stop_manager.best_score))) 205 | torch.save({ 206 | 'sd': model.state_dict(), 207 | }, os.path.join(directory, 'latest.tar')) 208 | 209 | with open(os.path.join(directory, 'model_arg.json'), 'wt') as f: 210 | json.dump(vars(opt), f, indent=4) 211 | 212 | shutil.rmtree(directory + '/buffer') 213 | 214 | 215 | ######## best test ############ 216 | model.load_state_dict(best_checkpoint['sd']) 217 | train_score, val_score, test_score = model_eval(model, dataset, device, opt.downstream_classifier) 218 | print("Best_test {}".format(test_score)) 219 | print(dataset.name) 220 | print(model.name) 221 | return test_score 222 | 223 | if __name__ == "__main__": 224 | opt = arg_parse() 225 | set_seed(opt.seed) 226 | total = [] 227 | for _ in range(opt.trails): 228 | test_score = train(opt) 229 | total.append(test_score) 230 | 231 | print(total) 232 | print(opt) -------------------------------------------------------------------------------- /datasets/tu_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import accuracy_score 8 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip 9 | from torch_geometric.io import read_tu_data 10 | 11 | 12 | class TUDataset(InMemoryDataset): 13 | r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY", 14 | "REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University 15 | `_. 16 | In addition, this dataset wrapper provides `cleaned dataset versions 17 | `_ as motivated by the 18 | `"Understanding Isomorphism Bias in Graph Data Sets" 19 | `_ paper, containing only non-isomorphic 20 | graphs. 21 | 22 | .. note:: 23 | Some datasets may not come with any node labels. 24 | You can then either make use of the argument :obj:`use_node_attr` 25 | to load additional continuous node attributes (if present) or provide 26 | synthetic node features using transforms such as 27 | like :class:`torch_geometric.transforms.Constant` or 28 | :class:`torch_geometric.transforms.OneHotDegree`. 29 | 30 | Args: 31 | root (string): Root directory where the dataset should be saved. 32 | name (string): The `name 33 | `_ of the 34 | dataset. 35 | transform (callable, optional): A function/transform that takes in an 36 | :obj:`torch_geometric.data.Data` object and returns a transformed 37 | version. The data object will be transformed before every access. 38 | (default: :obj:`None`) 39 | pre_transform (callable, optional): A function/transform that takes in 40 | an :obj:`torch_geometric.data.Data` object and returns a 41 | transformed version. The data object will be transformed before 42 | being saved to disk. (default: :obj:`None`) 43 | pre_filter (callable, optional): A function that takes in an 44 | :obj:`torch_geometric.data.Data` object and returns a boolean 45 | value, indicating whether the data object should be included in the 46 | final dataset. (default: :obj:`None`) 47 | use_node_attr (bool, optional): If :obj:`True`, the dataset will 48 | contain additional continuous node attributes (if present). 49 | (default: :obj:`False`) 50 | use_edge_attr (bool, optional): If :obj:`True`, the dataset will 51 | contain additional continuous edge attributes (if present). 52 | (default: :obj:`False`) 53 | cleaned: (bool, optional): If :obj:`True`, the dataset will 54 | contain only non-isomorphic graphs. (default: :obj:`False`) 55 | """ 56 | 57 | url = 'https://www.chrsmrrs.com/graphkerneldatasets' 58 | cleaned_url = ('https://raw.githubusercontent.com/nd7141/' 59 | 'graph_datasets/master/datasets') 60 | 61 | def __init__(self, root, name, transform=None, pre_transform=None, 62 | pre_filter=None, use_node_attr=False, use_edge_attr=False, 63 | cleaned=False): 64 | self.name = name 65 | self.cleaned = cleaned 66 | self.num_tasks = 1 67 | self.task_type = 'classification' 68 | self.eval_metric = 'accuracy' 69 | super(TUDataset, self).__init__(root, transform, pre_transform, 70 | pre_filter) 71 | self.data, self.slices = torch.load(self.processed_paths[0]) 72 | if self.data.x is not None and not use_node_attr: 73 | num_node_attributes = self.num_node_attributes 74 | self.data.x = self.data.x[:, num_node_attributes:] 75 | if self.data.edge_attr is not None and not use_edge_attr: 76 | num_edge_attributes = self.num_edge_attributes 77 | self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:] 78 | 79 | @property 80 | def raw_dir(self): 81 | name = 'raw{}'.format('_cleaned' if self.cleaned else '') 82 | return osp.join(self.root, self.name, name) 83 | 84 | @property 85 | def processed_dir(self): 86 | name = 'processed{}'.format('_cleaned' if self.cleaned else '') 87 | return osp.join(self.root, self.name, name) 88 | 89 | @property 90 | def num_node_labels(self): 91 | if self.data.x is None: 92 | return 0 93 | for i in range(self.data.x.size(1)): 94 | x = self.data.x[:, i:] 95 | if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all(): 96 | return self.data.x.size(1) - i 97 | return 0 98 | 99 | @property 100 | def num_node_attributes(self): 101 | if self.data.x is None: 102 | return 0 103 | return self.data.x.size(1) - self.num_node_labels 104 | 105 | @property 106 | def num_edge_labels(self): 107 | if self.data.edge_attr is None: 108 | return 0 109 | for i in range(self.data.edge_attr.size(1)): 110 | if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0): 111 | return self.data.edge_attr.size(1) - i 112 | return 0 113 | 114 | @property 115 | def num_edge_attributes(self): 116 | if self.data.edge_attr is None: 117 | return 0 118 | return self.data.edge_attr.size(1) - self.num_edge_labels 119 | 120 | @property 121 | def raw_file_names(self): 122 | names = ['A', 'graph_indicator'] 123 | return ['{}_{}.txt'.format(self.name, name) for name in names] 124 | 125 | @property 126 | def processed_file_names(self): 127 | return 'data.pt' 128 | 129 | def download(self): 130 | url = self.cleaned_url if self.cleaned else self.url 131 | folder = osp.join(self.root, self.name) 132 | path = download_url('{}/{}.zip'.format(url, self.name), folder) 133 | extract_zip(path, folder) 134 | os.unlink(path) 135 | shutil.rmtree(self.raw_dir) 136 | os.rename(osp.join(folder, self.name), self.raw_dir) 137 | 138 | def process(self): 139 | self.data, self.slices = read_tu_data(self.raw_dir, self.name) 140 | 141 | if self.pre_filter is not None: 142 | data_list = [self.get(idx) for idx in range(len(self))] 143 | data_list = [data for data in data_list if self.pre_filter(data)] 144 | self.data, self.slices = self.collate(data_list) 145 | 146 | if self.pre_transform is not None: 147 | data_list = [self.get(idx) for idx in range(len(self))] 148 | data_list = [self.pre_transform(data) for data in data_list] 149 | self.data, self.slices = self.collate(data_list) 150 | 151 | torch.save((self.data, self.slices), self.processed_paths[0]) 152 | 153 | def __repr__(self): 154 | return '{}({})'.format(self.name, len(self)) 155 | 156 | class TUEvaluator: 157 | def __init__(self): 158 | self.num_tasks = 1 159 | self.eval_metric = 'accuracy' 160 | 161 | def _parse_and_check_input(self, input_dict): 162 | if self.eval_metric == 'accuracy': 163 | if not 'y_true' in input_dict: 164 | raise RuntimeError('Missing key of y_true') 165 | if not 'y_pred' in input_dict: 166 | raise RuntimeError('Missing key of y_pred') 167 | 168 | y_true, y_pred = input_dict['y_true'], input_dict['y_pred'] 169 | 170 | ''' 171 | y_true: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 172 | y_pred: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 173 | ''' 174 | 175 | # converting to torch.Tensor to numpy on cpu 176 | if torch is not None and isinstance(y_true, torch.Tensor): 177 | y_true = y_true.detach().cpu().numpy() 178 | 179 | if torch is not None and isinstance(y_pred, torch.Tensor): 180 | y_pred = y_pred.detach().cpu().numpy() 181 | 182 | ## check type 183 | if not (isinstance(y_true, np.ndarray) and isinstance(y_true, np.ndarray)): 184 | raise RuntimeError('Arguments to Evaluator need to be either numpy ndarray or torch tensor') 185 | 186 | if not y_true.shape == y_pred.shape: 187 | raise RuntimeError('Shape of y_true and y_pred must be the same') 188 | 189 | if not y_true.ndim == 2: 190 | raise RuntimeError('y_true and y_pred mush to 2-dim arrray, {}-dim array given'.format(y_true.ndim)) 191 | 192 | if not y_true.shape[1] == self.num_tasks: 193 | raise RuntimeError('Number of tasks should be {} but {} given'.format(self.num_tasks, 194 | y_true.shape[1])) 195 | 196 | return y_true, y_pred 197 | else: 198 | raise ValueError('Undefined eval metric %s ' % self.eval_metric) 199 | 200 | def _eval_accuracy(self, y_true, y_pred): 201 | ''' 202 | compute Accuracy score averaged across tasks 203 | ''' 204 | acc_list = [] 205 | 206 | for i in range(y_true.shape[1]): 207 | # ignore nan values 208 | is_labeled = y_true[:, i] == y_true[:, i] 209 | acc = accuracy_score(y_true[is_labeled], y_pred[is_labeled]) 210 | acc_list.append(acc) 211 | 212 | return {'accuracy': sum(acc_list) / len(acc_list)} 213 | 214 | def eval(self, input_dict): 215 | y_true, y_pred = self._parse_and_check_input(input_dict) 216 | return self._eval_accuracy(y_true, y_pred) -------------------------------------------------------------------------------- /model/GCS_transfer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Sequential, Linear, ReLU 5 | from torch_geometric.nn import global_mean_pool 6 | 7 | from torch_scatter import scatter, scatter_max, scatter_min, scatter_add 8 | 9 | from cam_util.aug import permute_edges 10 | from cam_util.chem_gnn import GNN 11 | 12 | from copy import deepcopy 13 | from torch_geometric.utils import subgraph 14 | 15 | def normalize(cam, batch, eps=1e-20): 16 | cam = cam.clone() 17 | # batch_num 18 | batch_max, _ = scatter_max(cam.squeeze(), batch) 19 | batch_min, _ = scatter_min(cam.squeeze(), batch) 20 | batch_max_expand = [] 21 | batch_min_expand = [] 22 | for i in batch: 23 | batch_max_expand.append(batch_max[i]) 24 | batch_min_expand.append(batch_min[i]) 25 | 26 | batch_max_expand = torch.tensor(batch_max_expand).unsqueeze(1).to(cam.device) 27 | batch_min_expand = torch.tensor(batch_min_expand).unsqueeze(1).to(cam.device) 28 | normalized_cam = (cam - batch_min_expand) / (batch_max_expand + eps) 29 | normalized_cam = normalized_cam.clamp_min(0) 30 | normalized_cam = normalized_cam.clamp_max(1) 31 | 32 | return normalized_cam 33 | 34 | def reset(nn): 35 | def _reset(item): 36 | if hasattr(item, 'reset_parameters'): 37 | item.reset_parameters() 38 | 39 | if nn is not None: 40 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 41 | for item in nn.children(): 42 | _reset(item) 43 | else: 44 | _reset(nn) 45 | 46 | 47 | 48 | class Model(torch.nn.Module): 49 | def __init__(self, meta_info, opt, device): 50 | super(Model, self).__init__() 51 | self.name = 'GCS_for_transfer' 52 | self.device = device 53 | self.warm_ratio = opt.warm_ratio 54 | self.inner_iter = opt.inner_iter 55 | self.thres = opt.thres 56 | 57 | self.encoder = GNN(num_layer=opt.num_gc_layers, emb_dim=opt.emb_dim, JK=opt.JK, drop_ratio=opt.drop_ratio, gnn_type=opt.gnn_type) 58 | self.pool = global_mean_pool 59 | self.proj_head = Sequential(Linear(self.encoder.emb_dim, opt.emb_dim), ReLU(inplace=True), Linear(opt.emb_dim, opt.emb_dim)) 60 | 61 | self.init_emb() 62 | 63 | def init_emb(self): 64 | for m in self.modules(): 65 | if isinstance(m, Linear): 66 | torch.nn.init.xavier_uniform_(m.weight.data) 67 | if m.bias is not None: 68 | m.bias.data.fill_(0.0) 69 | 70 | def _contrastive_score(self, query, key, queues): 71 | eye = torch.eye(query.size(0)).type_as(query) 72 | pos = torch.einsum('nc,nc->n', [query, key]).unsqueeze(0) 73 | neg = torch.cat([torch.einsum('nc,kc->nk', [query, queue]) * (1 - eye) for queue in queues], dim=1) 74 | score = (pos.exp().sum(dim=1) / neg.exp().sum(dim=1)).log() 75 | return score 76 | 77 | 78 | def _compute_cam(self, feature, score, batch, clamp_negative_weights=True): 79 | # feature (batch_nodes, embedding_dim) 80 | grad = torch.autograd.grad(score.sum(), feature)[0] 81 | 82 | # (batch_nodes, 1) 83 | weight = torch.mean(grad, dim=-1, keepdim=True) 84 | 85 | if clamp_negative_weights: # positive weights only 86 | weight = weight.clamp_min(0) 87 | 88 | # (batch_nodes, 1) 89 | cam = torch.sum(weight * feature, dim=1, keepdim=True).detach() 90 | 91 | normalized_cam = normalize(cam, batch).squeeze().detach() 92 | return normalized_cam 93 | 94 | 95 | def get_features(self, batch, x, edge_index, edge_attr, edge_weight, keep_node): 96 | if keep_node is not None: 97 | edge_index, edge_attr, edge_mask = subgraph(keep_node, edge_index, edge_attr, return_edge_mask=True) 98 | edge_weight = torch.masked_select(edge_weight, edge_mask.to(edge_weight.device)) 99 | new_x = torch.zeros(x.shape).long().to(x.device) 100 | new_x[keep_node] = x[keep_node] 101 | x = new_x 102 | 103 | node_emb = self.encoder(x, edge_index, edge_attr, edge_weight) 104 | z = self.pool(node_emb, batch) 105 | return z, node_emb 106 | 107 | def get_projection(self, z): 108 | return self.proj_head(z) 109 | 110 | def get_contrastive_cam(self, batch, n_iters=1, return_intermediate=False): 111 | key, queues = None, [] 112 | _masks, _masked_images = [], [] 113 | 114 | mask_edge = torch.zeros(batch.edge_index.shape[1]) + 1e-10 115 | keep_indice = torch.arange(batch.x.shape[0]) 116 | mask_edge_list = [] 117 | keep_node_list = [] 118 | 119 | for it in range(n_iters): 120 | z, node_emb = self.get_features(batch.batch, batch.x, batch.edge_index, batch.edge_attr, (1 - mask_edge).to(self.device), keep_indice) 121 | output = self.get_projection(z) 122 | 123 | if it == 0: 124 | key = output # original graph 125 | # queues.append(output.detach()) # masked images 126 | 127 | # score = self._contrastive_score(output, key, queues) 128 | score = self.calc_loss(key, output) 129 | 130 | # (batch_nodes, 1) 131 | node_cam = self._compute_cam(node_emb, score, batch.batch, clamp_negative_weights=True) 132 | mask_node = torch.max(mask_node, node_cam) if it > 0 else node_cam 133 | mask_node = mask_node.detach() 134 | indicater = torch.where(mask_node > self.thres, 1, 0) 135 | keep_node_list.append(indicater) 136 | 137 | 138 | src, dst = batch.edge_index[0], batch.edge_index[1] 139 | # batch_edge, 1 140 | edge_cam = (node_cam[src] + node_cam[dst]) /2 141 | mask_edge = torch.max(mask_edge, edge_cam) if it > 0 else edge_cam 142 | mask_edge = mask_edge.detach() 143 | edge_indicater = torch.where(mask_edge > self.thres, 1, 0) 144 | mask_edge_list.append(edge_indicater) 145 | 146 | return keep_node_list, mask_edge_list 147 | 148 | 149 | 150 | def contrast_train(self, batch, keep_node=None, mask_edge=None): 151 | if keep_node is None and mask_edge is None: 152 | aug = permute_edges(deepcopy(batch).cpu()).to(self.device) 153 | z_aug, _ = self.get_features(aug.batch, aug.x, aug.edge_index, aug.edge_attr, None, None) 154 | x_aug = self.get_projection(z_aug) 155 | else: 156 | 157 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, batch.edge_attr, mask_edge, keep_node) 158 | x_aug = self.get_projection(z_aug) 159 | 160 | z, _ = self.get_features(batch.batch, batch.x, batch.edge_index, batch.edge_attr, None, None) 161 | x = self.get_projection(z) 162 | 163 | contrast_loss = self.calc_loss(x, x_aug) 164 | 165 | return contrast_loss 166 | 167 | 168 | def positive(self, batch, indicater, mask_edge): 169 | # view1 170 | env_indicator = indicater.new_ones(indicater.shape) * 0.5 171 | env_indicator = torch.bernoulli(env_indicator) 172 | keep_node = torch.nonzero(indicater + env_indicator, as_tuple=False).view(-1,) 173 | 174 | edge_env_indicator = mask_edge.new_ones(mask_edge.shape) * 0.5 175 | edge_env_indicator = torch.bernoulli(edge_env_indicator) 176 | new_mask_edge = mask_edge + edge_env_indicator 177 | new_mask_edge = new_mask_edge.clamp_max(1) 178 | 179 | new_mask_edge = new_mask_edge.to(self.device) 180 | keep_node = keep_node.to(self.device) 181 | 182 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, new_mask_edge, keep_node) 183 | x_aug_1 = self.get_projection(z_aug) 184 | 185 | 186 | # view2 187 | env_indicator = indicater.new_ones(indicater.shape) * 0.5 188 | env_indicator = torch.bernoulli(env_indicator) 189 | keep_node = torch.nonzero(indicater + env_indicator, as_tuple=False).view(-1,) 190 | 191 | edge_env_indicator = mask_edge.new_ones(mask_edge.shape) * 0.5 192 | edge_env_indicator = torch.bernoulli(edge_env_indicator) 193 | new_mask_edge = mask_edge + edge_env_indicator 194 | new_mask_edge = new_mask_edge.clamp_max(1) 195 | 196 | new_mask_edge = new_mask_edge.to(self.device) 197 | keep_node = keep_node.to(self.device) 198 | 199 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, new_mask_edge, keep_node) 200 | x_aug_2 = self.get_projection(z_aug) 201 | 202 | contrast_loss = self.calc_loss(x_aug_1, x_aug_2) 203 | 204 | return contrast_loss 205 | 206 | 207 | def negative(self, batch, indicater, mask_edge): 208 | # view1 209 | indicater = 1 - indicater 210 | env_indicator = indicater.new_ones(indicater.shape) * 0.1 211 | env_indicator = torch.bernoulli(env_indicator) 212 | keep_node = torch.nonzero(indicater + env_indicator, as_tuple=False).view(-1,) 213 | 214 | mask_edge = 1 - mask_edge 215 | edge_env_indicator = mask_edge.new_ones(mask_edge.shape) * 0.1 216 | edge_env_indicator = torch.bernoulli(edge_env_indicator) 217 | new_mask_edge = mask_edge + edge_env_indicator 218 | new_mask_edge = new_mask_edge.clamp_max(1) 219 | 220 | new_mask_edge = new_mask_edge.to(self.device) 221 | keep_node = keep_node.to(self.device) 222 | 223 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, new_mask_edge, keep_node) 224 | x_aug = self.get_projection(z_aug) 225 | 226 | z, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, None, None) 227 | x = self.get_projection(z) 228 | 229 | contrast_loss = self.calc_loss(x, x_aug) 230 | 231 | return contrast_loss 232 | 233 | 234 | @staticmethod 235 | def calc_loss(x, x_aug, temperature=0.2, sym=True): 236 | # x and x_aug shape -> Batch x proj_hidden_dim 237 | 238 | batch_size, _ = x.size() 239 | x_abs = x.norm(dim=1) 240 | x_aug_abs = x_aug.norm(dim=1) 241 | 242 | sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs) 243 | sim_matrix = torch.exp(sim_matrix / temperature) 244 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 245 | if sym: 246 | 247 | loss_0 = pos_sim / (sim_matrix.sum(dim=0) - pos_sim) 248 | loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 249 | 250 | loss_0 = - torch.log(loss_0).mean() 251 | loss_1 = - torch.log(loss_1).mean() 252 | loss = (loss_0 + loss_1)/2.0 253 | else: 254 | loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 255 | loss_1 = - torch.log(loss_1).mean() 256 | return loss_1 257 | 258 | return loss 259 | 260 | def forward(self, batch, progress): 261 | if progress < self.warm_ratio: 262 | return self.contrast_train(batch) 263 | else: 264 | keep_node_list, mask_edge_list = self.get_contrastive_cam(batch, self.inner_iter) 265 | indicater = keep_node_list[-1] 266 | mask_edge = mask_edge_list[-1] 267 | 268 | pos_score = self.positive(batch, indicater, mask_edge) 269 | neg_score = self.negative(batch, indicater, mask_edge) 270 | 271 | return pos_score - 0.1 * neg_score 272 | -------------------------------------------------------------------------------- /transfer_finetune.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import argparse 6 | from datasets.transfer_mol_dataset import MoleculeDataset 7 | from munch import Munch 8 | 9 | from torch_geometric.loader import DataLoader 10 | from torch_geometric.transforms import Compose 11 | 12 | from early_stop import EarlyStopping 13 | from datetime import datetime 14 | import os 15 | import shutil 16 | from tqdm import tqdm 17 | import json 18 | 19 | import torch 20 | from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR 21 | from torch import optim, nn 22 | 23 | from model.contra_cam_v7_for_transfer import Model 24 | from cam_util.chem_gnn import GNN_graphpred 25 | from cam_util.utils import initialize_edge_weight 26 | from util.utils import scaffold_split 27 | import pandas as pd 28 | from sklearn.metrics import roc_auc_score 29 | 30 | 31 | def arg_parse(): 32 | str2bool = lambda x: x.lower() == "true" 33 | parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics') 34 | 35 | parser.add_argument('--dataset_name', type=str, default='bbbp', help='dataset name') 36 | parser.add_argument('--dataset_root', type=str, default='storage/datasets', help='dataset dir') 37 | 38 | parser.add_argument('--cuda_device', type=str, default='4') 39 | parser.add_argument('--num_workers', type=int, default=8) 40 | 41 | parser.add_argument('--batch_size', type=int, default=32) 42 | parser.add_argument('--epochs', type=int, default=100) 43 | 44 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 45 | parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') 46 | parser.add_argument('--lr_scale', type=float, default=1, help='relative learning rate for the feature extraction layer (default: 1)') 47 | 48 | parser.add_argument('--lr_decay', type=int, default=30) 49 | parser.add_argument('--lr_gamma', type=float, default=0.1) 50 | parser.add_argument('--lr_scheduler', type=str, default='none', help='cos, step') 51 | parser.add_argument('--milestones', nargs='+', type=int, default=[40,60,80]) 52 | 53 | parser.add_argument('--warm_ratio', type=float, default=0.1, help='Number epochs to start cam contrast') 54 | parser.add_argument('--inner_iter', type=int, default=3, help='Number epochs to start cam contrast') 55 | parser.add_argument('--num_gc_layers', type=int, default=5, help='Number of GNN layers before pooling') 56 | parser.add_argument('--emb_dim', type=int, default=300) 57 | parser.add_argument('--drop_ratio', type=float, default=0.5, help='Dropout Ratio / Probability') 58 | parser.add_argument('--thres', type=float, default=0.5, help='0 to 1 for controlling the node to drop') 59 | parser.add_argument('--JK', type=str, default="last", help='how the node features across layers are combined. last, sum, max or concat') 60 | parser.add_argument('--gnn_type', type=str, default="gin") 61 | parser.add_argument('--graph_pooling', type=str, default="mean", help='graph level pooling (sum, mean, max, set2set, attention)') 62 | parser.add_argument('--split', type=str, default="scaffold", help="random or scaffold or random_scaffold") 63 | 64 | parser.add_argument('--note', type=str, default='finetune', help='note to record') 65 | 66 | parser.add_argument('--trails', type=int, default=1, help='number of runs (default: 0)') 67 | parser.add_argument('--seed', type=int, default=618) 68 | 69 | parser.add_argument("--loadFilename", type=str, default=None) 70 | parser.add_argument('--input_model_file', type=str, default='data/contra_cam_v7_for_transfer/Dec11_21-30-24__pretrain/latest.tar', help='filename to read the pretrain model (if there is any)') 71 | 72 | args = parser.parse_args() 73 | return args 74 | 75 | 76 | def set_seed(seed): 77 | # Fix Random seed 78 | random.seed(seed) 79 | np.random.seed(seed) 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed(seed) 82 | torch.cuda.manual_seed_all(seed) 83 | torch.backends.cudnn.deterministic = True 84 | torch.backends.cudnn.CEX = False 85 | 86 | 87 | def directory_name_generate(model, note): 88 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 89 | directory = "data/{}".format(model.name) 90 | directory = os.path.join(directory, current_time) 91 | directory = directory + '__' + note 92 | return directory 93 | 94 | 95 | 96 | def load_data(opt): 97 | dataset = MoleculeDataset(opt.dataset_root + "/transfer_dataset/"+opt.dataset_name, dataset=opt.dataset_name) 98 | 99 | if opt.split == "scaffold": 100 | smiles_list = pd.read_csv(opt.dataset_root + '/transfer_dataset/' + opt.dataset_name + '/processed/smiles.csv', header=None)[0].tolist() 101 | train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1) 102 | else: 103 | raise ValueError("Invalid split option.") 104 | 105 | train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) 106 | val_loader = DataLoader(valid_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers) 107 | test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers) 108 | 109 | meta_info = Munch() 110 | meta_info.dataset_type = 'molecule' 111 | meta_info.model_level = 'graph' 112 | 113 | 114 | return train_loader, val_loader, test_loader, dataset, meta_info 115 | 116 | 117 | 118 | def model_eval(model, loader, device): 119 | model.eval() 120 | y_true = [] 121 | y_scores = [] 122 | 123 | for step, batch in enumerate(tqdm(loader, desc="Iteration")): 124 | batch = batch.to(device) 125 | 126 | with torch.no_grad(): 127 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) 128 | 129 | y_true.append(batch.y.view(pred.shape)) 130 | y_scores.append(pred) 131 | 132 | y_true = torch.cat(y_true, dim = 0).cpu().numpy() 133 | y_scores = torch.cat(y_scores, dim = 0).cpu().numpy() 134 | 135 | roc_list = [] 136 | for i in range(y_true.shape[1]): 137 | #AUC is only defined when there is at least one positive data. 138 | if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == -1) > 0: 139 | is_valid = y_true[:,i]**2 > 0 140 | roc_list.append(roc_auc_score((y_true[is_valid,i] + 1)/2, y_scores[is_valid,i])) 141 | 142 | if len(roc_list) < y_true.shape[1]: 143 | print("Some target is missing!") 144 | print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1])) 145 | 146 | return sum(roc_list)/len(roc_list) #y_true.shape[1] 147 | 148 | 149 | def train(opt): 150 | 151 | if opt.dataset_name == "tox21": 152 | num_tasks = 12 153 | elif opt.dataset_name == "hiv": 154 | num_tasks = 1 155 | elif opt.dataset_name == "pcba": 156 | num_tasks = 128 157 | elif opt.dataset_name == "muv": 158 | num_tasks = 17 159 | elif opt.dataset_name == "bace": 160 | num_tasks = 1 161 | elif opt.dataset_name == "bbbp": 162 | num_tasks = 1 163 | elif opt.dataset_name == "toxcast": 164 | num_tasks = 617 165 | elif opt.dataset_name == "sider": 166 | num_tasks = 27 167 | elif opt.dataset_name == "clintox": 168 | num_tasks = 2 169 | else: 170 | raise ValueError("Invalid dataset name.") 171 | 172 | ############ 173 | 174 | train_loader, val_loader, test_loader, dataset, meta_info = load_data(opt) 175 | 176 | device = torch.device("cuda:{0}".format(opt.cuda_device)) 177 | model = GNN_graphpred(opt.num_gc_layers, opt.emb_dim, num_tasks, JK=opt.JK, drop_ratio=opt.drop_ratio, graph_pooling=opt.graph_pooling, gnn_type=opt.gnn_type) 178 | model = model.to(device) 179 | 180 | 181 | if not opt.input_model_file == "": 182 | checkpoint = torch.load(opt.input_model_file) 183 | sd = checkpoint['sd'] 184 | full_model = Model(meta_info, opt, device) 185 | full_model.load_state_dict(sd) 186 | # model.from_pretrained(full_model.encoder.state_dict()) 187 | model.gnn.load_state_dict(full_model.encoder.state_dict()) 188 | model.name = full_model.name + '_finetune' 189 | 190 | model_param_group = [] 191 | model_param_group.append({"params": model.gnn.parameters()}) 192 | if opt.graph_pooling == "attention": 193 | model_param_group.append({"params": model.pool.parameters(), "lr": opt.lr * opt.lr_scale}) 194 | model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr": opt.lr * opt.lr_scale}) 195 | optimizer = optim.Adam(model_param_group, lr=opt.lr, weight_decay=opt.weight_decay) 196 | 197 | if opt.lr_scheduler == 'step': 198 | scheduler = StepLR(optimizer, step_size=opt.lr_decay, gamma=opt.lr_gamma) 199 | elif opt.lr_scheduler == 'multi': 200 | scheduler = MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.lr_gamma) 201 | elif opt.lr_scheduler == 'cos': 202 | scheduler = CosineAnnealingLR(optimizer, T_max=opt.epochs) 203 | else: 204 | scheduler = MultiStepLR(optimizer, milestones=[99999999999999], gamma=opt.lr_gamma) 205 | 206 | start_epoch = 0 207 | 208 | if opt.loadFilename != None: 209 | checkpoint = torch.load(opt.loadFilename) 210 | sd = checkpoint['sd'] 211 | opt_sd = checkpoint['opt'] 212 | 213 | start_epoch = checkpoint['epoch'] + 1 214 | scheduler_sd = checkpoint['sche'] 215 | 216 | model.load_state_dict(sd) 217 | optimizer.load_state_dict(opt_sd) 218 | 219 | scheduler.load_state_dict(scheduler_sd) 220 | 221 | directory = directory_name_generate(model, opt.note) 222 | 223 | criterion = nn.BCEWithLogitsLoss(reduction="none")########## 224 | 225 | stop_manager = EarlyStopping(directory, patience=100) 226 | for epoch in range(start_epoch, opt.epochs): 227 | model.train() 228 | 229 | show = int(float(len(train_loader)) / 2.0) 230 | with tqdm(total=len(train_loader), desc="epoch"+str(epoch)) as pbar: 231 | for index, batch in enumerate(train_loader): 232 | 233 | batch = batch.to(device) 234 | model.train() 235 | 236 | pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) 237 | y = batch.y.view(pred.shape).to(torch.float64) 238 | is_valid = y ** 2 > 0 239 | # Loss matrix 240 | loss_mat = criterion(pred.double(), (y + 1) / 2) 241 | # loss matrix after removing null target 242 | loss_mat = torch.where(is_valid, loss_mat, torch.zeros(loss_mat.shape).to(loss_mat.device).to(loss_mat.dtype)) 243 | loss = torch.sum(loss_mat) / torch.sum(is_valid) 244 | 245 | 246 | optimizer.zero_grad() 247 | loss.backward() 248 | optimizer.step() 249 | 250 | pbar.update(1) 251 | if index % show == 0: 252 | print("Train Iter:[{:<3}/{}], Model_Loss:[{:.4f}]".format(index, len(train_loader), loss)) 253 | 254 | scheduler.step() 255 | 256 | 257 | val_score = model_eval(model, val_loader, device) 258 | test_score = model_eval(model, test_loader, device) 259 | print("Epoch:[{}/{}], valid:[{:.8f}]".format(epoch, opt.epochs, test_score)) 260 | 261 | save_dic = { 262 | 'epoch': epoch, 263 | 'sd': model.state_dict(), 264 | 'opt': optimizer.state_dict(), 265 | 'sche': scheduler.state_dict(), 266 | } 267 | stop_manager(val_score, save_dic) 268 | if stop_manager.early_stop: 269 | print("Early stopping") 270 | break 271 | 272 | 273 | ####### final test ########### 274 | val_score = model_eval(model, val_loader, device) 275 | test_score = model_eval(model, test_loader, device) 276 | print("Final_test {}".format(test_score)) 277 | 278 | best_checkpoint = torch.load(os.path.join( 279 | directory + '/buffer', '{}_{}.tar'.format(stop_manager.best_score, 'checkpoint'))) 280 | torch.save({ 281 | 'sd': best_checkpoint['sd'], 282 | }, os.path.join(directory, 'best_{}.tar'.format(stop_manager.best_score))) 283 | torch.save({ 284 | 'sd': model.state_dict(), 285 | }, os.path.join(directory, 'latest.tar')) 286 | 287 | with open(os.path.join(directory, 'model_arg.json'), 'wt') as f: 288 | json.dump(vars(opt), f, indent=4) 289 | 290 | shutil.rmtree(directory + '/buffer') 291 | 292 | 293 | ######## best test ############ 294 | model.load_state_dict(best_checkpoint['sd']) 295 | val_score = model_eval(model, val_loader, device) 296 | test_score = model_eval(model, test_loader, device) 297 | 298 | print("Best_test {}".format(test_score)) 299 | return stop_manager.best_score 300 | 301 | if __name__ == "__main__": 302 | opt = arg_parse() 303 | set_seed(opt.seed) 304 | # opt.input_model_file 305 | total = [] 306 | for _ in range(opt.trails): 307 | test_score = train(opt) 308 | total.append(test_score) 309 | 310 | print(total) 311 | print(opt) -------------------------------------------------------------------------------- /cam_util/chem_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import MessagePassing 4 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 5 | from torch_geometric.nn.inits import glorot, zeros 6 | from torch_geometric.utils import add_self_loops, softmax 7 | from torch_scatter import scatter_add 8 | 9 | num_atom_type = 120 # including the extra mask tokens 10 | num_chirality_tag = 3 11 | 12 | num_bond_type = 6 # including aromatic and self-loop edge, and extra masked tokens 13 | num_bond_direction = 3 14 | 15 | 16 | class GINConv(MessagePassing): 17 | """ 18 | Extension of GIN aggregation to incorporate edge information by concatenation. 19 | 20 | Args: 21 | emb_dim (int): dimensionality of embeddings for nodes and edges. 22 | embed_input (bool): whether to embed input or not. 23 | 24 | 25 | See https://arxiv.org/abs/1810.00826 26 | """ 27 | 28 | def __init__(self, emb_dim, aggr="add"): 29 | super(GINConv, self).__init__(aggr=aggr) 30 | # multi-layer perceptron 31 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.ReLU(), 32 | torch.nn.Linear(2 * emb_dim, emb_dim)) 33 | self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) 34 | self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) 35 | 36 | torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) 37 | torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) 38 | self.aggr = aggr 39 | 40 | def forward(self, x, edge_index, edge_attr, edge_weight= None): 41 | # add self loops in the edge space 42 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) 43 | 44 | # add features corresponding to self-loop edges. 45 | self_loop_attr = torch.zeros(x.size(0), 2) 46 | self_loop_attr[:, 0] = 4 # bond type for self-loop edge 47 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 48 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) 49 | 50 | # add edge weight of 1.0 to all self loop edges. 51 | if edge_weight is not None: 52 | self_loop_weights = torch.ones(self_loop_attr.shape[0], dtype=torch.float).to(edge_weight.device).to(edge_weight.dtype) 53 | edge_weight = torch.cat((edge_weight, self_loop_weights), dim=0) 54 | 55 | edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1]) 56 | 57 | # return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings) 58 | 59 | return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings, edge_weight=edge_weight) 60 | 61 | def message(self, x_j, edge_attr, edge_weight): 62 | return F.relu(x_j + edge_attr) if edge_weight is None else F.relu(x_j + edge_attr) * edge_weight.view(-1, 1) 63 | # return x_j + edge_attr 64 | 65 | def update(self, aggr_out): 66 | return self.mlp(aggr_out) 67 | 68 | 69 | class GCNConv(MessagePassing): 70 | 71 | def __init__(self, emb_dim, aggr="add"): 72 | super(GCNConv, self).__init__() 73 | 74 | self.emb_dim = emb_dim 75 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 76 | self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) 77 | self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) 78 | 79 | torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) 80 | torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) 81 | 82 | self.aggr = aggr 83 | 84 | def norm(self, edge_index, num_nodes, dtype): 85 | ### assuming that self-loops have been already added in edge_index 86 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 87 | device=edge_index.device) 88 | row, col = edge_index 89 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 90 | deg_inv_sqrt = deg.pow(-0.5) 91 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 92 | 93 | return deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 94 | 95 | def forward(self, x, edge_index, edge_attr): 96 | # add self loops in the edge space 97 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) 98 | 99 | # add features corresponding to self-loop edges. 100 | self_loop_attr = torch.zeros(x.size(0), 2) 101 | self_loop_attr[:, 0] = 4 # bond type for self-loop edge 102 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 103 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) 104 | 105 | edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1]) 106 | 107 | norm = self.norm(edge_index, x.size(0), x.dtype) 108 | 109 | x = self.linear(x) 110 | 111 | return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings, norm=norm) 112 | 113 | def message(self, x_j, edge_attr, norm): 114 | return norm.view(-1, 1) * (x_j + edge_attr) 115 | 116 | 117 | class GATConv(MessagePassing): 118 | def __init__(self, emb_dim, heads=2, negative_slope=0.2, aggr="add"): 119 | super(GATConv, self).__init__() 120 | 121 | self.aggr = aggr 122 | 123 | self.emb_dim = emb_dim 124 | self.heads = heads 125 | self.negative_slope = negative_slope 126 | 127 | self.weight_linear = torch.nn.Linear(emb_dim, heads * emb_dim) 128 | self.att = torch.nn.Parameter(torch.Tensor(1, heads, 2 * emb_dim)) 129 | 130 | self.bias = torch.nn.Parameter(torch.Tensor(emb_dim)) 131 | 132 | self.edge_embedding1 = torch.nn.Embedding(num_bond_type, heads * emb_dim) 133 | self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, heads * emb_dim) 134 | 135 | torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) 136 | torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) 137 | 138 | self.reset_parameters() 139 | 140 | def reset_parameters(self): 141 | glorot(self.att) 142 | zeros(self.bias) 143 | 144 | def forward(self, x, edge_index, edge_attr): 145 | # add self loops in the edge space 146 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) 147 | 148 | # add features corresponding to self-loop edges. 149 | self_loop_attr = torch.zeros(x.size(0), 2) 150 | self_loop_attr[:, 0] = 4 # bond type for self-loop edge 151 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 152 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) 153 | 154 | edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1]) 155 | 156 | x = self.weight_linear(x).view(-1, self.heads, self.emb_dim) 157 | return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings) 158 | 159 | def message(self, edge_index, x_i, x_j, edge_attr): 160 | edge_attr = edge_attr.view(-1, self.heads, self.emb_dim) 161 | x_j += edge_attr 162 | 163 | alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) 164 | 165 | alpha = F.leaky_relu(alpha, self.negative_slope) 166 | alpha = softmax(alpha, edge_index[0]) 167 | 168 | return x_j * alpha.view(-1, self.heads, 1) 169 | 170 | def update(self, aggr_out): 171 | aggr_out = aggr_out.mean(dim=1) 172 | aggr_out = aggr_out + self.bias 173 | 174 | return aggr_out 175 | 176 | 177 | class GraphSAGEConv(MessagePassing): 178 | def __init__(self, emb_dim, aggr="mean"): 179 | super(GraphSAGEConv, self).__init__() 180 | 181 | self.emb_dim = emb_dim 182 | self.linear = torch.nn.Linear(emb_dim, emb_dim) 183 | self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim) 184 | self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim) 185 | 186 | torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data) 187 | torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data) 188 | 189 | self.aggr = aggr 190 | 191 | def forward(self, x, edge_index, edge_attr): 192 | # add self loops in the edge space 193 | edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) 194 | 195 | # add features corresponding to self-loop edges. 196 | self_loop_attr = torch.zeros(x.size(0), 2) 197 | self_loop_attr[:, 0] = 4 # bond type for self-loop edge 198 | self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) 199 | edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) 200 | 201 | edge_embeddings = self.edge_embedding1(edge_attr[:, 0]) + self.edge_embedding2(edge_attr[:, 1]) 202 | 203 | x = self.linear(x) 204 | 205 | return self.propagate(self.aggr, edge_index, x=x, edge_attr=edge_embeddings) 206 | 207 | def message(self, x_j, edge_attr): 208 | return x_j + edge_attr 209 | 210 | def update(self, aggr_out): 211 | return F.normalize(aggr_out, p=2, dim=-1) 212 | 213 | 214 | class GNN(torch.nn.Module): 215 | """ 216 | 217 | 218 | Args: 219 | num_layer (int): the number of GNN layers 220 | emb_dim (int): dimensionality of embeddings 221 | JK (str): last, concat, max or sum. 222 | max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation 223 | drop_ratio (float): dropout rate 224 | gnn_type: gin, gcn, graphsage, gat 225 | 226 | Output: 227 | node representations 228 | 229 | """ 230 | 231 | def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0, gnn_type="gin"): 232 | super(GNN, self).__init__() 233 | self.emb_dim = emb_dim 234 | self.num_layer = num_layer 235 | self.drop_ratio = drop_ratio 236 | self.JK = JK 237 | 238 | if self.num_layer < 2: 239 | raise ValueError("Number of GNN layers must be greater than 1.") 240 | 241 | self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim) 242 | self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim) 243 | 244 | torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data) 245 | torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data) 246 | 247 | ###List of MLPs 248 | self.gnns = torch.nn.ModuleList() 249 | for layer in range(num_layer): 250 | if gnn_type == "gin": 251 | self.gnns.append(GINConv(emb_dim, aggr="add")) 252 | elif gnn_type == "gcn": 253 | self.gnns.append(GCNConv(emb_dim)) 254 | elif gnn_type == "gat": 255 | self.gnns.append(GATConv(emb_dim)) 256 | elif gnn_type == "graphsage": 257 | self.gnns.append(GraphSAGEConv(emb_dim)) 258 | 259 | ###List of batchnorms 260 | self.batch_norms = torch.nn.ModuleList() 261 | for layer in range(num_layer): 262 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 263 | 264 | # def forward(self, x, edge_index, edge_attr,edge_weight): 265 | def forward(self, *argv): 266 | 267 | if len(argv) == 4: 268 | x, edge_index, edge_attr, edge_weight = argv[0], argv[1], argv[2], argv[3] 269 | elif len(argv) == 3: 270 | x, edge_index, edge_attr = argv[0], argv[1], argv[2] 271 | edge_weight = None 272 | elif len(argv) == 1: 273 | data = argv[0] 274 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 275 | edge_weight = None 276 | else: 277 | raise ValueError("unmatched number of arguments.") 278 | 279 | x = self.x_embedding1(x[:, 0]) + self.x_embedding2(x[:, 1]) 280 | 281 | h_list = [x] 282 | for layer in range(self.num_layer): 283 | h = self.gnns[layer](h_list[layer], edge_index, edge_attr, edge_weight) 284 | h = self.batch_norms[layer](h) 285 | # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) 286 | if layer == self.num_layer - 1: 287 | # remove relu for the last layer 288 | h = F.dropout(h, self.drop_ratio, training=self.training) 289 | else: 290 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 291 | h_list.append(h) 292 | 293 | ### Different implementations of Jk-concat 294 | if self.JK == "concat": 295 | node_representation = torch.cat(h_list, dim=1) 296 | elif self.JK == "last": 297 | node_representation = h_list[-1] 298 | elif self.JK == "max": 299 | h_list = [h.unsqueeze_(0) for h in h_list] 300 | node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0] 301 | elif self.JK == "sum": 302 | h_list = [h.unsqueeze_(0) for h in h_list] 303 | node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0] 304 | 305 | return node_representation 306 | 307 | 308 | class GNN_graphpred(torch.nn.Module): 309 | """ 310 | Extension of GIN to incorporate edge information by concatenation. 311 | 312 | Args: 313 | num_layer (int): the number of GNN layers 314 | emb_dim (int): dimensionality of embeddings 315 | num_tasks (int): number of tasks in multi-task learning scenario 316 | drop_ratio (float): dropout rate 317 | JK (str): last, concat, max or sum. 318 | graph_pooling (str): sum, mean, max, attention, set2set 319 | gnn_type: gin, gcn, graphsage, gat 320 | 321 | See https://arxiv.org/abs/1810.00826 322 | JK-net: https://arxiv.org/abs/1806.03536 323 | """ 324 | 325 | def __init__(self, num_layer, emb_dim, num_tasks, JK="last", drop_ratio=0, graph_pooling="mean", gnn_type="gin"): 326 | super(GNN_graphpred, self).__init__() 327 | self.num_layer = num_layer 328 | self.drop_ratio = drop_ratio 329 | self.JK = JK 330 | self.emb_dim = emb_dim 331 | self.num_tasks = num_tasks 332 | 333 | if self.num_layer < 2: 334 | raise ValueError("Number of GNN layers must be greater than 1.") 335 | 336 | self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type=gnn_type) 337 | 338 | # Different kind of graph pooling 339 | if graph_pooling == "sum": 340 | self.pool = global_add_pool 341 | elif graph_pooling == "mean": 342 | self.pool = global_mean_pool 343 | elif graph_pooling == "max": 344 | self.pool = global_max_pool 345 | elif graph_pooling == "attention": 346 | if self.JK == "concat": 347 | self.pool = GlobalAttention(gate_nn=torch.nn.Linear((self.num_layer + 1) * emb_dim, 1)) 348 | else: 349 | self.pool = GlobalAttention(gate_nn=torch.nn.Linear(emb_dim, 1)) 350 | elif graph_pooling[:-1] == "set2set": 351 | set2set_iter = int(graph_pooling[-1]) 352 | if self.JK == "concat": 353 | self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter) 354 | else: 355 | self.pool = Set2Set(emb_dim, set2set_iter) 356 | else: 357 | raise ValueError("Invalid graph pooling type.") 358 | 359 | # For graph-level binary classification 360 | if graph_pooling[:-1] == "set2set": 361 | self.mult = 2 362 | else: 363 | self.mult = 1 364 | 365 | if self.JK == "concat": 366 | self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks) 367 | else: 368 | self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks) 369 | 370 | def from_pretrained(self, model_file): 371 | # self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio) 372 | self.gnn.load_state_dict(torch.load(model_file)) 373 | 374 | def forward(self, *argv): 375 | if len(argv) == 4: 376 | x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3] 377 | elif len(argv) == 1: 378 | data = argv[0] 379 | x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch 380 | else: 381 | raise ValueError("unmatched number of arguments.") 382 | 383 | node_representation = self.gnn(x, edge_index, edge_attr) 384 | 385 | return self.graph_pred_linear(self.pool(node_representation, batch)) 386 | 387 | 388 | if __name__ == "__main__": 389 | pass 390 | -------------------------------------------------------------------------------- /cam_util/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import accuracy_score 8 | from torch_geometric.data import InMemoryDataset, download_url, extract_zip, Data 9 | from torch_geometric.io import read_tu_data 10 | import pickle 11 | from tqdm import tqdm 12 | 13 | class TUDataset(InMemoryDataset): 14 | r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY", 15 | "REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University 16 | `_. 17 | In addition, this dataset wrapper provides `cleaned dataset versions 18 | `_ as motivated by the 19 | `"Understanding Isomorphism Bias in Graph Data Sets" 20 | `_ paper, containing only non-isomorphic 21 | graphs. 22 | 23 | .. note:: 24 | Some datasets may not come with any node labels. 25 | You can then either make use of the argument :obj:`use_node_attr` 26 | to load additional continuous node attributes (if present) or provide 27 | synthetic node features using transforms such as 28 | like :class:`torch_geometric.transforms.Constant` or 29 | :class:`torch_geometric.transforms.OneHotDegree`. 30 | 31 | Args: 32 | root (string): Root directory where the dataset should be saved. 33 | name (string): The `name 34 | `_ of the 35 | dataset. 36 | transform (callable, optional): A function/transform that takes in an 37 | :obj:`torch_geometric.data.Data` object and returns a transformed 38 | version. The data object will be transformed before every access. 39 | (default: :obj:`None`) 40 | pre_transform (callable, optional): A function/transform that takes in 41 | an :obj:`torch_geometric.data.Data` object and returns a 42 | transformed version. The data object will be transformed before 43 | being saved to disk. (default: :obj:`None`) 44 | pre_filter (callable, optional): A function that takes in an 45 | :obj:`torch_geometric.data.Data` object and returns a boolean 46 | value, indicating whether the data object should be included in the 47 | final dataset. (default: :obj:`None`) 48 | use_node_attr (bool, optional): If :obj:`True`, the dataset will 49 | contain additional continuous node attributes (if present). 50 | (default: :obj:`False`) 51 | use_edge_attr (bool, optional): If :obj:`True`, the dataset will 52 | contain additional continuous edge attributes (if present). 53 | (default: :obj:`False`) 54 | cleaned: (bool, optional): If :obj:`True`, the dataset will 55 | contain only non-isomorphic graphs. (default: :obj:`False`) 56 | """ 57 | 58 | url = 'https://www.chrsmrrs.com/graphkerneldatasets' 59 | cleaned_url = ('https://raw.githubusercontent.com/nd7141/' 60 | 'graph_datasets/master/datasets') 61 | 62 | def __init__(self, root, name, transform=None, pre_transform=None, 63 | pre_filter=None, use_node_attr=False, use_edge_attr=False, 64 | cleaned=False): 65 | self.name = name 66 | self.cleaned = cleaned 67 | self.num_tasks = 1 68 | self.task_type = 'classification' 69 | self.eval_metric = 'accuracy' 70 | super(TUDataset, self).__init__(root, transform, pre_transform, 71 | pre_filter) 72 | self.data, self.slices = torch.load(self.processed_paths[0]) 73 | if self.data.x is not None and not use_node_attr: 74 | num_node_attributes = self.num_node_attributes 75 | self.data.x = self.data.x[:, num_node_attributes:] 76 | if self.data.edge_attr is not None and not use_edge_attr: 77 | num_edge_attributes = self.num_edge_attributes 78 | self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:] 79 | 80 | @property 81 | def raw_dir(self): 82 | name = 'raw{}'.format('_cleaned' if self.cleaned else '') 83 | return osp.join(self.root, self.name, name) 84 | 85 | @property 86 | def processed_dir(self): 87 | name = 'processed{}'.format('_cleaned' if self.cleaned else '') 88 | return osp.join(self.root, self.name, name) 89 | 90 | @property 91 | def num_node_labels(self): 92 | if self.data.x is None: 93 | return 0 94 | for i in range(self.data.x.size(1)): 95 | x = self.data.x[:, i:] 96 | if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all(): 97 | return self.data.x.size(1) - i 98 | return 0 99 | 100 | @property 101 | def num_node_attributes(self): 102 | if self.data.x is None: 103 | return 0 104 | return self.data.x.size(1) - self.num_node_labels 105 | 106 | @property 107 | def num_edge_labels(self): 108 | if self.data.edge_attr is None: 109 | return 0 110 | for i in range(self.data.edge_attr.size(1)): 111 | if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0): 112 | return self.data.edge_attr.size(1) - i 113 | return 0 114 | 115 | @property 116 | def num_edge_attributes(self): 117 | if self.data.edge_attr is None: 118 | return 0 119 | return self.data.edge_attr.size(1) - self.num_edge_labels 120 | 121 | @property 122 | def raw_file_names(self): 123 | names = ['A', 'graph_indicator'] 124 | return ['{}_{}.txt'.format(self.name, name) for name in names] 125 | 126 | @property 127 | def processed_file_names(self): 128 | return 'data.pt' 129 | 130 | def download(self): 131 | url = self.cleaned_url if self.cleaned else self.url 132 | folder = osp.join(self.root, self.name) 133 | path = download_url('{}/{}.zip'.format(url, self.name), folder) 134 | extract_zip(path, folder) 135 | os.unlink(path) 136 | shutil.rmtree(self.raw_dir) 137 | os.rename(osp.join(folder, self.name), self.raw_dir) 138 | 139 | def process(self): 140 | self.data, self.slices, _ = read_tu_data(self.raw_dir, self.name) 141 | 142 | if self.pre_filter is not None: 143 | data_list = [self.get(idx) for idx in range(len(self))] 144 | data_list = [data for data in data_list if self.pre_filter(data)] 145 | self.data, self.slices = self.collate(data_list) 146 | 147 | if self.pre_transform is not None: 148 | data_list = [self.get(idx) for idx in range(len(self))] 149 | data_list = [self.pre_transform(data) for data in data_list] 150 | self.data, self.slices = self.collate(data_list) 151 | 152 | torch.save((self.data, self.slices), self.processed_paths[0]) 153 | 154 | def __repr__(self): 155 | return '{}({})'.format(self.name, len(self)) 156 | 157 | class TUEvaluator: 158 | def __init__(self): 159 | self.num_tasks = 1 160 | self.eval_metric = 'accuracy' 161 | 162 | def _parse_and_check_input(self, input_dict): 163 | if self.eval_metric == 'accuracy': 164 | if not 'y_true' in input_dict: 165 | raise RuntimeError('Missing key of y_true') 166 | if not 'y_pred' in input_dict: 167 | raise RuntimeError('Missing key of y_pred') 168 | 169 | y_true, y_pred = input_dict['y_true'], input_dict['y_pred'] 170 | 171 | ''' 172 | y_true: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 173 | y_pred: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 174 | ''' 175 | 176 | # converting to torch.Tensor to numpy on cpu 177 | if torch is not None and isinstance(y_true, torch.Tensor): 178 | y_true = y_true.detach().cpu().numpy() 179 | 180 | if torch is not None and isinstance(y_pred, torch.Tensor): 181 | y_pred = y_pred.detach().cpu().numpy() 182 | 183 | ## check type 184 | if not (isinstance(y_true, np.ndarray) and isinstance(y_true, np.ndarray)): 185 | raise RuntimeError('Arguments to Evaluator need to be either numpy ndarray or torch tensor') 186 | 187 | if not y_true.shape == y_pred.shape: 188 | raise RuntimeError('Shape of y_true and y_pred must be the same') 189 | 190 | if not y_true.ndim == 2: 191 | raise RuntimeError('y_true and y_pred mush to 2-dim arrray, {}-dim array given'.format(y_true.ndim)) 192 | 193 | if not y_true.shape[1] == self.num_tasks: 194 | raise RuntimeError('Number of tasks should be {} but {} given'.format(self.num_tasks, 195 | y_true.shape[1])) 196 | 197 | return y_true, y_pred 198 | else: 199 | raise ValueError('Undefined eval metric %s ' % self.eval_metric) 200 | 201 | def _eval_accuracy(self, y_true, y_pred): 202 | ''' 203 | compute Accuracy score averaged across tasks 204 | ''' 205 | acc_list = [] 206 | 207 | for i in range(y_true.shape[1]): 208 | # ignore nan values 209 | is_labeled = y_true[:, i] == y_true[:, i] 210 | acc = accuracy_score(y_true[is_labeled], y_pred[is_labeled]) 211 | acc_list.append(acc) 212 | 213 | return {'accuracy': sum(acc_list) / len(acc_list)} 214 | 215 | def eval(self, input_dict): 216 | y_true, y_pred = self._parse_and_check_input(input_dict) 217 | return self._eval_accuracy(y_true, y_pred) 218 | 219 | 220 | class ZINC(InMemoryDataset): 221 | 222 | url = 'https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1' 223 | split_url = ('https://raw.githubusercontent.com/graphdeeplearning/' 224 | 'benchmarking-gnns/master/data/molecules/{}.index') 225 | 226 | def __init__(self, root, subset=False, split='all', transform=None, 227 | pre_transform=None, pre_filter=None): 228 | self.subset = subset 229 | self.num_atom_type = 28 # known meta-info about the zinc dataset; can be calculated as well 230 | self.num_bond_type = 4 # known meta-info about the zinc dataset; can be calculated as well 231 | self.num_tasks = 1 232 | self.task_type = 'regression' 233 | self.eval_metric = 'mae' 234 | assert split in ['all', 'train', 'val', 'test'] 235 | super(ZINC, self).__init__(root, transform, pre_transform, pre_filter) 236 | path = osp.join(self.processed_dir, f'{split}.pt') 237 | self.data, self.slices = torch.load(path) 238 | 239 | 240 | @property 241 | def raw_file_names(self): 242 | return [ 243 | 'train.pickle', 'val.pickle', 'test.pickle', 'train.index', 244 | 'val.index', 'test.index', 'atom_dict.pickle', 'bond_dict.pickle' 245 | ] 246 | 247 | @property 248 | def processed_dir(self): 249 | name = 'subset' if self.subset else 'full' 250 | return osp.join(self.root, name, 'processed') 251 | 252 | @property 253 | def processed_file_names(self): 254 | return ['all.pt', 'train.pt', 'val.pt', 'test.pt'] 255 | 256 | def download(self): 257 | shutil.rmtree(self.raw_dir) 258 | path = download_url(self.url, self.root) 259 | extract_zip(path, self.root) 260 | os.rename(osp.join(self.root, 'molecules'), self.raw_dir) 261 | os.unlink(path) 262 | 263 | for split in ['train', 'val', 'test']: 264 | download_url(self.split_url.format(split), self.raw_dir) 265 | 266 | def process(self): 267 | all_data_list = [] 268 | for split in ['train', 'val', 'test']: 269 | with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f: 270 | mols = pickle.load(f) 271 | 272 | indices = range(len(mols)) 273 | 274 | if self.subset: 275 | with open(osp.join(self.raw_dir, f'{split}.index'), 'r') as f: 276 | indices = [int(x) for x in f.read()[:-1].split(',')] 277 | 278 | pbar = tqdm(total=len(indices)) 279 | pbar.set_description(f'Processing {split} dataset') 280 | 281 | data_list = [] 282 | for idx in indices: 283 | mol = mols[idx] 284 | 285 | x = mol['atom_type'].to(torch.long).view(-1, 1) 286 | y = mol['logP_SA_cycle_normalized'].to(torch.float) 287 | y = y.unsqueeze(self.num_tasks) 288 | adj = mol['bond_type'] 289 | edge_index = adj.nonzero(as_tuple=False).t().contiguous() 290 | edge_attr = adj[edge_index[0], edge_index[1]].to(torch.long) 291 | 292 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, 293 | y=y) 294 | 295 | if self.pre_filter is not None and not self.pre_filter(data): 296 | continue 297 | 298 | if self.pre_transform is not None: 299 | data = self.pre_transform(data) 300 | 301 | data_list.append(data) 302 | all_data_list.append(data) 303 | pbar.update(1) 304 | 305 | 306 | pbar.close() 307 | 308 | torch.save(self.collate(data_list), 309 | osp.join(self.processed_dir, f'{split}.pt')) 310 | 311 | torch.save(self.collate(all_data_list), 312 | osp.join(self.processed_dir, 'all.pt')) 313 | 314 | class ZINCEvaluator: 315 | def __init__(self): 316 | self.num_tasks = 1 317 | self.eval_metric = 'mae' 318 | 319 | def _parse_and_check_input(self, input_dict): 320 | if self.eval_metric == 'mae': 321 | if not 'y_true' in input_dict: 322 | raise RuntimeError('Missing key of y_true') 323 | if not 'y_pred' in input_dict: 324 | raise RuntimeError('Missing key of y_pred') 325 | 326 | y_true, y_pred = input_dict['y_true'], input_dict['y_pred'] 327 | 328 | ''' 329 | y_true: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 330 | y_pred: numpy ndarray or torch tensor of shape (num_graph, num_tasks) 331 | ''' 332 | 333 | # converting to torch.Tensor to numpy on cpu 334 | if torch is not None and isinstance(y_true, torch.Tensor): 335 | y_true = y_true.detach().cpu().numpy() 336 | 337 | if torch is not None and isinstance(y_pred, torch.Tensor): 338 | y_pred = y_pred.detach().cpu().numpy() 339 | 340 | ## check type 341 | if not (isinstance(y_true, np.ndarray) and isinstance(y_true, np.ndarray)): 342 | raise RuntimeError('Arguments to Evaluator need to be either numpy ndarray or torch tensor') 343 | 344 | if not y_true.shape == y_pred.shape: 345 | raise RuntimeError('Shape of y_true and y_pred must be the same') 346 | 347 | if not y_true.ndim == 2: 348 | raise RuntimeError('y_true and y_pred mush to 2-dim arrray, {}-dim array given'.format(y_true.ndim)) 349 | 350 | if not y_true.shape[1] == self.num_tasks: 351 | raise RuntimeError('Number of tasks should be {} but {} given'.format(self.num_tasks, 352 | y_true.shape[1])) 353 | 354 | return y_true, y_pred 355 | else: 356 | raise ValueError('Undefined eval metric %s ' % self.eval_metric) 357 | 358 | def _eval_mae(self, y_true, y_pred): 359 | ''' 360 | compute MAE score averaged across tasks 361 | ''' 362 | mae_list = [] 363 | 364 | for i in range(y_true.shape[1]): 365 | # ignore nan values 366 | is_labeled = y_true[:, i] == y_true[:, i] 367 | mae_list.append(np.absolute(y_true[is_labeled] - y_pred[is_labeled]).mean()) 368 | 369 | return {'mae': sum(mae_list) / len(mae_list)} 370 | 371 | def eval(self, input_dict): 372 | y_true, y_pred = self._parse_and_check_input(input_dict) 373 | return self._eval_mae(y_true, y_pred) -------------------------------------------------------------------------------- /model/GCS.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Sequential, Linear, ReLU 5 | from torch_geometric.nn import global_add_pool 6 | 7 | from typing import Callable, Union 8 | from torch import Tensor 9 | from torch_geometric.nn.conv import MessagePassing 10 | from torch_geometric.typing import OptPairTensor, Adj, Size 11 | from torch_scatter import scatter, scatter_max, scatter_min, scatter_add 12 | 13 | from cam_util.aug import permute_edges 14 | from copy import deepcopy 15 | from torch_geometric.utils import subgraph 16 | 17 | def normalize(cam, batch, eps=1e-20): 18 | cam = cam.clone() 19 | # batch_num 20 | batch_max, _ = scatter_max(cam.squeeze(), batch) 21 | batch_min, _ = scatter_min(cam.squeeze(), batch) 22 | batch_max_expand = [] 23 | batch_min_expand = [] 24 | for i in batch: 25 | batch_max_expand.append(batch_max[i]) 26 | batch_min_expand.append(batch_min[i]) 27 | 28 | batch_max_expand = torch.tensor(batch_max_expand).unsqueeze(1).to(cam.device) 29 | batch_min_expand = torch.tensor(batch_min_expand).unsqueeze(1).to(cam.device) 30 | normalized_cam = (cam - batch_min_expand) / (batch_max_expand + eps) 31 | normalized_cam = normalized_cam.clamp_min(0) 32 | normalized_cam = normalized_cam.clamp_max(1) 33 | 34 | return normalized_cam 35 | 36 | def reset(nn): 37 | def _reset(item): 38 | if hasattr(item, 'reset_parameters'): 39 | item.reset_parameters() 40 | 41 | if nn is not None: 42 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 43 | for item in nn.children(): 44 | _reset(item) 45 | else: 46 | _reset(nn) 47 | 48 | 49 | class WGINConv(MessagePassing): 50 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, 51 | **kwargs): 52 | kwargs.setdefault('aggr', 'add') 53 | super(WGINConv, self).__init__(**kwargs) 54 | self.nn = nn 55 | self.initial_eps = eps 56 | if train_eps: 57 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 58 | else: 59 | self.register_buffer('eps', torch.Tensor([eps])) 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self): 63 | reset(self.nn) 64 | self.eps.data.fill_(self.initial_eps) 65 | 66 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight = None, 67 | size: Size = None) -> Tensor: 68 | """""" 69 | if isinstance(x, Tensor): 70 | x: OptPairTensor = (x, x) 71 | 72 | # propagate_type: (x: OptPairTensor) 73 | out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) 74 | 75 | x_r = x[1] 76 | if x_r is not None: 77 | out += (1 + self.eps) * x_r 78 | 79 | return self.nn(out) 80 | 81 | def message(self, x_j: Tensor, edge_weight) -> Tensor: 82 | return x_j if edge_weight is None else x_j * edge_weight.view(-1, 1) 83 | 84 | 85 | def __repr__(self): 86 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 87 | 88 | 89 | class TUEncoder(torch.nn.Module): 90 | def __init__(self, num_dataset_features, emb_dim=300, num_gc_layers=5, drop_ratio=0.0, pooling_type="layerwise", is_infograph=False): 91 | super(TUEncoder, self).__init__() 92 | 93 | self.pooling_type = pooling_type 94 | self.emb_dim = emb_dim 95 | self.num_gc_layers = num_gc_layers 96 | self.drop_ratio = drop_ratio 97 | self.is_infograph = is_infograph 98 | 99 | self.out_node_dim = self.emb_dim 100 | if self.pooling_type == "standard": 101 | self.out_graph_dim = self.emb_dim 102 | elif self.pooling_type == "layerwise": 103 | self.out_graph_dim = self.emb_dim * self.num_gc_layers 104 | else: 105 | raise NotImplementedError 106 | 107 | self.convs = torch.nn.ModuleList() 108 | self.bns = torch.nn.ModuleList() 109 | 110 | for i in range(num_gc_layers): 111 | 112 | if i: 113 | nn = Sequential(Linear(emb_dim, emb_dim), ReLU(), Linear(emb_dim, emb_dim)) 114 | else: 115 | nn = Sequential(Linear(num_dataset_features, emb_dim), ReLU(), Linear(emb_dim, emb_dim)) 116 | conv = WGINConv(nn) 117 | bn = torch.nn.BatchNorm1d(emb_dim) 118 | 119 | self.convs.append(conv) 120 | self.bns.append(bn) 121 | 122 | 123 | def forward(self, batch, x, edge_index, edge_attr=None, edge_weight=None): 124 | xs = [] 125 | for i in range(self.num_gc_layers): 126 | x = self.convs[i](x, edge_index, edge_weight) 127 | x = self.bns[i](x) 128 | if i == self.num_gc_layers - 1: 129 | # remove relu for the last layer 130 | x = F.dropout(x, self.drop_ratio, training=self.training) 131 | else: 132 | x = F.dropout(F.relu(x), self.drop_ratio, training=self.training) 133 | xs.append(x) 134 | # compute graph embedding using pooling 135 | if self.pooling_type == "standard": 136 | xpool = global_add_pool(x, batch) 137 | return xpool, x 138 | 139 | elif self.pooling_type == "layerwise": 140 | xpool = [global_add_pool(x, batch) for x in xs] 141 | xpool = torch.cat(xpool, 1) 142 | if self.is_infograph: 143 | return xpool, torch.cat(xs, 1) 144 | else: 145 | return xpool, x 146 | else: 147 | raise NotImplementedError 148 | 149 | def get_embeddings(self, loader, device, is_rand_label=False): 150 | ret = [] 151 | y = [] 152 | with torch.no_grad(): 153 | for data in loader: 154 | if isinstance(data, list): 155 | data = data[0].to(device) 156 | data = data.to(device) 157 | batch, x, edge_index = data.batch, data.x, data.edge_index 158 | edge_weight = data.edge_weight if hasattr(data, 'edge_weight') else None 159 | 160 | if x is None: 161 | x = torch.ones((batch.shape[0], 1)).to(device) 162 | x, _ = self.forward(batch, x, edge_index, edge_weight) 163 | 164 | ret.append(x.cpu().numpy()) 165 | if is_rand_label: 166 | y.append(data.rand_label.cpu().numpy()) 167 | else: 168 | y.append(data.y.cpu().numpy()) 169 | ret = np.concatenate(ret, 0) 170 | y = np.concatenate(y, 0) 171 | return ret, y 172 | 173 | 174 | 175 | class Model(torch.nn.Module): 176 | def __init__(self, meta_info, opt, device): 177 | super(Model, self).__init__() 178 | self.name = 'GCSå' 179 | self.device = device 180 | self.warm_ratio = opt.warm_ratio 181 | self.inner_iter = opt.inner_iter 182 | self.thres = opt.thres 183 | 184 | self.encoder = TUEncoder(num_dataset_features=1, emb_dim=opt.emb_dim, num_gc_layers=opt.num_gc_layers, drop_ratio=opt.drop_ratio, pooling_type=opt.pooling_type) 185 | self.proj_head = Sequential(Linear(self.encoder.out_graph_dim, opt.emb_dim), ReLU(inplace=True), Linear(opt.emb_dim, opt.emb_dim)) 186 | 187 | self.init_emb() 188 | 189 | def init_emb(self): 190 | for m in self.modules(): 191 | if isinstance(m, Linear): 192 | torch.nn.init.xavier_uniform_(m.weight.data) 193 | if m.bias is not None: 194 | m.bias.data.fill_(0.0) 195 | 196 | def _contrastive_score(self, query, key, queues): 197 | eye = torch.eye(query.size(0)).type_as(query) 198 | pos = torch.einsum('nc,nc->n', [query, key]).unsqueeze(0) 199 | neg = torch.cat([torch.einsum('nc,kc->nk', [query, queue]) * (1 - eye) for queue in queues], dim=1) 200 | score = (pos.exp().sum(dim=1) / neg.exp().sum(dim=1)).log() 201 | return score 202 | 203 | 204 | def _compute_cam(self, feature, score, batch, clamp_negative_weights=True): 205 | 206 | 207 | # feature (batch_nodes, embedding_dim) 208 | grad = torch.autograd.grad(score.sum(), feature)[0] 209 | 210 | # (batch_nodes, 1) 211 | weight = torch.mean(grad, dim=-1, keepdim=True) 212 | 213 | if clamp_negative_weights: # positive weights only 214 | weight = weight.clamp_min(0) 215 | 216 | # (batch_nodes, 1) 217 | cam = torch.sum(weight * feature, dim=1, keepdim=True).detach() 218 | 219 | normalized_cam = normalize(cam, batch).squeeze().detach() 220 | return normalized_cam 221 | 222 | 223 | def get_features(self, batch, x, edge_index, edge_attr, edge_weight, keep_node): 224 | if keep_node is not None: 225 | edge_index, edge_attr, edge_mask = subgraph(keep_node, edge_index, edge_attr, return_edge_mask=True) 226 | edge_weight = torch.masked_select(edge_weight, edge_mask.to(edge_weight.device)) 227 | new_x = torch.zeros(x.shape).to(x.device) 228 | new_x[keep_node] = x[keep_node] 229 | x = new_x 230 | 231 | z, node_emb = self.encoder(batch, x, edge_index, edge_attr, edge_weight) 232 | return z, node_emb 233 | 234 | def get_projection(self, z): 235 | return self.proj_head(z) 236 | 237 | def get_contrastive_cam(self, batch, n_iters=1, return_intermediate=False): 238 | key, queues = None, [] 239 | _masks, _masked_images = [], [] 240 | 241 | mask_edge = torch.zeros(batch.edge_index.shape[1]) + 1e-10 242 | keep_indice = torch.arange(batch.x.shape[0]) 243 | mask_edge_list = [] 244 | keep_node_list = [] 245 | 246 | for it in range(n_iters): 247 | z, node_emb = self.get_features(batch.batch, batch.x, batch.edge_index, None, (1 - mask_edge).to(self.device), keep_indice) 248 | output = self.get_projection(z) 249 | 250 | if it == 0: 251 | key = output # original graph 252 | # queues.append(output.detach()) # masked images 253 | 254 | # score = self._contrastive_score(output, key, queues) 255 | score = self.calc_loss(key, output) 256 | 257 | # (batch_nodes, 1) 258 | node_cam = self._compute_cam(node_emb, score, batch.batch, clamp_negative_weights=True) 259 | mask_node = torch.max(mask_node, node_cam) if it > 0 else node_cam 260 | mask_node = mask_node.detach() 261 | indicater = torch.where(mask_node > self.thres, 1, 0) 262 | keep_node_list.append(indicater) 263 | 264 | src, dst = batch.edge_index[0], batch.edge_index[1] 265 | # batch_edge, 1 266 | edge_cam = (node_cam[src] + node_cam[dst]) /2 267 | mask_edge = torch.max(mask_edge, edge_cam) if it > 0 else edge_cam 268 | mask_edge = mask_edge.detach() 269 | edge_indicater = torch.where(mask_edge > self.thres, 1, 0) 270 | mask_edge_list.append(edge_indicater) 271 | 272 | return keep_node_list, mask_edge_list 273 | 274 | 275 | 276 | 277 | 278 | def contrast_train(self, batch, keep_node=None, mask_edge=None): 279 | if keep_node is None and mask_edge is None: 280 | aug = permute_edges(deepcopy(batch).cpu()).to(self.device) 281 | z_aug, _ = self.get_features(aug.batch, aug.x, aug.edge_index, None, None, None) 282 | x_aug = self.get_projection(z_aug) 283 | else: 284 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, mask_edge, keep_node) 285 | x_aug = self.get_projection(z_aug) 286 | 287 | z, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, None, None) 288 | x = self.get_projection(z) 289 | 290 | contrast_loss = self.calc_loss(x, x_aug) 291 | 292 | return contrast_loss 293 | 294 | 295 | 296 | def positive(self, batch, indicater, mask_edge): 297 | # view1 298 | env_indicator = indicater.new_ones(indicater.shape) * 0.5 299 | env_indicator = torch.bernoulli(env_indicator) 300 | keep_node = torch.nonzero(indicater + env_indicator, as_tuple=False).view(-1,) 301 | 302 | edge_env_indicator = mask_edge.new_ones(mask_edge.shape) * 0.5 303 | edge_env_indicator = torch.bernoulli(edge_env_indicator) 304 | new_mask_edge = mask_edge + edge_env_indicator 305 | new_mask_edge = new_mask_edge.clamp_max(1) 306 | 307 | new_mask_edge = new_mask_edge.to(self.device) 308 | keep_node = keep_node.to(self.device) 309 | 310 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, new_mask_edge, keep_node) 311 | x_aug_1 = self.get_projection(z_aug) 312 | 313 | 314 | # view2 315 | env_indicator = indicater.new_ones(indicater.shape) * 0.5 316 | env_indicator = torch.bernoulli(env_indicator) 317 | keep_node = torch.nonzero(indicater + env_indicator, as_tuple=False).view(-1,) 318 | 319 | edge_env_indicator = mask_edge.new_ones(mask_edge.shape) * 0.5 320 | edge_env_indicator = torch.bernoulli(edge_env_indicator) 321 | new_mask_edge = mask_edge + edge_env_indicator 322 | new_mask_edge = new_mask_edge.clamp_max(1) 323 | 324 | new_mask_edge = new_mask_edge.to(self.device) 325 | keep_node = keep_node.to(self.device) 326 | 327 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, new_mask_edge, keep_node) 328 | x_aug_2 = self.get_projection(z_aug) 329 | 330 | contrast_loss = self.calc_loss(x_aug_1, x_aug_2) 331 | 332 | return contrast_loss 333 | 334 | 335 | def negative(self, batch, indicater, mask_edge): 336 | # view1 337 | indicater = 1 - indicater 338 | env_indicator = indicater.new_ones(indicater.shape) * 0.1 339 | env_indicator = torch.bernoulli(env_indicator) 340 | keep_node = torch.nonzero(indicater + env_indicator, as_tuple=False).view(-1,) 341 | 342 | mask_edge = 1 - mask_edge 343 | edge_env_indicator = mask_edge.new_ones(mask_edge.shape) * 0.1 344 | edge_env_indicator = torch.bernoulli(edge_env_indicator) 345 | new_mask_edge = mask_edge + edge_env_indicator 346 | new_mask_edge = new_mask_edge.clamp_max(1) 347 | 348 | new_mask_edge = new_mask_edge.to(self.device) 349 | keep_node = keep_node.to(self.device) 350 | 351 | z_aug, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, new_mask_edge, keep_node) 352 | x_aug = self.get_projection(z_aug) 353 | 354 | z, _ = self.get_features(batch.batch, batch.x, batch.edge_index, None, None, None) 355 | x = self.get_projection(z) 356 | 357 | contrast_loss = self.calc_loss(x, x_aug) 358 | 359 | return contrast_loss 360 | 361 | @staticmethod 362 | def calc_loss(x, x_aug, temperature=0.2, sym=True): 363 | # x and x_aug shape -> Batch x proj_hidden_dim 364 | 365 | batch_size, _ = x.size() 366 | x_abs = x.norm(dim=1) 367 | x_aug_abs = x_aug.norm(dim=1) 368 | 369 | sim_matrix = torch.einsum('ik,jk->ij', x, x_aug) / torch.einsum('i,j->ij', x_abs, x_aug_abs) 370 | sim_matrix = torch.exp(sim_matrix / temperature) 371 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 372 | if sym: 373 | 374 | loss_0 = pos_sim / (sim_matrix.sum(dim=0) - pos_sim) 375 | loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 376 | 377 | loss_0 = - torch.log(loss_0).mean() 378 | loss_1 = - torch.log(loss_1).mean() 379 | loss = (loss_0 + loss_1)/2.0 380 | else: 381 | loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 382 | loss_1 = - torch.log(loss_1).mean() 383 | return loss_1 384 | 385 | return loss 386 | 387 | def reg_mask(self, mask, batch, size): 388 | 389 | key_num = scatter_add(mask, batch, dim=0, dim_size=size) 390 | env_num = scatter_add((1 - mask), batch, dim=0, dim_size=size) 391 | non_zero_mask = scatter_add((mask > 0).to(torch.float32), batch, dim=0, dim_size=size) 392 | all_mask = scatter_add(torch.ones_like(mask).to(torch.float32), batch, dim=0, dim_size=size) 393 | non_zero_ratio = non_zero_mask / (all_mask + 1e-8) 394 | return key_num + 1e-8, env_num + 1e-8, non_zero_ratio 395 | 396 | def forward(self, batch, progress): 397 | if progress < self.warm_ratio: 398 | return self.contrast_train(batch) 399 | else: 400 | keep_node_list, mask_edge_list = self.get_contrastive_cam(batch, self.inner_iter) 401 | indicater = keep_node_list[-1] 402 | mask_edge = mask_edge_list[-1] 403 | 404 | pos_score = self.positive(batch, indicater, mask_edge) 405 | neg_score = self.negative(batch, indicater, mask_edge) 406 | 407 | return pos_score - 0.1 * neg_score 408 | -------------------------------------------------------------------------------- /datasets/transfer_mol_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from itertools import repeat, chain 4 | 5 | import networkx as nx 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from rdkit import Chem 10 | from rdkit.Chem import AllChem 11 | from rdkit.Chem import Descriptors 12 | from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect 13 | from torch.utils import data 14 | from torch_geometric.data import Data 15 | from torch_geometric.data import InMemoryDataset 16 | 17 | # allowable node and edge features 18 | allowable_features = { 19 | 'possible_atomic_num_list': list(range(1, 119)), 20 | 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], 21 | 'possible_chirality_list': [ 22 | Chem.rdchem.ChiralType.CHI_UNSPECIFIED, 23 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, 24 | Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, 25 | Chem.rdchem.ChiralType.CHI_OTHER 26 | ], 27 | 'possible_hybridization_list': [ 28 | Chem.rdchem.HybridizationType.S, 29 | Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, 30 | Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, 31 | Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED 32 | ], 33 | 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8], 34 | 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], 35 | 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 36 | 'possible_bonds': [ 37 | Chem.rdchem.BondType.SINGLE, 38 | Chem.rdchem.BondType.DOUBLE, 39 | Chem.rdchem.BondType.TRIPLE, 40 | Chem.rdchem.BondType.AROMATIC 41 | ], 42 | 'possible_bond_dirs': [ # only for double bond stereo information 43 | Chem.rdchem.BondDir.NONE, 44 | Chem.rdchem.BondDir.ENDUPRIGHT, 45 | Chem.rdchem.BondDir.ENDDOWNRIGHT 46 | ] 47 | } 48 | 49 | 50 | def mol_to_graph_data_obj_simple(mol): 51 | """ 52 | Converts rdkit mol object to graph Data object required by the pytorch 53 | geometric package. NB: Uses simplified atom and bond features, and represent 54 | as indices 55 | :param mol: rdkit mol object 56 | :return: graph data object with the attributes: x, edge_index, edge_attr 57 | """ 58 | # atoms 59 | num_atom_features = 2 # atom type, chirality tag 60 | atom_features_list = [] 61 | for atom in mol.GetAtoms(): 62 | atom_feature = [allowable_features['possible_atomic_num_list'].index( 63 | atom.GetAtomicNum())] + [allowable_features[ 64 | 'possible_chirality_list'].index(atom.GetChiralTag())] 65 | atom_features_list.append(atom_feature) 66 | x = torch.tensor(np.array(atom_features_list), dtype=torch.long) 67 | 68 | # bonds 69 | num_bond_features = 2 # bond type, bond direction 70 | if len(mol.GetBonds()) > 0: # mol has bonds 71 | edges_list = [] 72 | edge_features_list = [] 73 | for bond in mol.GetBonds(): 74 | i = bond.GetBeginAtomIdx() 75 | j = bond.GetEndAtomIdx() 76 | edge_feature = [allowable_features['possible_bonds'].index( 77 | bond.GetBondType())] + [allowable_features[ 78 | 'possible_bond_dirs'].index( 79 | bond.GetBondDir())] 80 | edges_list.append((i, j)) 81 | edge_features_list.append(edge_feature) 82 | edges_list.append((j, i)) 83 | edge_features_list.append(edge_feature) 84 | 85 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] 86 | edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) 87 | 88 | # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] 89 | edge_attr = torch.tensor(np.array(edge_features_list), 90 | dtype=torch.long) 91 | else: # mol has no bonds 92 | edge_index = torch.empty((2, 0), dtype=torch.long) 93 | edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) 94 | 95 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) 96 | 97 | return data 98 | 99 | 100 | def graph_data_obj_to_mol_simple(data_x, data_edge_index, data_edge_attr): 101 | """ 102 | Convert pytorch geometric data obj to rdkit mol object. NB: Uses simplified 103 | atom and bond features, and represent as indices. 104 | :param: data_x: 105 | :param: data_edge_index: 106 | :param: data_edge_attr 107 | :return: 108 | """ 109 | mol = Chem.RWMol() 110 | 111 | # atoms 112 | atom_features = data_x.cpu().numpy() 113 | num_atoms = atom_features.shape[0] 114 | for i in range(num_atoms): 115 | atomic_num_idx, chirality_tag_idx = atom_features[i] 116 | atomic_num = allowable_features['possible_atomic_num_list'][atomic_num_idx] 117 | chirality_tag = allowable_features['possible_chirality_list'][chirality_tag_idx] 118 | atom = Chem.Atom(atomic_num) 119 | atom.SetChiralTag(chirality_tag) 120 | mol.AddAtom(atom) 121 | 122 | # bonds 123 | edge_index = data_edge_index.cpu().numpy() 124 | edge_attr = data_edge_attr.cpu().numpy() 125 | num_bonds = edge_index.shape[1] 126 | for j in range(0, num_bonds, 2): 127 | begin_idx = int(edge_index[0, j]) 128 | end_idx = int(edge_index[1, j]) 129 | bond_type_idx, bond_dir_idx = edge_attr[j] 130 | bond_type = allowable_features['possible_bonds'][bond_type_idx] 131 | bond_dir = allowable_features['possible_bond_dirs'][bond_dir_idx] 132 | mol.AddBond(begin_idx, end_idx, bond_type) 133 | # set bond direction 134 | new_bond = mol.GetBondBetweenAtoms(begin_idx, end_idx) 135 | new_bond.SetBondDir(bond_dir) 136 | 137 | # Chem.SanitizeMol(mol) # fails for COC1=CC2=C(NC(=N2)[S@@](=O)CC2=NC=C( 138 | # C)C(OC)=C2C)C=C1, when aromatic bond is possible 139 | # when we do not have aromatic bonds 140 | # Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 141 | 142 | return mol 143 | 144 | 145 | def graph_data_obj_to_nx_simple(data): 146 | """ 147 | Converts graph Data object required by the pytorch geometric package to 148 | network x data object. NB: Uses simplified atom and bond features, 149 | and represent as indices. NB: possible issues with recapitulating relative 150 | stereochemistry since the edges in the nx object are unordered. 151 | :param data: pytorch geometric Data object 152 | :return: network x object 153 | """ 154 | G = nx.Graph() 155 | 156 | # atoms 157 | atom_features = data.x.cpu().numpy() 158 | num_atoms = atom_features.shape[0] 159 | for i in range(num_atoms): 160 | atomic_num_idx, chirality_tag_idx = atom_features[i] 161 | G.add_node(i, atom_num_idx=atomic_num_idx, chirality_tag_idx=chirality_tag_idx) 162 | pass 163 | 164 | # bonds 165 | edge_index = data.edge_index.cpu().numpy() 166 | edge_attr = data.edge_attr.cpu().numpy() 167 | num_bonds = edge_index.shape[1] 168 | for j in range(0, num_bonds, 2): 169 | begin_idx = int(edge_index[0, j]) 170 | end_idx = int(edge_index[1, j]) 171 | bond_type_idx, bond_dir_idx = edge_attr[j] 172 | if not G.has_edge(begin_idx, end_idx): 173 | G.add_edge(begin_idx, end_idx, bond_type_idx=bond_type_idx, 174 | bond_dir_idx=bond_dir_idx) 175 | 176 | return G 177 | 178 | 179 | def nx_to_graph_data_obj_simple(G): 180 | """ 181 | Converts nx graph to pytorch geometric Data object. Assume node indices 182 | are numbered from 0 to num_nodes - 1. NB: Uses simplified atom and bond 183 | features, and represent as indices. NB: possible issues with 184 | recapitulating relative stereochemistry since the edges in the nx 185 | object are unordered. 186 | :param G: nx graph obj 187 | :return: pytorch geometric Data object 188 | """ 189 | # atoms 190 | num_atom_features = 2 # atom type, chirality tag 191 | atom_features_list = [] 192 | for _, node in G.nodes(data=True): 193 | atom_feature = [node['atom_num_idx'], node['chirality_tag_idx']] 194 | atom_features_list.append(atom_feature) 195 | x = torch.tensor(np.array(atom_features_list), dtype=torch.long) 196 | 197 | # bonds 198 | num_bond_features = 2 # bond type, bond direction 199 | if len(G.edges()) > 0: # mol has bonds 200 | edges_list = [] 201 | edge_features_list = [] 202 | for i, j, edge in G.edges(data=True): 203 | edge_feature = [edge['bond_type_idx'], edge['bond_dir_idx']] 204 | edges_list.append((i, j)) 205 | edge_features_list.append(edge_feature) 206 | edges_list.append((j, i)) 207 | edge_features_list.append(edge_feature) 208 | 209 | # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] 210 | edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) 211 | 212 | # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] 213 | edge_attr = torch.tensor(np.array(edge_features_list), 214 | dtype=torch.long) 215 | else: # mol has no bonds 216 | edge_index = torch.empty((2, 0), dtype=torch.long) 217 | edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) 218 | 219 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) 220 | 221 | return data 222 | 223 | 224 | def get_gasteiger_partial_charges(mol, n_iter=12): 225 | """ 226 | Calculates list of gasteiger partial charges for each atom in mol object. 227 | :param mol: rdkit mol object 228 | :param n_iter: number of iterations. Default 12 229 | :return: list of computed partial charges for each atom. 230 | """ 231 | Chem.rdPartialCharges.ComputeGasteigerCharges(mol, nIter=n_iter, 232 | throwOnParamFailure=True) 233 | partial_charges = [float(a.GetProp('_GasteigerCharge')) for a in 234 | mol.GetAtoms()] 235 | return partial_charges 236 | 237 | 238 | def create_standardized_mol_id(smiles): 239 | """ 240 | 241 | :param smiles: 242 | :return: inchi 243 | """ 244 | if check_smiles_validity(smiles): 245 | # remove stereochemistry 246 | smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), 247 | isomericSmiles=False) 248 | mol = AllChem.MolFromSmiles(smiles) 249 | if mol != None: # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 250 | if '.' in smiles: # if multiple species, pick largest molecule 251 | mol_species_list = split_rdkit_mol_obj(mol) 252 | largest_mol = get_largest_mol(mol_species_list) 253 | inchi = AllChem.MolToInchi(largest_mol) 254 | else: 255 | inchi = AllChem.MolToInchi(mol) 256 | return inchi 257 | else: 258 | return 259 | else: 260 | return 261 | 262 | 263 | class MoleculeDataset_aug(InMemoryDataset): 264 | def __init__(self, 265 | root, 266 | # data = None, 267 | # slices = None, 268 | transform=None, 269 | pre_transform=None, 270 | pre_filter=None, 271 | dataset='zinc250k', 272 | empty=False, 273 | aug="none", aug_ratio=None): 274 | """ 275 | Adapted from qm9.py. Disabled the download functionality 276 | :param root: directory of the dataset, containing a raw and processed 277 | dir. The raw dir should contain the file containing the smiles, and the 278 | processed dir can either empty or a previously processed file 279 | :param dataset: name of the dataset. Currently only implemented for 280 | zinc250k, chembl_with_labels, tox21, hiv, bace, bbbp, clintox, esol, 281 | freesolv, lipophilicity, muv, pcba, sider, toxcast 282 | :param empty: if True, then will not load any data obj. For 283 | initializing empty dataset 284 | """ 285 | self.dataset = dataset 286 | self.root = root 287 | self.aug = aug 288 | self.aug_ratio = aug_ratio 289 | 290 | super(MoleculeDataset_aug, self).__init__(root, transform, pre_transform, 291 | pre_filter) 292 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter 293 | 294 | if not empty: 295 | self.data, self.slices = torch.load(self.processed_paths[0]) 296 | 297 | def get(self, idx): 298 | data = Data() 299 | for key in self.data.keys: 300 | item, slices = self.data[key], self.slices[key] 301 | s = list(repeat(slice(None), item.dim())) 302 | s[data.cat_dim(key, item)] = slice(slices[idx], 303 | slices[idx + 1]) 304 | data[key] = item[s] 305 | 306 | if self.aug == 'dropN': 307 | data = drop_nodes(data, self.aug_ratio) 308 | elif self.aug == 'permE': 309 | data = permute_edges(data, self.aug_ratio) 310 | elif self.aug == 'maskN': 311 | data = mask_nodes(data, self.aug_ratio) 312 | elif self.aug == 'subgraph': 313 | data = subgraph(data, self.aug_ratio) 314 | elif self.aug == 'random': 315 | n = np.random.randint(2) 316 | if n == 0: 317 | data = drop_nodes(data, self.aug_ratio) 318 | elif n == 1: 319 | data = subgraph(data, self.aug_ratio) 320 | # data = subgraph(data, 0.5) 321 | else: 322 | print('augmentation error') 323 | assert False 324 | elif self.aug == 'none': 325 | None 326 | else: 327 | print('augmentation error') 328 | assert False 329 | 330 | return data 331 | 332 | @property 333 | def raw_file_names(self): 334 | file_name_list = os.listdir(self.raw_dir) 335 | # assert len(file_name_list) == 1 # currently assume we have a 336 | # # single raw file 337 | return file_name_list 338 | 339 | @property 340 | def processed_file_names(self): 341 | return 'geometric_data_processed.pt' 342 | 343 | def download(self): 344 | raise NotImplementedError('Must indicate valid location of raw data. ' 345 | 'No download allowed') 346 | 347 | def process(self): 348 | data_smiles_list = [] 349 | data_list = [] 350 | 351 | if self.dataset == 'zinc_standard_agent': 352 | input_path = self.raw_paths[0] 353 | input_df = pd.read_csv(input_path, sep=',', compression='gzip', 354 | dtype='str') 355 | smiles_list = list(input_df['smiles']) 356 | zinc_id_list = list(input_df['zinc_id']) 357 | for i in range(len(smiles_list)): 358 | print(i) 359 | s = smiles_list[i] 360 | # each example contains a single species 361 | try: 362 | rdkit_mol = AllChem.MolFromSmiles(s) 363 | if rdkit_mol != None: # ignore invalid mol objects 364 | # # convert aromatic bonds to double bonds 365 | # Chem.SanitizeMol(rdkit_mol, 366 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 367 | data = mol_to_graph_data_obj_simple(rdkit_mol) 368 | # manually add mol id 369 | id = int(zinc_id_list[i].split('ZINC')[1].lstrip('0')) 370 | data.id = torch.tensor( 371 | [id]) # id here is zinc id value, stripped of 372 | # leading zeros 373 | data_list.append(data) 374 | data_smiles_list.append(smiles_list[i]) 375 | except: 376 | continue 377 | 378 | elif self.dataset == 'chembl_filtered': 379 | ### get downstream test molecules. 380 | from splitters import scaffold_split 381 | 382 | ### 383 | downstream_dir = [ 384 | 'storage/datasets/transfer_dataset/bace', 385 | 'storage/datasets/transfer_dataset/bbbp', 386 | 'storage/datasets/transfer_dataset/clintox', 387 | 'storage/datasets/transfer_dataset/esol', 388 | 'storage/datasets/transfer_dataset/freesolv', 389 | 'storage/datasets/transfer_dataset/hiv', 390 | 'storage/datasets/transfer_dataset/lipophilicity', 391 | 'storage/datasets/transfer_dataset/muv', 392 | # 'storage/datasets/transfer_dataset/pcba/processed/smiles.csv', 393 | 'storage/datasets/transfer_dataset/sider', 394 | 'storage/datasets/transfer_dataset/tox21', 395 | 'storage/datasets/transfer_dataset/toxcast' 396 | ] 397 | 398 | downstream_inchi_set = set() 399 | for d_path in downstream_dir: 400 | print(d_path) 401 | dataset_name = d_path.split('/')[-1] 402 | downstream_dataset = MoleculeDataset(d_path, dataset=dataset_name) 403 | downstream_smiles = pd.read_csv(os.path.join(d_path, 404 | 'processed', 'smiles.csv'), 405 | header=None)[0].tolist() 406 | 407 | assert len(downstream_dataset) == len(downstream_smiles) 408 | 409 | _, _, _, (train_smiles, valid_smiles, test_smiles) = scaffold_split(downstream_dataset, 410 | downstream_smiles, task_idx=None, 411 | null_value=0, 412 | frac_train=0.8, frac_valid=0.1, 413 | frac_test=0.1, 414 | return_smiles=True) 415 | 416 | ### remove both test and validation molecules 417 | remove_smiles = test_smiles + valid_smiles 418 | 419 | downstream_inchis = [] 420 | for smiles in remove_smiles: 421 | species_list = smiles.split('.') 422 | for s in species_list: # record inchi for all species, not just 423 | # largest (by default in create_standardized_mol_id if input has 424 | # multiple species) 425 | inchi = create_standardized_mol_id(s) 426 | downstream_inchis.append(inchi) 427 | downstream_inchi_set.update(downstream_inchis) 428 | 429 | smiles_list, rdkit_mol_objs, folds, labels = \ 430 | _load_chembl_with_labels_dataset(os.path.join(self.root, 'raw')) 431 | 432 | print('processing') 433 | for i in range(len(rdkit_mol_objs)): 434 | print(i) 435 | rdkit_mol = rdkit_mol_objs[i] 436 | if rdkit_mol != None: 437 | # # convert aromatic bonds to double bonds 438 | # Chem.SanitizeMol(rdkit_mol, 439 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 440 | mw = Descriptors.MolWt(rdkit_mol) 441 | if 50 <= mw <= 900: 442 | inchi = create_standardized_mol_id(smiles_list[i]) 443 | if inchi != None and inchi not in downstream_inchi_set: 444 | data = mol_to_graph_data_obj_simple(rdkit_mol) 445 | # manually add mol id 446 | data.id = torch.tensor( 447 | [i]) # id here is the index of the mol in 448 | # the dataset 449 | data.y = torch.tensor(labels[i, :]) 450 | # fold information 451 | if i in folds[0]: 452 | data.fold = torch.tensor([0]) 453 | elif i in folds[1]: 454 | data.fold = torch.tensor([1]) 455 | else: 456 | data.fold = torch.tensor([2]) 457 | data_list.append(data) 458 | data_smiles_list.append(smiles_list[i]) 459 | 460 | elif self.dataset == 'tox21': 461 | smiles_list, rdkit_mol_objs, labels = \ 462 | _load_tox21_dataset(self.raw_paths[0]) 463 | for i in range(len(smiles_list)): 464 | print(i) 465 | rdkit_mol = rdkit_mol_objs[i] 466 | ## convert aromatic bonds to double bonds 467 | # Chem.SanitizeMol(rdkit_mol, 468 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 469 | data = mol_to_graph_data_obj_simple(rdkit_mol) 470 | # manually add mol id 471 | data.id = torch.tensor( 472 | [i]) # id here is the index of the mol in 473 | # the dataset 474 | data.y = torch.tensor(labels[i, :]) 475 | data_list.append(data) 476 | data_smiles_list.append(smiles_list[i]) 477 | 478 | elif self.dataset == 'hiv': 479 | smiles_list, rdkit_mol_objs, labels = \ 480 | _load_hiv_dataset(self.raw_paths[0]) 481 | for i in range(len(smiles_list)): 482 | print(i) 483 | rdkit_mol = rdkit_mol_objs[i] 484 | # # convert aromatic bonds to double bonds 485 | # Chem.SanitizeMol(rdkit_mol, 486 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 487 | data = mol_to_graph_data_obj_simple(rdkit_mol) 488 | # manually add mol id 489 | data.id = torch.tensor( 490 | [i]) # id here is the index of the mol in 491 | # the dataset 492 | data.y = torch.tensor([labels[i]]) 493 | data_list.append(data) 494 | data_smiles_list.append(smiles_list[i]) 495 | 496 | elif self.dataset == 'bace': 497 | smiles_list, rdkit_mol_objs, folds, labels = \ 498 | _load_bace_dataset(self.raw_paths[0]) 499 | for i in range(len(smiles_list)): 500 | print(i) 501 | rdkit_mol = rdkit_mol_objs[i] 502 | # # convert aromatic bonds to double bonds 503 | # Chem.SanitizeMol(rdkit_mol, 504 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 505 | data = mol_to_graph_data_obj_simple(rdkit_mol) 506 | # manually add mol id 507 | data.id = torch.tensor( 508 | [i]) # id here is the index of the mol in 509 | # the dataset 510 | data.y = torch.tensor([labels[i]]) 511 | data.fold = torch.tensor([folds[i]]) 512 | data_list.append(data) 513 | data_smiles_list.append(smiles_list[i]) 514 | 515 | elif self.dataset == 'bbbp': 516 | smiles_list, rdkit_mol_objs, labels = \ 517 | _load_bbbp_dataset(self.raw_paths[0]) 518 | for i in range(len(smiles_list)): 519 | print(i) 520 | rdkit_mol = rdkit_mol_objs[i] 521 | if rdkit_mol != None: 522 | # # convert aromatic bonds to double bonds 523 | # Chem.SanitizeMol(rdkit_mol, 524 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 525 | data = mol_to_graph_data_obj_simple(rdkit_mol) 526 | # manually add mol id 527 | data.id = torch.tensor( 528 | [i]) # id here is the index of the mol in 529 | # the dataset 530 | data.y = torch.tensor([labels[i]]) 531 | data_list.append(data) 532 | data_smiles_list.append(smiles_list[i]) 533 | 534 | elif self.dataset == 'clintox': 535 | smiles_list, rdkit_mol_objs, labels = \ 536 | _load_clintox_dataset(self.raw_paths[0]) 537 | for i in range(len(smiles_list)): 538 | print(i) 539 | rdkit_mol = rdkit_mol_objs[i] 540 | if rdkit_mol != None: 541 | # # convert aromatic bonds to double bonds 542 | # Chem.SanitizeMol(rdkit_mol, 543 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 544 | data = mol_to_graph_data_obj_simple(rdkit_mol) 545 | # manually add mol id 546 | data.id = torch.tensor( 547 | [i]) # id here is the index of the mol in 548 | # the dataset 549 | data.y = torch.tensor(labels[i, :]) 550 | data_list.append(data) 551 | data_smiles_list.append(smiles_list[i]) 552 | 553 | elif self.dataset == 'esol': 554 | smiles_list, rdkit_mol_objs, labels = \ 555 | _load_esol_dataset(self.raw_paths[0]) 556 | for i in range(len(smiles_list)): 557 | print(i) 558 | rdkit_mol = rdkit_mol_objs[i] 559 | # # convert aromatic bonds to double bonds 560 | # Chem.SanitizeMol(rdkit_mol, 561 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 562 | data = mol_to_graph_data_obj_simple(rdkit_mol) 563 | # manually add mol id 564 | data.id = torch.tensor( 565 | [i]) # id here is the index of the mol in 566 | # the dataset 567 | data.y = torch.tensor([labels[i]]) 568 | data_list.append(data) 569 | data_smiles_list.append(smiles_list[i]) 570 | 571 | elif self.dataset == 'freesolv': 572 | smiles_list, rdkit_mol_objs, labels = \ 573 | _load_freesolv_dataset(self.raw_paths[0]) 574 | for i in range(len(smiles_list)): 575 | print(i) 576 | rdkit_mol = rdkit_mol_objs[i] 577 | # # convert aromatic bonds to double bonds 578 | # Chem.SanitizeMol(rdkit_mol, 579 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 580 | data = mol_to_graph_data_obj_simple(rdkit_mol) 581 | # manually add mol id 582 | data.id = torch.tensor( 583 | [i]) # id here is the index of the mol in 584 | # the dataset 585 | data.y = torch.tensor([labels[i]]) 586 | data_list.append(data) 587 | data_smiles_list.append(smiles_list[i]) 588 | 589 | elif self.dataset == 'lipophilicity': 590 | smiles_list, rdkit_mol_objs, labels = \ 591 | _load_lipophilicity_dataset(self.raw_paths[0]) 592 | for i in range(len(smiles_list)): 593 | print(i) 594 | rdkit_mol = rdkit_mol_objs[i] 595 | # # convert aromatic bonds to double bonds 596 | # Chem.SanitizeMol(rdkit_mol, 597 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 598 | data = mol_to_graph_data_obj_simple(rdkit_mol) 599 | # manually add mol id 600 | data.id = torch.tensor( 601 | [i]) # id here is the index of the mol in 602 | # the dataset 603 | data.y = torch.tensor([labels[i]]) 604 | data_list.append(data) 605 | data_smiles_list.append(smiles_list[i]) 606 | 607 | elif self.dataset == 'muv': 608 | smiles_list, rdkit_mol_objs, labels = \ 609 | _load_muv_dataset(self.raw_paths[0]) 610 | for i in range(len(smiles_list)): 611 | print(i) 612 | rdkit_mol = rdkit_mol_objs[i] 613 | # # convert aromatic bonds to double bonds 614 | # Chem.SanitizeMol(rdkit_mol, 615 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 616 | data = mol_to_graph_data_obj_simple(rdkit_mol) 617 | # manually add mol id 618 | data.id = torch.tensor( 619 | [i]) # id here is the index of the mol in 620 | # the dataset 621 | data.y = torch.tensor(labels[i, :]) 622 | data_list.append(data) 623 | data_smiles_list.append(smiles_list[i]) 624 | 625 | elif self.dataset == 'pcba': 626 | smiles_list, rdkit_mol_objs, labels = \ 627 | _load_pcba_dataset(self.raw_paths[0]) 628 | for i in range(len(smiles_list)): 629 | print(i) 630 | rdkit_mol = rdkit_mol_objs[i] 631 | # # convert aromatic bonds to double bonds 632 | # Chem.SanitizeMol(rdkit_mol, 633 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 634 | data = mol_to_graph_data_obj_simple(rdkit_mol) 635 | # manually add mol id 636 | data.id = torch.tensor( 637 | [i]) # id here is the index of the mol in 638 | # the dataset 639 | data.y = torch.tensor(labels[i, :]) 640 | data_list.append(data) 641 | data_smiles_list.append(smiles_list[i]) 642 | 643 | elif self.dataset == 'pcba_pretrain': 644 | smiles_list, rdkit_mol_objs, labels = \ 645 | _load_pcba_dataset(self.raw_paths[0]) 646 | downstream_inchi = set(pd.read_csv(os.path.join(self.root, 647 | 'downstream_mol_inchi_may_24_2019'), 648 | sep=',', header=None)[0]) 649 | for i in range(len(smiles_list)): 650 | print(i) 651 | if '.' not in smiles_list[i]: # remove examples with 652 | # multiples species 653 | rdkit_mol = rdkit_mol_objs[i] 654 | mw = Descriptors.MolWt(rdkit_mol) 655 | if 50 <= mw <= 900: 656 | inchi = create_standardized_mol_id(smiles_list[i]) 657 | if inchi != None and inchi not in downstream_inchi: 658 | # # convert aromatic bonds to double bonds 659 | # Chem.SanitizeMol(rdkit_mol, 660 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 661 | data = mol_to_graph_data_obj_simple(rdkit_mol) 662 | # manually add mol id 663 | data.id = torch.tensor( 664 | [i]) # id here is the index of the mol in 665 | # the dataset 666 | data.y = torch.tensor(labels[i, :]) 667 | data_list.append(data) 668 | data_smiles_list.append(smiles_list[i]) 669 | 670 | # elif self.dataset == '' 671 | 672 | elif self.dataset == 'sider': 673 | smiles_list, rdkit_mol_objs, labels = \ 674 | _load_sider_dataset(self.raw_paths[0]) 675 | for i in range(len(smiles_list)): 676 | print(i) 677 | rdkit_mol = rdkit_mol_objs[i] 678 | # # convert aromatic bonds to double bonds 679 | # Chem.SanitizeMol(rdkit_mol, 680 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 681 | data = mol_to_graph_data_obj_simple(rdkit_mol) 682 | # manually add mol id 683 | data.id = torch.tensor( 684 | [i]) # id here is the index of the mol in 685 | # the dataset 686 | data.y = torch.tensor(labels[i, :]) 687 | data_list.append(data) 688 | data_smiles_list.append(smiles_list[i]) 689 | 690 | elif self.dataset == 'toxcast': 691 | smiles_list, rdkit_mol_objs, labels = \ 692 | _load_toxcast_dataset(self.raw_paths[0]) 693 | for i in range(len(smiles_list)): 694 | print(i) 695 | rdkit_mol = rdkit_mol_objs[i] 696 | if rdkit_mol != None: 697 | # # convert aromatic bonds to double bonds 698 | # Chem.SanitizeMol(rdkit_mol, 699 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 700 | data = mol_to_graph_data_obj_simple(rdkit_mol) 701 | # manually add mol id 702 | data.id = torch.tensor( 703 | [i]) # id here is the index of the mol in 704 | # the dataset 705 | data.y = torch.tensor(labels[i, :]) 706 | data_list.append(data) 707 | data_smiles_list.append(smiles_list[i]) 708 | 709 | elif self.dataset == 'ptc_mr': 710 | input_path = self.raw_paths[0] 711 | input_df = pd.read_csv(input_path, sep=',', header=None, names=['id', 'label', 'smiles']) 712 | smiles_list = input_df['smiles'] 713 | labels = input_df['label'].values 714 | for i in range(len(smiles_list)): 715 | print(i) 716 | s = smiles_list[i] 717 | rdkit_mol = AllChem.MolFromSmiles(s) 718 | if rdkit_mol != None: # ignore invalid mol objects 719 | # # convert aromatic bonds to double bonds 720 | # Chem.SanitizeMol(rdkit_mol, 721 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 722 | data = mol_to_graph_data_obj_simple(rdkit_mol) 723 | # manually add mol id 724 | data.id = torch.tensor( 725 | [i]) 726 | data.y = torch.tensor([labels[i]]) 727 | data_list.append(data) 728 | data_smiles_list.append(smiles_list[i]) 729 | 730 | elif self.dataset == 'mutag': 731 | smiles_path = os.path.join(self.root, 'raw', 'mutag_188_data.can') 732 | # smiles_path = 'storage/datasets/transfer_dataset/mutag/raw/mutag_188_data.can' 733 | labels_path = os.path.join(self.root, 'raw', 'mutag_188_target.txt') 734 | # labels_path = 'storage/datasets/transfer_dataset/mutag/raw/mutag_188_target.txt' 735 | smiles_list = pd.read_csv(smiles_path, sep=' ', header=None)[0] 736 | labels = pd.read_csv(labels_path, header=None)[0].values 737 | for i in range(len(smiles_list)): 738 | print(i) 739 | s = smiles_list[i] 740 | rdkit_mol = AllChem.MolFromSmiles(s) 741 | if rdkit_mol != None: # ignore invalid mol objects 742 | # # convert aromatic bonds to double bonds 743 | # Chem.SanitizeMol(rdkit_mol, 744 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 745 | data = mol_to_graph_data_obj_simple(rdkit_mol) 746 | # manually add mol id 747 | data.id = torch.tensor( 748 | [i]) 749 | data.y = torch.tensor([labels[i]]) 750 | data_list.append(data) 751 | data_smiles_list.append(smiles_list[i]) 752 | 753 | 754 | else: 755 | raise ValueError('Invalid dataset name') 756 | 757 | if self.pre_filter is not None: 758 | data_list = [data for data in data_list if self.pre_filter(data)] 759 | 760 | if self.pre_transform is not None: 761 | data_list = [self.pre_transform(data) for data in data_list] 762 | 763 | # write data_smiles_list in processed paths 764 | data_smiles_series = pd.Series(data_smiles_list) 765 | data_smiles_series.to_csv(os.path.join(self.processed_dir, 766 | 'smiles.csv'), index=False, 767 | header=False) 768 | 769 | data, slices = self.collate(data_list) 770 | torch.save((data, slices), self.processed_paths[0]) 771 | 772 | 773 | def drop_nodes(data, aug_ratio): 774 | node_num, _ = data.x.size() 775 | _, edge_num = data.edge_index.size() 776 | drop_num = int(node_num * aug_ratio) 777 | 778 | idx_perm = np.random.permutation(node_num) 779 | 780 | idx_drop = idx_perm[:drop_num] 781 | idx_nondrop = idx_perm[drop_num:] 782 | idx_nondrop.sort() 783 | idx_dict = {idx_nondrop[n]: n for n in list(range(idx_nondrop.shape[0]))} 784 | 785 | edge_index = data.edge_index.numpy() 786 | edge_mask = np.array( 787 | [n for n in range(edge_num) if not (edge_index[0, n] in idx_drop or edge_index[1, n] in idx_drop)]) 788 | 789 | edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if 790 | (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] 791 | try: 792 | data.edge_index = torch.tensor(edge_index).transpose_(0, 1) 793 | data.x = data.x[idx_nondrop] 794 | data.edge_attr = data.edge_attr[edge_mask] 795 | except: 796 | data = data 797 | 798 | return data 799 | 800 | 801 | def permute_edges(data, aug_ratio): 802 | node_num, _ = data.x.size() 803 | _, edge_num = data.edge_index.size() 804 | permute_num = int(edge_num * aug_ratio) 805 | edge_index = data.edge_index.numpy() 806 | 807 | idx_add = np.random.choice(node_num, (2, permute_num)) 808 | edge_index = np.concatenate( 809 | (edge_index[:, np.random.choice(edge_num, (edge_num - permute_num), replace=False)], idx_add), axis=1) 810 | data.edge_index = torch.tensor(edge_index) 811 | 812 | return data 813 | 814 | 815 | def mask_nodes(data, aug_ratio): 816 | node_num, feat_dim = data.x.size() 817 | mask_num = int(node_num * aug_ratio) 818 | 819 | token = data.x.mean(dim=0) 820 | idx_mask = np.random.choice(node_num, mask_num, replace=False) 821 | data.x[idx_mask] = torch.tensor(token, dtype=torch.float32) 822 | 823 | return data 824 | 825 | 826 | def subgraph(data, aug_ratio): 827 | node_num, _ = data.x.size() 828 | _, edge_num = data.edge_index.size() 829 | sub_num = int(node_num * aug_ratio) 830 | 831 | edge_index = data.edge_index.numpy() 832 | 833 | idx_sub = [np.random.randint(node_num, size=1)[0]] 834 | idx_neigh = set([n for n in edge_index[1][edge_index[0] == idx_sub[0]]]) 835 | 836 | count = 0 837 | while len(idx_sub) <= sub_num: 838 | count = count + 1 839 | if count > node_num: 840 | break 841 | if len(idx_neigh) == 0: 842 | break 843 | sample_node = np.random.choice(list(idx_neigh)) 844 | if sample_node in idx_sub: 845 | continue 846 | idx_sub.append(sample_node) 847 | idx_neigh.union(set([n for n in edge_index[1][edge_index[0] == idx_sub[-1]]])) 848 | 849 | idx_drop = [n for n in range(node_num) if not n in idx_sub] 850 | idx_nondrop = idx_sub 851 | idx_dict = {idx_nondrop[n]: n for n in list(range(len(idx_nondrop)))} 852 | edge_mask = np.array( 853 | [n for n in range(edge_num) if (edge_index[0, n] in idx_nondrop and edge_index[1, n] in idx_nondrop)]) 854 | 855 | edge_index = data.edge_index.numpy() 856 | edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if 857 | (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] 858 | try: 859 | data.edge_index = torch.tensor(edge_index).transpose_(0, 1) 860 | data.x = data.x[idx_nondrop] 861 | data.edge_attr = data.edge_attr[edge_mask] 862 | except: 863 | data = data 864 | 865 | return data 866 | 867 | 868 | class MoleculeDataset(InMemoryDataset): 869 | def __init__(self, 870 | root, 871 | # data = None, 872 | # slices = None, 873 | transform=None, 874 | pre_transform=None, 875 | pre_filter=None, 876 | dataset='zinc250k', 877 | empty=False): 878 | """ 879 | Adapted from qm9.py. Disabled the download functionality 880 | :param root: directory of the dataset, containing a raw and processed 881 | dir. The raw dir should contain the file containing the smiles, and the 882 | processed dir can either empty or a previously processed file 883 | :param dataset: name of the dataset. Currently only implemented for 884 | zinc250k, chembl_with_labels, tox21, hiv, bace, bbbp, clintox, esol, 885 | freesolv, lipophilicity, muv, pcba, sider, toxcast 886 | :param empty: if True, then will not load any data obj. For 887 | initializing empty dataset 888 | """ 889 | self.dataset = dataset 890 | self.root = root 891 | 892 | super(MoleculeDataset, self).__init__(root, transform, pre_transform, 893 | pre_filter) 894 | self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter 895 | 896 | if not empty: 897 | self.data, self.slices = torch.load(self.processed_paths[0]) 898 | 899 | def get(self, idx): 900 | data = Data() 901 | for key in self.data.keys: 902 | item, slices = self.data[key], self.slices[key] 903 | s = list(repeat(slice(None), item.dim())) 904 | s[data.__cat_dim__(key, item)] = slice(slices[idx], 905 | slices[idx + 1]) 906 | data[key] = item[s] 907 | return data 908 | 909 | @property 910 | def raw_file_names(self): 911 | file_name_list = os.listdir(self.raw_dir) 912 | # assert len(file_name_list) == 1 # currently assume we have a 913 | # # single raw file 914 | return file_name_list 915 | 916 | @property 917 | def processed_file_names(self): 918 | return 'geometric_data_processed.pt' 919 | 920 | def download(self): 921 | raise NotImplementedError('Must indicate valid location of raw data. ' 922 | 'No download allowed') 923 | 924 | def process(self): 925 | data_smiles_list = [] 926 | data_list = [] 927 | 928 | if self.dataset == 'zinc_standard_agent': 929 | input_path = self.raw_paths[0] 930 | input_df = pd.read_csv(input_path, sep=',', compression='gzip', 931 | dtype='str') 932 | smiles_list = list(input_df['smiles']) 933 | zinc_id_list = list(input_df['zinc_id']) 934 | for i in range(len(smiles_list)): 935 | print(i) 936 | s = smiles_list[i] 937 | # each example contains a single species 938 | try: 939 | rdkit_mol = AllChem.MolFromSmiles(s) 940 | if rdkit_mol != None: # ignore invalid mol objects 941 | # # convert aromatic bonds to double bonds 942 | # Chem.SanitizeMol(rdkit_mol, 943 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 944 | data = mol_to_graph_data_obj_simple(rdkit_mol) 945 | # manually add mol id 946 | id = int(zinc_id_list[i].split('ZINC')[1].lstrip('0')) 947 | data.id = torch.tensor( 948 | [id]) # id here is zinc id value, stripped of 949 | # leading zeros 950 | data_list.append(data) 951 | data_smiles_list.append(smiles_list[i]) 952 | except: 953 | continue 954 | 955 | elif self.dataset == 'chembl_filtered': 956 | ### get downstream test molecules. 957 | from splitters import scaffold_split 958 | 959 | ### 960 | downstream_dir = [ 961 | 'storage/datasets/transfer_dataset/bace', 962 | 'storage/datasets/transfer_dataset/bbbp', 963 | 'storage/datasets/transfer_dataset/clintox', 964 | 'storage/datasets/transfer_dataset/esol', 965 | 'storage/datasets/transfer_dataset/freesolv', 966 | 'storage/datasets/transfer_dataset/hiv', 967 | 'storage/datasets/transfer_dataset/lipophilicity', 968 | 'storage/datasets/transfer_dataset/muv', 969 | # 'storage/datasets/transfer_dataset/pcba/processed/smiles.csv', 970 | 'storage/datasets/transfer_dataset/sider', 971 | 'storage/datasets/transfer_dataset/tox21', 972 | 'storage/datasets/transfer_dataset/toxcast' 973 | ] 974 | 975 | downstream_inchi_set = set() 976 | for d_path in downstream_dir: 977 | print(d_path) 978 | dataset_name = d_path.split('/')[-1] 979 | downstream_dataset = MoleculeDataset(d_path, dataset=dataset_name) 980 | downstream_smiles = pd.read_csv(os.path.join(d_path, 981 | 'processed', 'smiles.csv'), 982 | header=None)[0].tolist() 983 | 984 | assert len(downstream_dataset) == len(downstream_smiles) 985 | 986 | _, _, _, (train_smiles, valid_smiles, test_smiles) = scaffold_split(downstream_dataset, 987 | downstream_smiles, task_idx=None, 988 | null_value=0, 989 | frac_train=0.8, frac_valid=0.1, 990 | frac_test=0.1, 991 | return_smiles=True) 992 | 993 | ### remove both test and validation molecules 994 | remove_smiles = test_smiles + valid_smiles 995 | 996 | downstream_inchis = [] 997 | for smiles in remove_smiles: 998 | species_list = smiles.split('.') 999 | for s in species_list: # record inchi for all species, not just 1000 | # largest (by default in create_standardized_mol_id if input has 1001 | # multiple species) 1002 | inchi = create_standardized_mol_id(s) 1003 | downstream_inchis.append(inchi) 1004 | downstream_inchi_set.update(downstream_inchis) 1005 | 1006 | smiles_list, rdkit_mol_objs, folds, labels = \ 1007 | _load_chembl_with_labels_dataset(os.path.join(self.root, 'raw')) 1008 | 1009 | print('processing') 1010 | for i in range(len(rdkit_mol_objs)): 1011 | print(i) 1012 | rdkit_mol = rdkit_mol_objs[i] 1013 | if rdkit_mol != None: 1014 | # # convert aromatic bonds to double bonds 1015 | # Chem.SanitizeMol(rdkit_mol, 1016 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1017 | mw = Descriptors.MolWt(rdkit_mol) 1018 | if 50 <= mw <= 900: 1019 | inchi = create_standardized_mol_id(smiles_list[i]) 1020 | if inchi != None and inchi not in downstream_inchi_set: 1021 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1022 | # manually add mol id 1023 | data.id = torch.tensor( 1024 | [i]) # id here is the index of the mol in 1025 | # the dataset 1026 | data.y = torch.tensor(labels[i, :]) 1027 | # fold information 1028 | if i in folds[0]: 1029 | data.fold = torch.tensor([0]) 1030 | elif i in folds[1]: 1031 | data.fold = torch.tensor([1]) 1032 | else: 1033 | data.fold = torch.tensor([2]) 1034 | data_list.append(data) 1035 | data_smiles_list.append(smiles_list[i]) 1036 | 1037 | elif self.dataset == 'tox21': 1038 | smiles_list, rdkit_mol_objs, labels = \ 1039 | _load_tox21_dataset(self.raw_paths[0]) 1040 | for i in range(len(smiles_list)): 1041 | print(i) 1042 | rdkit_mol = rdkit_mol_objs[i] 1043 | ## convert aromatic bonds to double bonds 1044 | # Chem.SanitizeMol(rdkit_mol, 1045 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1046 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1047 | # manually add mol id 1048 | data.id = torch.tensor( 1049 | [i]) # id here is the index of the mol in 1050 | # the dataset 1051 | data.y = torch.tensor(labels[i, :]) 1052 | data_list.append(data) 1053 | data_smiles_list.append(smiles_list[i]) 1054 | 1055 | elif self.dataset == 'hiv': 1056 | smiles_list, rdkit_mol_objs, labels = \ 1057 | _load_hiv_dataset(self.raw_paths[0]) 1058 | for i in range(len(smiles_list)): 1059 | print(i) 1060 | rdkit_mol = rdkit_mol_objs[i] 1061 | # # convert aromatic bonds to double bonds 1062 | # Chem.SanitizeMol(rdkit_mol, 1063 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1064 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1065 | # manually add mol id 1066 | data.id = torch.tensor( 1067 | [i]) # id here is the index of the mol in 1068 | # the dataset 1069 | data.y = torch.tensor([labels[i]]) 1070 | data_list.append(data) 1071 | data_smiles_list.append(smiles_list[i]) 1072 | 1073 | elif self.dataset == 'bace': 1074 | smiles_list, rdkit_mol_objs, folds, labels = \ 1075 | _load_bace_dataset(self.raw_paths[0]) 1076 | for i in range(len(smiles_list)): 1077 | print(i) 1078 | rdkit_mol = rdkit_mol_objs[i] 1079 | # # convert aromatic bonds to double bonds 1080 | # Chem.SanitizeMol(rdkit_mol, 1081 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1082 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1083 | # manually add mol id 1084 | data.id = torch.tensor( 1085 | [i]) # id here is the index of the mol in 1086 | # the dataset 1087 | data.y = torch.tensor([labels[i]]) 1088 | data.fold = torch.tensor([folds[i]]) 1089 | data_list.append(data) 1090 | data_smiles_list.append(smiles_list[i]) 1091 | 1092 | elif self.dataset == 'bbbp': 1093 | smiles_list, rdkit_mol_objs, labels = \ 1094 | _load_bbbp_dataset(self.raw_paths[0]) 1095 | for i in range(len(smiles_list)): 1096 | print(i) 1097 | rdkit_mol = rdkit_mol_objs[i] 1098 | if rdkit_mol != None: 1099 | # # convert aromatic bonds to double bonds 1100 | # Chem.SanitizeMol(rdkit_mol, 1101 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1102 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1103 | # manually add mol id 1104 | data.id = torch.tensor( 1105 | [i]) # id here is the index of the mol in 1106 | # the dataset 1107 | data.y = torch.tensor([labels[i]]) 1108 | data_list.append(data) 1109 | data_smiles_list.append(smiles_list[i]) 1110 | 1111 | elif self.dataset == 'clintox': 1112 | smiles_list, rdkit_mol_objs, labels = \ 1113 | _load_clintox_dataset(self.raw_paths[0]) 1114 | for i in range(len(smiles_list)): 1115 | print(i) 1116 | rdkit_mol = rdkit_mol_objs[i] 1117 | if rdkit_mol != None: 1118 | # # convert aromatic bonds to double bonds 1119 | # Chem.SanitizeMol(rdkit_mol, 1120 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1121 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1122 | # manually add mol id 1123 | data.id = torch.tensor( 1124 | [i]) # id here is the index of the mol in 1125 | # the dataset 1126 | data.y = torch.tensor(labels[i, :]) 1127 | data_list.append(data) 1128 | data_smiles_list.append(smiles_list[i]) 1129 | 1130 | elif self.dataset == 'esol': 1131 | smiles_list, rdkit_mol_objs, labels = \ 1132 | _load_esol_dataset(self.raw_paths[0]) 1133 | for i in range(len(smiles_list)): 1134 | print(i) 1135 | rdkit_mol = rdkit_mol_objs[i] 1136 | # # convert aromatic bonds to double bonds 1137 | # Chem.SanitizeMol(rdkit_mol, 1138 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1139 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1140 | # manually add mol id 1141 | data.id = torch.tensor( 1142 | [i]) # id here is the index of the mol in 1143 | # the dataset 1144 | data.y = torch.tensor([labels[i]]) 1145 | data_list.append(data) 1146 | data_smiles_list.append(smiles_list[i]) 1147 | 1148 | elif self.dataset == 'freesolv': 1149 | smiles_list, rdkit_mol_objs, labels = \ 1150 | _load_freesolv_dataset(self.raw_paths[0]) 1151 | for i in range(len(smiles_list)): 1152 | print(i) 1153 | rdkit_mol = rdkit_mol_objs[i] 1154 | # # convert aromatic bonds to double bonds 1155 | # Chem.SanitizeMol(rdkit_mol, 1156 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1157 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1158 | # manually add mol id 1159 | data.id = torch.tensor( 1160 | [i]) # id here is the index of the mol in 1161 | # the dataset 1162 | data.y = torch.tensor([labels[i]]) 1163 | data_list.append(data) 1164 | data_smiles_list.append(smiles_list[i]) 1165 | 1166 | elif self.dataset == 'lipophilicity': 1167 | smiles_list, rdkit_mol_objs, labels = \ 1168 | _load_lipophilicity_dataset(self.raw_paths[0]) 1169 | for i in range(len(smiles_list)): 1170 | print(i) 1171 | rdkit_mol = rdkit_mol_objs[i] 1172 | # # convert aromatic bonds to double bonds 1173 | # Chem.SanitizeMol(rdkit_mol, 1174 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1175 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1176 | # manually add mol id 1177 | data.id = torch.tensor( 1178 | [i]) # id here is the index of the mol in 1179 | # the dataset 1180 | data.y = torch.tensor([labels[i]]) 1181 | data_list.append(data) 1182 | data_smiles_list.append(smiles_list[i]) 1183 | 1184 | elif self.dataset == 'muv': 1185 | smiles_list, rdkit_mol_objs, labels = \ 1186 | _load_muv_dataset(self.raw_paths[0]) 1187 | for i in range(len(smiles_list)): 1188 | print(i) 1189 | rdkit_mol = rdkit_mol_objs[i] 1190 | # # convert aromatic bonds to double bonds 1191 | # Chem.SanitizeMol(rdkit_mol, 1192 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1193 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1194 | # manually add mol id 1195 | data.id = torch.tensor( 1196 | [i]) # id here is the index of the mol in 1197 | # the dataset 1198 | data.y = torch.tensor(labels[i, :]) 1199 | data_list.append(data) 1200 | data_smiles_list.append(smiles_list[i]) 1201 | 1202 | elif self.dataset == 'pcba': 1203 | smiles_list, rdkit_mol_objs, labels = \ 1204 | _load_pcba_dataset(self.raw_paths[0]) 1205 | for i in range(len(smiles_list)): 1206 | print(i) 1207 | rdkit_mol = rdkit_mol_objs[i] 1208 | # # convert aromatic bonds to double bonds 1209 | # Chem.SanitizeMol(rdkit_mol, 1210 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1211 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1212 | # manually add mol id 1213 | data.id = torch.tensor( 1214 | [i]) # id here is the index of the mol in 1215 | # the dataset 1216 | data.y = torch.tensor(labels[i, :]) 1217 | data_list.append(data) 1218 | data_smiles_list.append(smiles_list[i]) 1219 | 1220 | elif self.dataset == 'pcba_pretrain': 1221 | smiles_list, rdkit_mol_objs, labels = \ 1222 | _load_pcba_dataset(self.raw_paths[0]) 1223 | downstream_inchi = set(pd.read_csv(os.path.join(self.root, 1224 | 'downstream_mol_inchi_may_24_2019'), 1225 | sep=',', header=None)[0]) 1226 | for i in range(len(smiles_list)): 1227 | print(i) 1228 | if '.' not in smiles_list[i]: # remove examples with 1229 | # multiples species 1230 | rdkit_mol = rdkit_mol_objs[i] 1231 | mw = Descriptors.MolWt(rdkit_mol) 1232 | if 50 <= mw <= 900: 1233 | inchi = create_standardized_mol_id(smiles_list[i]) 1234 | if inchi != None and inchi not in downstream_inchi: 1235 | # # convert aromatic bonds to double bonds 1236 | # Chem.SanitizeMol(rdkit_mol, 1237 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1238 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1239 | # manually add mol id 1240 | data.id = torch.tensor( 1241 | [i]) # id here is the index of the mol in 1242 | # the dataset 1243 | data.y = torch.tensor(labels[i, :]) 1244 | data_list.append(data) 1245 | data_smiles_list.append(smiles_list[i]) 1246 | 1247 | # elif self.dataset == '' 1248 | 1249 | elif self.dataset == 'sider': 1250 | smiles_list, rdkit_mol_objs, labels = \ 1251 | _load_sider_dataset(self.raw_paths[0]) 1252 | for i in range(len(smiles_list)): 1253 | print(i) 1254 | rdkit_mol = rdkit_mol_objs[i] 1255 | # # convert aromatic bonds to double bonds 1256 | # Chem.SanitizeMol(rdkit_mol, 1257 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1258 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1259 | # manually add mol id 1260 | data.id = torch.tensor( 1261 | [i]) # id here is the index of the mol in 1262 | # the dataset 1263 | data.y = torch.tensor(labels[i, :]) 1264 | data_list.append(data) 1265 | data_smiles_list.append(smiles_list[i]) 1266 | 1267 | elif self.dataset == 'toxcast': 1268 | smiles_list, rdkit_mol_objs, labels = \ 1269 | _load_toxcast_dataset(self.raw_paths[0]) 1270 | for i in range(len(smiles_list)): 1271 | print(i) 1272 | rdkit_mol = rdkit_mol_objs[i] 1273 | if rdkit_mol != None: 1274 | # # convert aromatic bonds to double bonds 1275 | # Chem.SanitizeMol(rdkit_mol, 1276 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1277 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1278 | # manually add mol id 1279 | data.id = torch.tensor( 1280 | [i]) # id here is the index of the mol in 1281 | # the dataset 1282 | data.y = torch.tensor(labels[i, :]) 1283 | data_list.append(data) 1284 | data_smiles_list.append(smiles_list[i]) 1285 | 1286 | elif self.dataset == 'ptc_mr': 1287 | input_path = self.raw_paths[0] 1288 | input_df = pd.read_csv(input_path, sep=',', header=None, names=['id', 'label', 'smiles']) 1289 | smiles_list = input_df['smiles'] 1290 | labels = input_df['label'].values 1291 | for i in range(len(smiles_list)): 1292 | print(i) 1293 | s = smiles_list[i] 1294 | rdkit_mol = AllChem.MolFromSmiles(s) 1295 | if rdkit_mol != None: # ignore invalid mol objects 1296 | # # convert aromatic bonds to double bonds 1297 | # Chem.SanitizeMol(rdkit_mol, 1298 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1299 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1300 | # manually add mol id 1301 | data.id = torch.tensor( 1302 | [i]) 1303 | data.y = torch.tensor([labels[i]]) 1304 | data_list.append(data) 1305 | data_smiles_list.append(smiles_list[i]) 1306 | 1307 | elif self.dataset == 'mutag': 1308 | smiles_path = os.path.join(self.root, 'raw', 'mutag_188_data.can') 1309 | # smiles_path = 'storage/datasets/transfer_dataset/mutag/raw/mutag_188_data.can' 1310 | labels_path = os.path.join(self.root, 'raw', 'mutag_188_target.txt') 1311 | # labels_path = 'storage/datasets/transfer_dataset/mutag/raw/mutag_188_target.txt' 1312 | smiles_list = pd.read_csv(smiles_path, sep=' ', header=None)[0] 1313 | labels = pd.read_csv(labels_path, header=None)[0].values 1314 | for i in range(len(smiles_list)): 1315 | print(i) 1316 | s = smiles_list[i] 1317 | rdkit_mol = AllChem.MolFromSmiles(s) 1318 | if rdkit_mol != None: # ignore invalid mol objects 1319 | # # convert aromatic bonds to double bonds 1320 | # Chem.SanitizeMol(rdkit_mol, 1321 | # sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE) 1322 | data = mol_to_graph_data_obj_simple(rdkit_mol) 1323 | # manually add mol id 1324 | data.id = torch.tensor( 1325 | [i]) 1326 | data.y = torch.tensor([labels[i]]) 1327 | data_list.append(data) 1328 | data_smiles_list.append(smiles_list[i]) 1329 | 1330 | 1331 | else: 1332 | print(self.dataset) 1333 | raise ValueError('Invalid dataset name') 1334 | 1335 | if self.pre_filter is not None: 1336 | data_list = [data for data in data_list if self.pre_filter(data)] 1337 | 1338 | if self.pre_transform is not None: 1339 | data_list = [self.pre_transform(data) for data in data_list] 1340 | 1341 | # write data_smiles_list in processed paths 1342 | data_smiles_series = pd.Series(data_smiles_list) 1343 | data_smiles_series.to_csv(os.path.join(self.processed_dir, 1344 | 'smiles.csv'), index=False, 1345 | header=False) 1346 | 1347 | data, slices = self.collate(data_list) 1348 | torch.save((data, slices), self.processed_paths[0]) 1349 | 1350 | 1351 | # NB: only properly tested when dataset_1 is chembl_with_labels and dataset_2 1352 | # is pcba_pretrain 1353 | def merge_dataset_objs(dataset_1, dataset_2): 1354 | """ 1355 | Naively merge 2 molecule dataset objects, and ignore identities of 1356 | molecules. Assumes both datasets have multiple y labels, and will pad 1357 | accordingly. ie if dataset_1 has obj_1 with y dim 1310 and dataset_2 has 1358 | obj_2 with y dim 128, then the resulting obj_1 and obj_2 will have dim 1359 | 1438, where obj_1 have the last 128 cols with 0, and obj_2 have 1360 | the first 1310 cols with 0. 1361 | :return: pytorch geometric dataset obj, with the x, edge_attr, edge_index, 1362 | new y attributes only 1363 | """ 1364 | d_1_y_dim = dataset_1[0].y.size()[0] 1365 | d_2_y_dim = dataset_2[0].y.size()[0] 1366 | 1367 | data_list = [] 1368 | # keep only x, edge_attr, edge_index, padded_y then append 1369 | for d in dataset_1: 1370 | old_y = d.y 1371 | new_y = torch.cat([old_y, torch.zeros(d_2_y_dim, dtype=torch.long)]) 1372 | data_list.append(Data(x=d.x, edge_index=d.edge_index, 1373 | edge_attr=d.edge_attr, y=new_y)) 1374 | 1375 | for d in dataset_2: 1376 | old_y = d.y 1377 | new_y = torch.cat([torch.zeros(d_1_y_dim, dtype=torch.long), old_y.long()]) 1378 | data_list.append(Data(x=d.x, edge_index=d.edge_index, 1379 | edge_attr=d.edge_attr, y=new_y)) 1380 | 1381 | # create 'empty' dataset obj. Just randomly pick a dataset and root path 1382 | # that has already been processed 1383 | new_dataset = MoleculeDataset(root='storage/datasets/transfer_dataset/chembl_with_labels', 1384 | dataset='chembl_with_labels', empty=True) 1385 | # collate manually 1386 | new_dataset.data, new_dataset.slices = new_dataset.collate(data_list) 1387 | 1388 | return new_dataset 1389 | 1390 | 1391 | def create_circular_fingerprint(mol, radius, size, chirality): 1392 | """ 1393 | 1394 | :param mol: 1395 | :param radius: 1396 | :param size: 1397 | :param chirality: 1398 | :return: np array of morgan fingerprint 1399 | """ 1400 | fp = GetMorganFingerprintAsBitVect(mol, radius, 1401 | nBits=size, useChirality=chirality) 1402 | return np.array(fp) 1403 | 1404 | 1405 | class MoleculeFingerprintDataset(data.Dataset): 1406 | def __init__(self, root, dataset, radius, size, chirality=True): 1407 | """ 1408 | Create dataset object containing list of dicts, where each dict 1409 | contains the circular fingerprint of the molecule, label, id, 1410 | and possibly precomputed fold information 1411 | :param root: directory of the dataset, containing a raw and 1412 | processed_fp dir. The raw dir should contain the file containing the 1413 | smiles, and the processed_fp dir can either be empty or a 1414 | previously processed file 1415 | :param dataset: name of dataset. Currently only implemented for 1416 | tox21, hiv, chembl_with_labels 1417 | :param radius: radius of the circular fingerprints 1418 | :param size: size of the folded fingerprint vector 1419 | :param chirality: if True, fingerprint includes chirality information 1420 | """ 1421 | self.dataset = dataset 1422 | self.root = root 1423 | self.radius = radius 1424 | self.size = size 1425 | self.chirality = chirality 1426 | 1427 | self._load() 1428 | 1429 | def _process(self): 1430 | data_smiles_list = [] 1431 | data_list = [] 1432 | if self.dataset == 'chembl_with_labels': 1433 | smiles_list, rdkit_mol_objs, folds, labels = \ 1434 | _load_chembl_with_labels_dataset(os.path.join(self.root, 'raw')) 1435 | print('processing') 1436 | for i in range(len(rdkit_mol_objs)): 1437 | print(i) 1438 | rdkit_mol = rdkit_mol_objs[i] 1439 | if rdkit_mol != None: 1440 | # # convert aromatic bonds to double bonds 1441 | fp_arr = create_circular_fingerprint(rdkit_mol, 1442 | self.radius, 1443 | self.size, self.chirality) 1444 | fp_arr = torch.tensor(fp_arr) 1445 | # manually add mol id 1446 | id = torch.tensor([i]) # id here is the index of the mol in 1447 | # the dataset 1448 | y = torch.tensor(labels[i, :]) 1449 | # fold information 1450 | if i in folds[0]: 1451 | fold = torch.tensor([0]) 1452 | elif i in folds[1]: 1453 | fold = torch.tensor([1]) 1454 | else: 1455 | fold = torch.tensor([2]) 1456 | data_list.append({'fp_arr': fp_arr, 'id': id, 'y': y, 1457 | 'fold': fold}) 1458 | data_smiles_list.append(smiles_list[i]) 1459 | elif self.dataset == 'tox21': 1460 | smiles_list, rdkit_mol_objs, labels = \ 1461 | _load_tox21_dataset(os.path.join(self.root, 'raw/tox21.csv')) 1462 | print('processing') 1463 | for i in range(len(smiles_list)): 1464 | print(i) 1465 | rdkit_mol = rdkit_mol_objs[i] 1466 | ## convert aromatic bonds to double bonds 1467 | fp_arr = create_circular_fingerprint(rdkit_mol, 1468 | self.radius, 1469 | self.size, 1470 | self.chirality) 1471 | fp_arr = torch.tensor(fp_arr) 1472 | 1473 | # manually add mol id 1474 | id = torch.tensor([i]) # id here is the index of the mol in 1475 | # the dataset 1476 | y = torch.tensor(labels[i, :]) 1477 | data_list.append({'fp_arr': fp_arr, 'id': id, 'y': y}) 1478 | data_smiles_list.append(smiles_list[i]) 1479 | elif self.dataset == 'hiv': 1480 | smiles_list, rdkit_mol_objs, labels = \ 1481 | _load_hiv_dataset(os.path.join(self.root, 'raw/HIV.csv')) 1482 | print('processing') 1483 | for i in range(len(smiles_list)): 1484 | print(i) 1485 | rdkit_mol = rdkit_mol_objs[i] 1486 | # # convert aromatic bonds to double bonds 1487 | fp_arr = create_circular_fingerprint(rdkit_mol, 1488 | self.radius, 1489 | self.size, 1490 | self.chirality) 1491 | fp_arr = torch.tensor(fp_arr) 1492 | 1493 | # manually add mol id 1494 | id = torch.tensor([i]) # id here is the index of the mol in 1495 | # the dataset 1496 | y = torch.tensor([labels[i]]) 1497 | data_list.append({'fp_arr': fp_arr, 'id': id, 'y': y}) 1498 | data_smiles_list.append(smiles_list[i]) 1499 | else: 1500 | raise ValueError('Invalid dataset name') 1501 | 1502 | # save processed data objects and smiles 1503 | processed_dir = os.path.join(self.root, 'processed_fp') 1504 | data_smiles_series = pd.Series(data_smiles_list) 1505 | data_smiles_series.to_csv(os.path.join(processed_dir, 'smiles.csv'), 1506 | index=False, 1507 | header=False) 1508 | with open(os.path.join(processed_dir, 1509 | 'fingerprint_data_processed.pkl'), 1510 | 'wb') as f: 1511 | pickle.dump(data_list, f) 1512 | 1513 | def _load(self): 1514 | processed_dir = os.path.join(self.root, 'processed_fp') 1515 | # check if saved file exist. If so, then load from save 1516 | file_name_list = os.listdir(processed_dir) 1517 | if 'fingerprint_data_processed.pkl' in file_name_list: 1518 | with open(os.path.join(processed_dir, 1519 | 'fingerprint_data_processed.pkl'), 1520 | 'rb') as f: 1521 | self.data_list = pickle.load(f) 1522 | # if no saved file exist, then perform processing steps, save then 1523 | # reload 1524 | else: 1525 | self._process() 1526 | self._load() 1527 | 1528 | def __len__(self): 1529 | return len(self.data_list) 1530 | 1531 | def __getitem__(self, index): 1532 | ## if iterable class is passed, return dataset objection 1533 | if hasattr(index, "__iter__"): 1534 | dataset = MoleculeFingerprintDataset(self.root, self.dataset, self.radius, self.size, 1535 | chirality=self.chirality) 1536 | dataset.data_list = [self.data_list[i] for i in index] 1537 | return dataset 1538 | else: 1539 | return self.data_list[index] 1540 | 1541 | 1542 | def _load_tox21_dataset(input_path): 1543 | """ 1544 | 1545 | :param input_path: 1546 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1547 | labels 1548 | """ 1549 | input_df = pd.read_csv(input_path, sep=',') 1550 | smiles_list = input_df['smiles'] 1551 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1552 | tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 1553 | 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'] 1554 | labels = input_df[tasks] 1555 | # convert 0 to -1 1556 | labels = labels.replace(0, -1) 1557 | # convert nan to 0 1558 | labels = labels.fillna(0) 1559 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1560 | assert len(smiles_list) == len(labels) 1561 | return smiles_list, rdkit_mol_objs_list, labels.values 1562 | 1563 | 1564 | def _load_hiv_dataset(input_path): 1565 | """ 1566 | :param input_path: 1567 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1568 | labels 1569 | """ 1570 | input_df = pd.read_csv(input_path, sep=',') 1571 | smiles_list = input_df['smiles'] 1572 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1573 | labels = input_df['HIV_active'] 1574 | # convert 0 to -1 1575 | labels = labels.replace(0, -1) 1576 | # there are no nans 1577 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1578 | assert len(smiles_list) == len(labels) 1579 | return smiles_list, rdkit_mol_objs_list, labels.values 1580 | 1581 | 1582 | def _load_bace_dataset(input_path): 1583 | """ 1584 | 1585 | :param input_path: 1586 | :return: list of smiles, list of rdkit mol obj, np.array 1587 | containing indices for each of the 3 folds, np.array containing the 1588 | labels 1589 | """ 1590 | input_df = pd.read_csv(input_path, sep=',') 1591 | smiles_list = input_df['mol'] 1592 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1593 | labels = input_df['Class'] 1594 | # convert 0 to -1 1595 | labels = labels.replace(0, -1) 1596 | # there are no nans 1597 | folds = input_df['Model'] 1598 | folds = folds.replace('Train', 0) # 0 -> train 1599 | folds = folds.replace('Valid', 1) # 1 -> valid 1600 | folds = folds.replace('Test', 2) # 2 -> test 1601 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1602 | assert len(smiles_list) == len(labels) 1603 | assert len(smiles_list) == len(folds) 1604 | return smiles_list, rdkit_mol_objs_list, folds.values, labels.values 1605 | 1606 | 1607 | def _load_bbbp_dataset(input_path): 1608 | """ 1609 | 1610 | :param input_path: 1611 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1612 | labels 1613 | """ 1614 | input_df = pd.read_csv(input_path, sep=',') 1615 | smiles_list = input_df['smiles'] 1616 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1617 | 1618 | preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in 1619 | rdkit_mol_objs_list] 1620 | preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else 1621 | None for m in preprocessed_rdkit_mol_objs_list] 1622 | labels = input_df['p_np'] 1623 | # convert 0 to -1 1624 | labels = labels.replace(0, -1) 1625 | # there are no nans 1626 | assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) 1627 | assert len(smiles_list) == len(preprocessed_smiles_list) 1628 | assert len(smiles_list) == len(labels) 1629 | return preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list, \ 1630 | labels.values 1631 | 1632 | 1633 | def _load_clintox_dataset(input_path): 1634 | """ 1635 | 1636 | :param input_path: 1637 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1638 | labels 1639 | """ 1640 | input_df = pd.read_csv(input_path, sep=',') 1641 | smiles_list = input_df['smiles'] 1642 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1643 | 1644 | preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in 1645 | rdkit_mol_objs_list] 1646 | preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else 1647 | None for m in preprocessed_rdkit_mol_objs_list] 1648 | tasks = ['FDA_APPROVED', 'CT_TOX'] 1649 | labels = input_df[tasks] 1650 | # convert 0 to -1 1651 | labels = labels.replace(0, -1) 1652 | # there are no nans 1653 | assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) 1654 | assert len(smiles_list) == len(preprocessed_smiles_list) 1655 | assert len(smiles_list) == len(labels) 1656 | return preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list, \ 1657 | labels.values 1658 | 1659 | 1660 | # input_path = 'storage/datasets/transfer_dataset/clintox/raw/clintox.csv' 1661 | # smiles_list, rdkit_mol_objs_list, labels = _load_clintox_dataset(input_path) 1662 | 1663 | def _load_esol_dataset(input_path): 1664 | """ 1665 | 1666 | :param input_path: 1667 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1668 | labels (regression task) 1669 | """ 1670 | # NB: some examples have multiple species 1671 | input_df = pd.read_csv(input_path, sep=',') 1672 | smiles_list = input_df['smiles'] 1673 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1674 | labels = input_df['measured log solubility in mols per litre'] 1675 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1676 | assert len(smiles_list) == len(labels) 1677 | return smiles_list, rdkit_mol_objs_list, labels.values 1678 | 1679 | 1680 | # input_path = 'storage/datasets/transfer_dataset/esol/raw/delaney-processed.csv' 1681 | # smiles_list, rdkit_mol_objs_list, labels = _load_esol_dataset(input_path) 1682 | 1683 | def _load_freesolv_dataset(input_path): 1684 | """ 1685 | 1686 | :param input_path: 1687 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1688 | labels (regression task) 1689 | """ 1690 | input_df = pd.read_csv(input_path, sep=',') 1691 | smiles_list = input_df['smiles'] 1692 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1693 | labels = input_df['expt'] 1694 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1695 | assert len(smiles_list) == len(labels) 1696 | return smiles_list, rdkit_mol_objs_list, labels.values 1697 | 1698 | 1699 | def _load_lipophilicity_dataset(input_path): 1700 | """ 1701 | 1702 | :param input_path: 1703 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1704 | labels (regression task) 1705 | """ 1706 | input_df = pd.read_csv(input_path, sep=',') 1707 | smiles_list = input_df['smiles'] 1708 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1709 | labels = input_df['exp'] 1710 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1711 | assert len(smiles_list) == len(labels) 1712 | return smiles_list, rdkit_mol_objs_list, labels.values 1713 | 1714 | 1715 | def _load_muv_dataset(input_path): 1716 | """ 1717 | 1718 | :param input_path: 1719 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1720 | labels 1721 | """ 1722 | input_df = pd.read_csv(input_path, sep=',') 1723 | smiles_list = input_df['smiles'] 1724 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1725 | tasks = ['MUV-466', 'MUV-548', 'MUV-600', 'MUV-644', 'MUV-652', 'MUV-689', 1726 | 'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810', 1727 | 'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'] 1728 | labels = input_df[tasks] 1729 | # convert 0 to -1 1730 | labels = labels.replace(0, -1) 1731 | # convert nan to 0 1732 | labels = labels.fillna(0) 1733 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1734 | assert len(smiles_list) == len(labels) 1735 | return smiles_list, rdkit_mol_objs_list, labels.values 1736 | 1737 | 1738 | def _load_sider_dataset(input_path): 1739 | """ 1740 | 1741 | :param input_path: 1742 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1743 | labels 1744 | """ 1745 | input_df = pd.read_csv(input_path, sep=',') 1746 | smiles_list = input_df['smiles'] 1747 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1748 | tasks = ['Hepatobiliary disorders', 1749 | 'Metabolism and nutrition disorders', 'Product issues', 'Eye disorders', 1750 | 'Investigations', 'Musculoskeletal and connective tissue disorders', 1751 | 'Gastrointestinal disorders', 'Social circumstances', 1752 | 'Immune system disorders', 'Reproductive system and breast disorders', 1753 | 'Neoplasms benign, malignant and unspecified (incl cysts and polyps)', 1754 | 'General disorders and administration site conditions', 1755 | 'Endocrine disorders', 'Surgical and medical procedures', 1756 | 'Vascular disorders', 'Blood and lymphatic system disorders', 1757 | 'Skin and subcutaneous tissue disorders', 1758 | 'Congenital, familial and genetic disorders', 1759 | 'Infections and infestations', 1760 | 'Respiratory, thoracic and mediastinal disorders', 1761 | 'Psychiatric disorders', 'Renal and urinary disorders', 1762 | 'Pregnancy, puerperium and perinatal conditions', 1763 | 'Ear and labyrinth disorders', 'Cardiac disorders', 1764 | 'Nervous system disorders', 1765 | 'Injury, poisoning and procedural complications'] 1766 | labels = input_df[tasks] 1767 | # convert 0 to -1 1768 | labels = labels.replace(0, -1) 1769 | assert len(smiles_list) == len(rdkit_mol_objs_list) 1770 | assert len(smiles_list) == len(labels) 1771 | return smiles_list, rdkit_mol_objs_list, labels.values 1772 | 1773 | 1774 | def _load_toxcast_dataset(input_path): 1775 | """ 1776 | 1777 | :param input_path: 1778 | :return: list of smiles, list of rdkit mol obj, np.array containing the 1779 | labels 1780 | """ 1781 | # NB: some examples have multiple species, some example smiles are invalid 1782 | input_df = pd.read_csv(input_path, sep=',') 1783 | smiles_list = input_df['smiles'] 1784 | rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] 1785 | # Some smiles could not be successfully converted 1786 | # to rdkit mol object so them to None 1787 | preprocessed_rdkit_mol_objs_list = [m if m != None else None for m in 1788 | rdkit_mol_objs_list] 1789 | preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m != None else 1790 | None for m in preprocessed_rdkit_mol_objs_list] 1791 | tasks = list(input_df.columns)[1:] 1792 | labels = input_df[tasks] 1793 | # convert 0 to -1 1794 | labels = labels.replace(0, -1) 1795 | # convert nan to 0 1796 | labels = labels.fillna(0) 1797 | assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) 1798 | assert len(smiles_list) == len(preprocessed_smiles_list) 1799 | assert len(smiles_list) == len(labels) 1800 | return preprocessed_smiles_list, preprocessed_rdkit_mol_objs_list, \ 1801 | labels.values 1802 | 1803 | 1804 | def _load_chembl_with_labels_dataset(root_path): 1805 | """ 1806 | Data from 'Large-scale comparison of machine learning methods for drug target prediction on ChEMBL' 1807 | :param root_path: path to the folder containing the reduced chembl dataset 1808 | :return: list of smiles, preprocessed rdkit mol obj list, list of np.array 1809 | containing indices for each of the 3 folds, np.array containing the labels 1810 | """ 1811 | # adapted from https://github.com/ml-jku/lsc/blob/master/pythonCode/lstm/loadData.py 1812 | # first need to download the files and unzip: 1813 | # wget http://bioinf.jku.at/research/lsc/chembl20/dataPythonReduced.zip 1814 | # unzip and rename to chembl_with_labels 1815 | # wget http://bioinf.jku.at/research/lsc/chembl20/dataPythonReduced/chembl20Smiles.pckl 1816 | # into the dataPythonReduced directory 1817 | # wget http://bioinf.jku.at/research/lsc/chembl20/dataPythonReduced/chembl20LSTM.pckl 1818 | 1819 | # 1. load folds and labels 1820 | f = open(os.path.join(root_path, 'folds0.pckl'), 'rb') 1821 | folds = pickle.load(f) 1822 | f.close() 1823 | 1824 | f = open(os.path.join(root_path, 'labelsHard.pckl'), 'rb') 1825 | targetMat = pickle.load(f) 1826 | sampleAnnInd = pickle.load(f) 1827 | targetAnnInd = pickle.load(f) 1828 | f.close() 1829 | 1830 | targetMat = targetMat 1831 | targetMat = targetMat.copy().tocsr() 1832 | targetMat.sort_indices() 1833 | targetAnnInd = targetAnnInd 1834 | targetAnnInd = targetAnnInd - targetAnnInd.min() 1835 | 1836 | folds = [np.intersect1d(fold, sampleAnnInd.index.values).tolist() for fold in folds] 1837 | targetMatTransposed = targetMat[sampleAnnInd[list(chain(*folds))]].T.tocsr() 1838 | targetMatTransposed.sort_indices() 1839 | # # num positive examples in each of the 1310 targets 1840 | trainPosOverall = np.array([np.sum(targetMatTransposed[x].data > 0.5) for x in range(targetMatTransposed.shape[0])]) 1841 | # # num negative examples in each of the 1310 targets 1842 | trainNegOverall = np.array( 1843 | [np.sum(targetMatTransposed[x].data < -0.5) for x in range(targetMatTransposed.shape[0])]) 1844 | # dense array containing the labels for the 456331 molecules and 1310 targets 1845 | denseOutputData = targetMat.A # possible values are {-1, 0, 1} 1846 | 1847 | # 2. load structures 1848 | f = open(os.path.join(root_path, 'chembl20LSTM.pckl'), 'rb') 1849 | rdkitArr = pickle.load(f) 1850 | f.close() 1851 | 1852 | assert len(rdkitArr) == denseOutputData.shape[0] 1853 | assert len(rdkitArr) == len(folds[0]) + len(folds[1]) + len(folds[2]) 1854 | 1855 | preprocessed_rdkitArr = [] 1856 | print('preprocessing') 1857 | for i in range(len(rdkitArr)): 1858 | print(i) 1859 | m = rdkitArr[i] 1860 | if m == None: 1861 | preprocessed_rdkitArr.append(None) 1862 | else: 1863 | mol_species_list = split_rdkit_mol_obj(m) 1864 | if len(mol_species_list) == 0: 1865 | preprocessed_rdkitArr.append(None) 1866 | else: 1867 | largest_mol = get_largest_mol(mol_species_list) 1868 | if len(largest_mol.GetAtoms()) <= 2: 1869 | preprocessed_rdkitArr.append(None) 1870 | else: 1871 | preprocessed_rdkitArr.append(largest_mol) 1872 | 1873 | assert len(preprocessed_rdkitArr) == denseOutputData.shape[0] 1874 | 1875 | smiles_list = [AllChem.MolToSmiles(m) if m != None else None for m in 1876 | preprocessed_rdkitArr] # bc some empty mol in the 1877 | # rdkitArr zzz... 1878 | 1879 | assert len(preprocessed_rdkitArr) == len(smiles_list) 1880 | 1881 | return smiles_list, preprocessed_rdkitArr, folds, denseOutputData 1882 | 1883 | 1884 | # root_path = 'storage/datasets/transfer_dataset/chembl_with_labels' 1885 | 1886 | def check_smiles_validity(smiles): 1887 | try: 1888 | m = Chem.MolFromSmiles(smiles) 1889 | if m: 1890 | return True 1891 | else: 1892 | return False 1893 | except: 1894 | return False 1895 | 1896 | 1897 | def split_rdkit_mol_obj(mol): 1898 | """ 1899 | Split rdkit mol object containing multiple species or one species into a 1900 | list of mol objects or a list containing a single object respectively 1901 | :param mol: 1902 | :return: 1903 | """ 1904 | smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) 1905 | smiles_list = smiles.split('.') 1906 | mol_species_list = [] 1907 | for s in smiles_list: 1908 | if check_smiles_validity(s): 1909 | mol_species_list.append(AllChem.MolFromSmiles(s)) 1910 | return mol_species_list 1911 | 1912 | 1913 | def get_largest_mol(mol_list): 1914 | """ 1915 | Given a list of rdkit mol objects, returns mol object containing the 1916 | largest num of atoms. If multiple containing largest num of atoms, 1917 | picks the first one 1918 | :param mol_list: 1919 | :return: 1920 | """ 1921 | num_atoms_list = [len(m.GetAtoms()) for m in mol_list] 1922 | largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) 1923 | return mol_list[largest_mol_idx] 1924 | 1925 | 1926 | def create_all_datasets(): 1927 | #### create dataset 1928 | downstream_dir = [ 1929 | 'bace', 1930 | 'bbbp', 1931 | 'clintox', 1932 | 'esol', 1933 | 'freesolv', 1934 | 'hiv', 1935 | 'lipophilicity', 1936 | 'muv', 1937 | 'sider', 1938 | 'tox21', 1939 | 'toxcast' 1940 | ] 1941 | 1942 | for dataset_name in downstream_dir: 1943 | print(dataset_name) 1944 | root = "storage/datasets/transfer_dataset/" + dataset_name 1945 | os.makedirs(root + "/processed", exist_ok=True) 1946 | dataset = MoleculeDataset(root, dataset=dataset_name) 1947 | print(dataset) 1948 | 1949 | dataset = MoleculeDataset(root="storage/datasets/transfer_dataset/chembl_filtered", dataset="chembl_filtered") 1950 | print(dataset) 1951 | dataset = MoleculeDataset(root="storage/datasets/transfer_dataset/zinc_standard_agent", dataset="zinc_standard_agent") 1952 | print(dataset) 1953 | 1954 | 1955 | # test MoleculeDataset object 1956 | if __name__ == "__main__": 1957 | create_all_datasets() 1958 | --------------------------------------------------------------------------------