├── .gitignore ├── README.md ├── assets └── equivset.png ├── data_loader ├── __init__.py ├── amazon.py ├── bindingdb.py ├── celeba.py ├── gaussian.py ├── moons.py ├── pdbbind.py └── set_pdbbind.py ├── main.py ├── model ├── EquiVSet.py ├── acnn.py ├── base_model.py ├── celebaCNN.py ├── deepDTA.py └── modules.py └── utils ├── config.py ├── evaluation.py ├── logger.py └── pytorch_helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .idea/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

:fire:EquiVSet:fire:

2 |

3 | License 4 | 5 | Colab 6 | Video 7 | Slides 8 | Poster 9 |

10 | 11 | This repo contains PyTorch implementation of the paper "[Learning Neural Set Functions Under the Optimal Subset Oracle](https://arxiv.org/abs/2203.01693)" 12 | 13 | by [Zijing Ou](https://j-zin.github.io/), [Tingyang Xu](https://scholar.google.com.hk/citations?user=6gIs5YMAAAAJ&hl=en), [Qinliang Su](https://scholar.google.com/citations?user=cuIweygAAAAJ&hl=en), [Yingzhen Li](http://yingzhenli.net/home/en/), [Peilin Zhao](https://peilinzhao.github.io/), and [Yatao Bian](https://yataobian.com/). 14 | 15 | > We propose a way to learn set functions when the optimal subsets are given from an optimal subset oracle. This setting is different to other works that learn set functions from the function value oracle that provide utility values for each specific subset. Thus this setting is arguably more practically important but is surprisingly overlooked by previous works. To learn set functions under the optimal subset oracle, we propose to cast the problem into maximum likelihood estimation by replacing the utility function with an energy-based model such that it is proportional to the utility value, and satisfies some desiderata for set functions (e.g., permutation invariance, etc). Then mean-field variational inference and its amortized variants are proposed to learn EBMs on the sets. We evaluate our approach in a wide range of applications, including product recommendation, set anomaly detection, and compound selection in AI-aided drug discovery. The empirical results show our approach is promising. 16 | 17 | ![equivset](assets/equivset.png) 18 |

Overview of the training and inference processes of EquiVSet.

19 | 20 | ## Installation 21 | 22 | Please ensure that: 23 | 24 | - Python >= 3.6 25 | - PyTorch >= 1.8.0 26 | - dgl >= 0.7.0 27 | 28 | The following pakages are needed if you want to run the `compound selection` experiments: 29 | 30 | - **rdkit**: We recommend installing it with `conda install -c rdkit rdkit==2018.09.3`. For other installation recipes, see the [official documentation](https://www.rdkit.org/docs/Install.html). 31 | - **dgllife**: We recommend installing it with `pip install dgllife`. More information is available in the [official documentation](https://lifesci.dgl.ai/install/index.html). 32 | - **tdc**: We recommend installing it with `pip install PyTDC`. See the [official documentation](https://tdc.readthedocs.io/en/main/install.html) for more information. 33 | 34 | We provide step-by-step installation commands as follows: 35 | 36 | ``` 37 | conda create -n EquiVSet python=3.7 38 | source activate EquiVSet 39 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 40 | pip install dgl-cu110 dglgo -f https://data.dgl.ai/wheels/repo.html 41 | 42 | # The following commands are used for compound selection: 43 | conda install -c rdkit rdkit==2018.09.3 44 | pip install dgllife 45 | pip install PyTDC 46 | ``` 47 | 48 | ## Datasets 49 | For the experiments, we use the following datasets: 50 | 51 | - [Amazon baby registry dataset](https://www.kaggle.com/datasets/roopalik/amazon-baby-dataset) for the `product recommendation` experiments. The dataset is available [here](https://drive.google.com/file/d/1OLbCOTsRyowxw3_AzhxJPVB8VAgjt2Y6/view?usp=sharing). 52 | - [CelebA dataset](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) fot the `set anomaly detection` experiments. The images are available [here](https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?usp=sharing&resourcekey=0-dYn9z10tMJOBAkviAcfdyQ) and the attribute labels are available [here](https://drive.google.com/file/d/0B7EVK8r0v71pblRyaVFSWGxPY0U/view?usp=sharing&resourcekey=0-YW2qIuRcWHy_1C2VaRGL3Q). 53 | - [PDBBind](http://www.pdbbind.org.cn/) and [BindingDB](https://www.bindingdb.org/bind/index.jsp) for the `compubd selection` experiments. The PDBBind dataset is available [here](http://www.pdbbind.org.cn/index.php?newsid=20#news_section) and the BindingDB dataset is available [here](https://www.bindingdb.org/bind/index.jsp). 54 | 55 | For all experiments, the dataset is automatically downloaded and preprocessed when you run the corresponding code. You could also download the dataset manually using the link provided. 56 | 57 | ## Experiments 58 | 59 | This repository implements the synthetic experiments (section 6), product recommendation (section 6), set anomaly detection (section 6), and compound selection (section 6). 60 | 61 | ### Synthetic Experiments 62 | 63 | To run on the Two-Moons and Gaussian-Mixture dataset 64 | ``` 65 | python main.py equivset --train --cuda --data_name 66 | ``` 67 | `dataset_name` is chosen in ['moons', 'gaussian']. 68 | We also provide the Jupyter notebook [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_EI0BUjFzNAVxWS1ao-xia_UVmW4KLi4?usp=sharing) to run the synthetic experiments. 69 | 70 | 71 | ### Product Recommendation 72 | 73 | To run on the Amazon baby registry dataset 74 | ``` 75 | python main.py equivset --train --cuda --data_name amazon --amazon_cat 76 | ``` 77 | `category_name` is chosen in ['toys', 'furniture', 'gear', 'carseats', 'bath', 'health', 'diaper', 'bedding', 'safety', 'feeding', 'apparel', 'media']. 78 | 79 | ### Set Anomaly Detection 80 | 81 | To run on the CelebA dataset 82 | ``` 83 | python main.py equivset --train --cuda --data_name celeba 84 | ``` 85 | 86 | ### Compound Selection 87 | 88 | To run on the PDBBind and BindingDB dataset 89 | ``` 90 | python main.py equivset --train --cuda --data_name 91 | ``` 92 | `dataset_name` is chosen in ['pdbbind', 'bindingdb']. 93 | 94 | ## Citation 95 | 96 | :smile:If you find this repo is useful, please consider to cite our paper: 97 | ``` 98 | @article{ou2022learning, 99 | title={Learning Set Functions Under the Optimal Subset Oracle via Equivariant Variational Inference}, 100 | author={Ou, Zijing and Xu, Tingyang and Su, Qinliang and Li, Yingzhen and Zhao, Peilin and Bian, Yatao}, 101 | journal={arXiv preprint arXiv:2203.01693}, 102 | year={2022} 103 | } 104 | ``` 105 | -------------------------------------------------------------------------------- /assets/equivset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SubsetSelection/EquiVSet/76713b8e61279639261fb8554d4223f86b84b6c0/assets/equivset.png -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .moons import TwoMoons 2 | from .gaussian import GaussianMixture 3 | from .amazon import Amazon 4 | from .celeba import CelebA 5 | from .set_pdbbind import SetPDBBind 6 | from .bindingdb import SetBindingDB -------------------------------------------------------------------------------- /data_loader/amazon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import gdown 4 | import torch 5 | import zipfile 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | from utils.pytorch_helper import read_from_pickle, find_not_in_set 10 | 11 | class Data: 12 | def __init__(self, params): 13 | self.params = params 14 | 15 | def gen_datasets(self): 16 | raise NotImplementedError 17 | 18 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, 19 | get_test=True): 20 | raise NotImplementedError 21 | 22 | class Amazon(Data): 23 | def __init__(self, params): 24 | super().__init__(params) 25 | data_root = self.download_amazon() 26 | 27 | torch.manual_seed(1) # fix dataset 28 | self.gen_datasets(data_root) 29 | 30 | def download_amazon(self): 31 | url = 'https://drive.google.com/uc?id=1OLbCOTsRyowxw3_AzhxJPVB8VAgjt2Y6' 32 | data_root = '/root/dataset/amazon' 33 | download_path = f'{data_root}/amazon_baby_registry.zip' 34 | if not os.path.exists(data_root): 35 | os.makedirs(data_root) 36 | if not os.listdir(data_root): 37 | gdown.download(url, download_path, quiet=False) 38 | with zipfile.ZipFile(download_path, 'r') as ziphandler: 39 | ziphandler.extractall(data_root) 40 | return data_root 41 | 42 | def read_real_data(self, data_root): 43 | pickle_filename = data_root + '/' + self.params.amazon_cat 44 | dataset_ = read_from_pickle(pickle_filename) 45 | for i in range(len(dataset_)): 46 | dataset_[i+1] = torch.tensor(dataset_[i+1]) 47 | data_ = torch.zeros(len(dataset_), dataset_[1].shape[0]) 48 | for i in range(len(dataset_)): 49 | data_[i,:] = dataset_[i+1] 50 | 51 | csv_filename = data_root + '/' + self.params.amazon_cat + '.csv' 52 | with open(csv_filename, newline='') as csvfile: 53 | reader = csv.reader(csvfile, delimiter=',', quoting=csv.QUOTE_NONNUMERIC,quotechar='|') 54 | S = {} 55 | i=-1 56 | for row in reader: 57 | i=i+1 58 | S[i] = torch.tensor([int(row[x]) for x in range(len(row))]).long() 59 | return data_ , S 60 | 61 | def filter_S(self, data, S): 62 | S_list = [] 63 | V_list = [] 64 | for i in range(len(S)): 65 | if S[i].shape[0]>2 and S[i].shape[0] < self.params.v_size: 66 | Svar = S[i] - 1 # index starts froms 0 67 | sub_set, ground_set= self.construct_ground_set(data, Svar, V=self.params.v_size) 68 | S_list.append(sub_set); V_list.append(ground_set) 69 | S = S_list 70 | U = V_list 71 | return U, S 72 | 73 | def construct_ground_set(self, data, S, V): 74 | S_data = data[S] 75 | S_mean = S_data.mean(dim=0).unsqueeze(0) 76 | UnotS_data = find_not_in_set(data, S) 77 | S_mean_norm = F.normalize(S_mean, dim=-1) 78 | UnotS_data_norm = F.normalize(UnotS_data, dim=-1) 79 | 80 | cos_sim = (S_mean_norm @ UnotS_data_norm.T).squeeze(0) 81 | _, idx = torch.sort(cos_sim) 82 | UnotS_idx = idx[:V-S.shape[0]] 83 | UnotS_data = UnotS_data[UnotS_idx] 84 | 85 | S = torch.randperm(V)[:S.shape[0]] 86 | UnotS_idx = torch.ones(V,dtype=bool) 87 | UnotS_idx[S] = False 88 | 89 | U = torch.zeros([V, data.shape[-1]]) 90 | U[S] = S_data 91 | U[UnotS_idx] = UnotS_data 92 | 93 | return S, U 94 | 95 | def split_into_training_test(self, data_mat, S): 96 | folds = [0.33, 0.33, 0.33] 97 | num_elem = len(data_mat) 98 | tr_size = int(folds[0]* num_elem) 99 | dev_size = int((folds[1]+folds[0])* num_elem) 100 | test_size = num_elem 101 | 102 | V_train = data_mat[0:tr_size] 103 | V_dev = data_mat[tr_size:dev_size] 104 | V_test = data_mat[dev_size:test_size] 105 | 106 | S_train = S[0:tr_size] 107 | S_dev = S[tr_size:dev_size] 108 | S_test = S[dev_size:test_size] 109 | 110 | V_sets = (V_train, V_dev, V_test) 111 | S_sets = (S_train, S_dev, S_test) 112 | return V_sets, S_sets 113 | 114 | def gen_datasets(self, data_root): 115 | data, S = self.read_real_data(data_root) 116 | data, S = self.filter_S(data, S) 117 | V_sets, S_sets = self.split_into_training_test(data, S) 118 | 119 | self.V_train, self.V_val, self.V_test = V_sets 120 | self.S_train, self.S_val, self.S_test = S_sets 121 | 122 | self.fea_size = self.V_train[0].shape[-1] 123 | 124 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True): 125 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True) 126 | val_dataset = SetDataset(self.V_val, self.S_val, self.params) 127 | test_dataset = SetDataset(self.V_test, self.S_test, self.params) 128 | 129 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 130 | shuffle=shuffle_train, num_workers=num_workers) 131 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, 132 | shuffle=False, num_workers=num_workers) 133 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 134 | shuffle=False, num_workers=num_workers) if get_test else None 135 | return train_loader, val_loader, test_loader 136 | 137 | class SetDataset(Dataset): 138 | def __init__(self, V, S, params, is_train=False): 139 | self.data = V 140 | self.labels = S 141 | self.is_train = is_train 142 | self.neg_num = params.neg_num 143 | self.v_size = params.v_size 144 | 145 | def __getitem__(self, index): 146 | V = self.data[index] 147 | S = self.labels[index] 148 | 149 | S_mask = torch.zeros([self.v_size]) 150 | S_mask[S] = 1 151 | if self.is_train: 152 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0] 153 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]] 154 | neg_S_mask = torch.zeros([self.v_size]) 155 | neg_S_mask[S] = 1 156 | neg_S_mask[neg_S] = 1 157 | return V, S_mask, neg_S_mask 158 | 159 | return V, S_mask 160 | 161 | def __len__(self): 162 | return len(self.data) 163 | -------------------------------------------------------------------------------- /data_loader/bindingdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | from tqdm import tqdm 6 | from rdkit import Chem 7 | from tdc.multi_pred import DTI 8 | from multiprocessing import Pool 9 | from rdkit.DataStructs import FingerprintSimilarity 10 | from sklearn.cluster import AffinityPropagation 11 | 12 | from torch.utils.data import Dataset, DataLoader 13 | 14 | acid_list = ['H', 'M', 'C', 'P', 'L', 'A', 'R', 'F', 'D', 'T', 'K', 'E', 'S', 'V', 'G', 'Y', 'N', 'W', 'I', 'Q'] 15 | CHARPROTLEN = len(acid_list) 16 | CHARPROTDIC = { acid_list[idx]: idx for idx in range(len(acid_list))} 17 | 18 | smile_list = ['o', 'N', '/', 'H', '#', 'C', 'i', '+', 'l', '@', '8', '-', '6', '3', '\\', '2', 'B', 'P', '.', 19 | 'e', '9', '7', 'a', 's', 'O', ')', '0', 'n', '1', '4', 'I', 'F', ']', 'S', '5', '(', '[', '=', '%', 'c', 'r'] 20 | CHARCANSMILEN = len(smile_list) 21 | CHARCANSMIDIC = { smile_list[idx]: idx for idx in range(len(smile_list))} 22 | 23 | class Tokenizer(): 24 | @staticmethod 25 | def seq_tokenizer(seq, type_): 26 | if type_ == 'drug': 27 | max_length = 100 28 | mask = torch.zeros([max_length, CHARCANSMILEN]) 29 | seq = np.array([CHARCANSMIDIC[item] for item in seq.split(" ")[:max_length]]) 30 | elif type_ == 'protein': 31 | max_length = 1000 32 | mask = torch.zeros([max_length, CHARPROTLEN]) 33 | seq = np.array([CHARPROTDIC[item] for item in seq.split(" ")[:max_length]]) 34 | 35 | length = seq.shape[0] 36 | mask[range(length), seq] = 1 37 | return mask.transpose_(0, 1).unsqueeze(0) 38 | 39 | @staticmethod 40 | def tokenizer(bt, type_): 41 | bt = [Tokenizer.seq_tokenizer(seq, type_) for seq in bt] 42 | return bt 43 | 44 | class SetBindingDB(object): 45 | def __init__(self, params): 46 | super().__init__() 47 | self.params = params 48 | self.gen_datasets() 49 | 50 | def gen_datasets(self): 51 | np.random.seed(1) # fix dataset 52 | V_size, S_size = self.params.v_size, self.params.s_size 53 | self.dataset = load_bindingdb(self.params) 54 | 55 | data_root = '/root/dataset/bindingdb' 56 | data_path = os.path.join(data_root, 'bindingdb_set_data.pkl') 57 | if os.path.exists(data_path): 58 | print(f'load data from {data_path}') 59 | trainData, valData, testData = pickle.load(open(data_path, "rb")) 60 | self.V_train, self.S_train = trainData['V_train'], trainData['S_train'] 61 | self.V_val, self.S_val = valData['V_train'], valData['S_train'] 62 | self.V_test, self.S_test = testData['V_train'], testData['S_train'] 63 | else: 64 | self.V_train, self.S_train = get_set_bindingdb_dataset_activate(self.dataset, V_size, S_size, self.params, size=1000) 65 | self.V_val, self.S_val = get_set_bindingdb_dataset_activate(self.dataset, V_size, S_size, self.params, size=100) 66 | self.V_test, self.S_test = get_set_bindingdb_dataset_activate(self.dataset, V_size, S_size, self.params, size=100) 67 | 68 | trainData = {'V_train': self.V_train, 'S_train': self.S_train} 69 | valData = {'V_train': self.V_val, 'S_train': self.S_val} 70 | testData = {'V_train': self.V_test, 'S_train': self.S_test} 71 | if not os.path.exists(data_root): 72 | os.makedirs(data_root) 73 | pickle.dump((trainData, valData, testData), open(data_path, "wb")) 74 | 75 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True): 76 | train_dataset = SetDataset(self.dataset, self.V_train, self.S_train, self.params, is_train=True) 77 | val_dataset = SetDataset(self.dataset, self.V_val, self.S_val, self.params) 78 | test_dataset = SetDataset(self.dataset, self.V_test, self.S_test, self.params) 79 | 80 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 81 | collate_fn=collate_train, pin_memory=True, shuffle=shuffle_train, num_workers=num_workers) 82 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, 83 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers) 84 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 85 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers) if get_test else None 86 | 87 | return train_loader, val_loader, test_loader 88 | 89 | def collate_train(data): 90 | V_drug, V_target, S, neg_S = map(list, zip(*data)) 91 | bs, vs = len(V_drug), V_drug[0].shape[0] 92 | 93 | b_D = [ gs[idx] for gs in V_drug for idx in range(vs)] 94 | b_P = [ gt[idx] for gt in V_target for idx in range(vs)] 95 | b_D = Tokenizer.tokenizer(b_D, 'drug') 96 | b_P = Tokenizer.tokenizer(b_P, 'protein') 97 | 98 | b_D = torch.cat(b_D, dim=0) 99 | b_P = torch.cat(b_P, dim=0) 100 | S = torch.cat(S, dim=0).reshape(bs, -1) 101 | neg_S = torch.cat(neg_S, dim=0).reshape(bs, -1) 102 | return (b_D, b_P), S, neg_S 103 | 104 | def collate_val_and_test(data): 105 | V_drug, V_target, S = map(list, zip(*data)) 106 | bs, vs = len(V_drug), V_drug[0].shape[0] 107 | 108 | b_D = [ gs[idx] for gs in V_drug for idx in range(vs)] 109 | b_P = [ gt[idx] for gt in V_target for idx in range(vs)] 110 | b_D = Tokenizer.tokenizer(b_D, 'drug') 111 | b_P = Tokenizer.tokenizer(b_P, 'protein') 112 | 113 | b_D = torch.cat(b_D, dim=0) 114 | b_P = torch.cat(b_P, dim=0) 115 | S = torch.cat(S, dim=0).reshape(bs, -1) 116 | return (b_D, b_P), S 117 | 118 | class SetDataset(Dataset): 119 | def __init__(self, dataset, V_idxs, S_idxs, params, is_train=False): 120 | self.drugs, self.targets = dataset['Drug'], dataset['Target'] 121 | self.V_idxs, self.S_idxs = V_idxs, S_idxs 122 | self.is_train = is_train 123 | self.neg_num = params.neg_num 124 | self.v_size = params.v_size 125 | 126 | def __getitem__(self, index): 127 | V_idxs, S = np.array(self.V_idxs[index]), np.array(self.S_idxs[index]) 128 | V_drug = np.array([" ".join(item) for item in self.drugs[V_idxs].tolist()]) 129 | V_target = np.array([" ".join(item) for item in self.targets[V_idxs].tolist()]) 130 | 131 | S_mask = torch.zeros([self.v_size]) 132 | S_mask[S] = 1 133 | if self.is_train: 134 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0] 135 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]] 136 | neg_S_mask = torch.zeros([self.v_size]) 137 | neg_S_mask[S] = 1 138 | neg_S_mask[neg_S] = 1 139 | return V_drug, V_target, S_mask, neg_S_mask 140 | 141 | return V_drug, V_target, S_mask 142 | 143 | def __len__(self): 144 | return len(self.V_idxs) 145 | 146 | def load_bindingdb(params): 147 | data = DTI(name = 'BindingDB_Kd', path='/root/dataset/bindingdb') 148 | data.harmonize_affinities(mode = 'mean') 149 | return data.get_data() 150 | 151 | def load_dt_pair(dataset, setdata): 152 | drugs, targets = dataset['Drug'], dataset['Target'] 153 | V, S = setdata['V_train'], setdata['S_train'] 154 | 155 | V_drugs_list, V_targets_list = [], [] 156 | for V_idxs in V: 157 | V_drugs = [" ".join(item) for item in drugs[V_idxs].tolist()] 158 | V_drugs_list.append(V_drugs) 159 | V_targets = [" ".join(item) for item in targets[V_idxs].tolist()] 160 | V_targets_list.append(V_targets) 161 | 162 | V_drug = np.array(V_drugs_list) 163 | V_target = np.array(V_targets_list) 164 | S = torch.Tensor(S).type(torch.long) 165 | return (V_drug, V_target), S 166 | 167 | def get_set_bindingdb_dataset_activate(dataset, V_size, S_size, params, size=1000): 168 | """ 169 | Generate dataset for compound selection with only the bioactivity filter 170 | """ 171 | _, _, labels = dataset['Drug'], dataset['Target'], dataset['Y'] 172 | data_size = len(labels) 173 | 174 | V_list, S_list = [], [] 175 | for _ in tqdm(range(size)): 176 | V_idxs = np.random.permutation(data_size)[:V_size] 177 | sub_labels = torch.from_numpy(labels[V_idxs].to_numpy()) 178 | _, idxs = torch.topk(sub_labels, S_size) 179 | 180 | V_list.append(V_idxs) 181 | S_list.append(idxs) 182 | return np.array(V_list), np.array(S_list) 183 | 184 | def get_set_bindingdb_dataset(dataset, V_size, S_size, params, size=1000): 185 | """ 186 | Generate dataset for compound selection with bioactivity and diversity filters 187 | """ 188 | drugs, targets, labels = dataset['Drug'], dataset['Target'], dataset['Y'] 189 | data_size = len(labels) 190 | 191 | V_list, S_list = [], [] 192 | pbar = tqdm(total=size) 193 | num = 0 194 | while True: 195 | if num == size: break 196 | 197 | V_idxs = np.random.permutation(data_size)[:V_size] 198 | sub_labels = torch.from_numpy(labels[V_idxs].to_numpy()) 199 | _, idxs = torch.topk(sub_labels, V_size // 3) 200 | filter_idxs = V_idxs[idxs] 201 | S_idxs = get_os_oracle(drugs, filter_idxs) 202 | 203 | if len(S_idxs) == 0: continue 204 | 205 | V_list.append(V_idxs) 206 | S_list.append([np.where(V_idxs == item)[0][0] for item in S_idxs]) 207 | 208 | num += 1 209 | pbar.update(1) 210 | return np.array(V_list), np.array(S_list) 211 | 212 | def get_os_oracle(drugs, filter_idxs): 213 | smile_list = drugs[filter_idxs].to_numpy() 214 | n = len(smile_list) 215 | 216 | sm = np.zeros((n, n)) 217 | max_cpu = os.cpu_count() 218 | ij_list = [(i, j, smile_list) for i in range(n) for j in range(i+1, n)] 219 | with Pool(max_cpu) as p: 220 | similarity = p.starmap(cal_fingerprint_similarity, ij_list) 221 | i, j, _ = zip(*ij_list) 222 | sm[i,j] = similarity 223 | sm = sm + sm.T + np.eye(n) 224 | 225 | af = AffinityPropagation().fit(sm) 226 | cluster_centers_indices = af.cluster_centers_indices_ 227 | return filter_idxs[cluster_centers_indices] 228 | 229 | def cal_fingerprint_similarity(i, j, smiles): 230 | m1, m2 = Chem.MolFromSmiles(smiles[i]), Chem.MolFromSmiles(smiles[j]) 231 | return FingerprintSimilarity(Chem.RDKFingerprint(m1), Chem.RDKFingerprint(m2)) 232 | -------------------------------------------------------------------------------- /data_loader/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gdown 4 | import pickle 5 | import zipfile 6 | import numpy as np 7 | import pandas as pd 8 | from PIL import Image 9 | from tqdm import tqdm 10 | from torchvision import transforms 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | image_size = 64 14 | img_transform = transforms.Compose([ 15 | transforms.Resize(image_size), 16 | transforms.CenterCrop(image_size), 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 19 | std=[0.5, 0.5, 0.5]) 20 | ]) 21 | img_root_path = '/root/dataset/celeba/img_align_celeba/' 22 | 23 | class Data: 24 | def __init__(self, params): 25 | self.params = params 26 | 27 | def gen_datasets(self): 28 | raise NotImplementedError 29 | 30 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, 31 | get_test=True): 32 | raise NotImplementedError 33 | 34 | class CelebA(Data): 35 | def __init__(self, params): 36 | super().__init__(params) 37 | data_root = self.download_celeba() 38 | 39 | torch.manual_seed(1) # fix dataset 40 | np.random.seed(1) 41 | self.gen_datasets(data_root) 42 | 43 | def download_celeba(self): 44 | url_img = 'https://drive.google.com/uc?id=1iBJh4vHuE9h-eMOqVis94QccxCT_LPFW' 45 | url_anno = 'https://drive.google.com/uc?id=1p0-TEiW4HgT8MblB399ep4YM3u5A0Edc' 46 | data_root = '/root/dataset/celeba' 47 | download_path_img = f'{data_root}/img_align_celeba.zip' 48 | download_path_anno = f'{data_root}/list_attr_celeba.txt' 49 | if not os.path.exists(data_root): 50 | os.makedirs(data_root) 51 | 52 | if not os.listdir(data_root): 53 | gdown.download(url_anno, download_path_anno, quiet=False) 54 | gdown.download(url_img, download_path_img, quiet=False) 55 | with zipfile.ZipFile(download_path_img, 'r') as ziphandler: 56 | ziphandler.extractall(data_root) 57 | return data_root 58 | 59 | def load_data(self, data_root): 60 | data_path = data_root + '/list_attr_celeba.txt' 61 | df = pd.read_csv(data_path, sep="\s+", skiprows=1) 62 | label_names = list(df.columns)[:-1] 63 | df = df.to_numpy()[:, :-1] 64 | df = np.maximum(df, 0) # -1 -> 0 65 | return df, label_names 66 | 67 | def gen_datasets(self, data_root): 68 | data_path = data_root + '/celebA_set_data.pkl' 69 | if os.path.exists(data_path): 70 | print(f'load data from {data_path}') 71 | label_names, trainData, valData, testData = pickle.load(open(data_path, "rb")) 72 | self.V_train, self.S_train, self.labels_train = trainData['V_train'], trainData['S_train'], trainData['labels_train'] 73 | self.V_val, self.S_val, self.labels_val = valData['V_train'], valData['S_train'], valData['labels_train'] 74 | self.V_test, self.S_test, self.labels_test = testData['V_train'], testData['S_train'], testData['labels_train'] 75 | else: 76 | data, label_names = self.load_data(data_root) 77 | self.V_train, self.S_train, self.labels_train = get_set_celeba_dataset(data, data_size=10000, v_size=self.params.v_size) 78 | self.V_val, self.S_val, self.labels_val = get_set_celeba_dataset(data, data_size=1000, v_size=self.params.v_size) 79 | self.V_test, self.S_test, self.labels_test = get_set_celeba_dataset(data, data_size=1000, v_size=self.params.v_size) 80 | 81 | trainData = {'V_train': self.V_train, 'S_train': self.S_train, 'labels_train': self.labels_train} 82 | valData = {'V_train': self.V_val, 'S_train': self.S_val, 'labels_train': self.labels_val} 83 | testData = {'V_train': self.V_test, 'S_train': self.S_test, 'labels_train': self.labels_test} 84 | pickle.dump((label_names, trainData, valData, testData), open(data_path, "wb")) 85 | 86 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True): 87 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True) 88 | val_dataset = SetDataset(self.V_val, self.S_val, self.params) 89 | test_dataset = SetDataset(self.V_test, self.S_test, self.params) 90 | 91 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 92 | collate_fn=collate_train, shuffle=shuffle_train, num_workers=num_workers) 93 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, 94 | collate_fn=collate_val_and_test, shuffle=False, num_workers=num_workers) 95 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 96 | collate_fn=collate_val_and_test, shuffle=False, num_workers=num_workers) if get_test else None 97 | return train_loader, val_loader, test_loader 98 | 99 | class SetDataset(Dataset): 100 | def __init__(self, U, S, params, is_train=False): 101 | self.data = U 102 | self.labels = S 103 | self.is_train = is_train 104 | self.neg_num = params.neg_num 105 | self.v_size = params.v_size 106 | 107 | def __getitem__(self, index): 108 | V_id = self.data[index] 109 | S = self.labels[index] 110 | V = torch.cat([ load_img(idx.item()) for idx in V_id ], dim=0) 111 | 112 | S_mask = torch.zeros([self.v_size]) 113 | S_mask[S] = 1 114 | if self.is_train: 115 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0] 116 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]] 117 | neg_S_mask = torch.zeros([self.v_size]) 118 | neg_S_mask[S] = 1 119 | neg_S_mask[neg_S] = 1 120 | return V, S_mask, neg_S_mask 121 | 122 | return V, S_mask 123 | 124 | def __len__(self): 125 | return len(self.data) 126 | 127 | def get_set_celeba_dataset(data, data_size, v_size): 128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | data = torch.Tensor(data).to(device) 130 | img_nums = data.shape[0] 131 | 132 | V_list = [] 133 | S_list = [] 134 | label_list = [] 135 | cur_size = 0 136 | pbar = tqdm(total=data_size) 137 | while True: 138 | if cur_size == data_size: break 139 | 140 | nor_id = np.random.randint(img_nums) 141 | nor_label = data[nor_id] 142 | if torch.sum(nor_label).item() < 2: 143 | continue 144 | nor_lable_idxs = torch.nonzero(nor_label).reshape(-1) 145 | perm = torch.randperm(nor_lable_idxs.size(0)) 146 | nor_lable_idxs = nor_lable_idxs[perm[:2]] 147 | nor_label = torch.zeros(nor_label.shape).to(device) 148 | nor_label[nor_lable_idxs] = 1 149 | nor_label = nor_label.reshape(-1, 1) 150 | 151 | s_size = np.random.randint(2, 4) 152 | nor_res = (torch.nonzero((data @ nor_label).squeeze(-1) == 2)).reshape(-1) 153 | ano_res = (torch.nonzero((data @ nor_label).squeeze(-1) == 0)).reshape(-1) 154 | if (nor_res.shape[0] < v_size) or (ano_res.shape[0] < s_size): 155 | continue 156 | 157 | perm = torch.randperm(nor_res.size(0)) 158 | U = nor_res[perm[:v_size]].cpu() 159 | perm = torch.randperm(ano_res.size(0)) 160 | S = ano_res[perm[:s_size]].cpu() 161 | 162 | S_idx = np.random.choice(list(range(v_size)), s_size, replace=False) 163 | U[S_idx] = S 164 | S = torch.Tensor(S_idx).type(torch.int64) 165 | lable_idxs = nor_lable_idxs.cpu() 166 | 167 | V_list.append(U) 168 | S_list.append(S) 169 | label_list.append(lable_idxs) 170 | 171 | cur_size += 1 172 | pbar.update(1) 173 | pbar.close() 174 | return V_list, S_list, label_list 175 | 176 | def collate_train(data): 177 | V, S, neg_S = map(list, zip(*data)) 178 | bs = len(V) 179 | 180 | V = torch.cat(V, dim=0) 181 | S = torch.cat(S, dim=0).reshape(bs, -1) 182 | neg_S = torch.cat(neg_S, dim=0).reshape(bs, -1) 183 | return V, S, neg_S 184 | 185 | def collate_val_and_test(data): 186 | V, S = map(list, zip(*data)) 187 | bs = len(V) 188 | 189 | V = torch.cat(V, dim=0) 190 | S = torch.cat(S, dim=0).reshape(bs, -1) 191 | return V, S 192 | 193 | def load_img(img_id): 194 | img_id = str(img_id + 1) # 0 -> 1 195 | img_path = img_root_path + ( '0' * (6 - len(img_id)) ) + img_id + '.jpg' 196 | img = Image.open(img_path).convert('RGB') 197 | img = img_transform(img).unsqueeze(0) 198 | return img 199 | -------------------------------------------------------------------------------- /data_loader/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn import datasets 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | class Data: 7 | def __init__(self, params): 8 | self.params = params 9 | self.gen_datasets() 10 | 11 | def gen_datasets(self): 12 | raise NotImplementedError 13 | 14 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, 15 | get_test=True): 16 | raise NotImplementedError 17 | 18 | 19 | class GaussianMixture(Data): 20 | def __init__(self, params): 21 | super().__init__(params) 22 | 23 | def gen_datasets(self): 24 | np.random.seed(1) 25 | V_size, S_size = self.params.v_size, self.params.s_size 26 | 27 | self.V_train, self.S_train = get_gaussian_mixture_dataset(V_size, S_size) 28 | self.V_val, self.S_val = get_gaussian_mixture_dataset(V_size, S_size) 29 | self.V_test, self.S_test = get_gaussian_mixture_dataset(V_size, S_size) 30 | 31 | self.fea_size = 2 32 | self.x_lim, self.y_lim = 2, 2 33 | 34 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True): 35 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True) 36 | val_dataset = SetDataset(self.V_val, self.S_val, self.params) 37 | test_dataset = SetDataset(self.V_test, self.S_test, self.params) 38 | 39 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 40 | shuffle=shuffle_train, num_workers=num_workers) 41 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, 42 | shuffle=False, num_workers=num_workers) 43 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 44 | shuffle=False, num_workers=num_workers) if get_test else None 45 | return train_loader, val_loader, test_loader 46 | 47 | class SetDataset(Dataset): 48 | def __init__(self, V, S, params, is_train=False): 49 | self.data = V 50 | self.labels = S 51 | self.is_train = is_train 52 | self.neg_num = params.neg_num 53 | self.v_size = params.v_size 54 | 55 | def __getitem__(self, index): 56 | V = self.data[index] 57 | S = self.labels[index] 58 | 59 | S_mask = torch.zeros([self.v_size]) 60 | S_mask[S] = 1 61 | if self.is_train: 62 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0] 63 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]] 64 | neg_S_mask = torch.zeros([self.v_size]) 65 | neg_S_mask[S] = 1 66 | neg_S_mask[neg_S] = 1 67 | return V, S_mask, neg_S_mask 68 | 69 | return V, S_mask 70 | 71 | def __len__(self): 72 | return len(self.data) 73 | 74 | def gen_gaussian_mixture(batch_size): 75 | scale = 1 76 | centers = [ (1. / np.sqrt(2), 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 77 | centers = [(scale * x, scale * y) for x, y in centers] 78 | 79 | data, labels = [], [] 80 | for i in range(batch_size): 81 | point = np.random.randn(2) * 0.5 82 | idx = np.random.randint(2) 83 | center = centers[idx] 84 | point[0] += center[0] 85 | point[1] += center[1] 86 | data.append(point) 87 | labels.append(idx) 88 | data = np.array(data, dtype="float32") 89 | labels = np.array(labels, dtype="int32") 90 | 91 | noise_label = np.random.randint(2) 92 | noise = data[labels == noise_label] 93 | data = data[labels == (1 - noise_label)] 94 | # dataset /= 1.414 95 | return data, noise 96 | 97 | def get_gaussian_mixture_dataset(V_size, S_size): 98 | V_list, S_list = [], [] 99 | for _ in range(1000): 100 | data, noise = gen_gaussian_mixture(V_size*4) 101 | 102 | V = data[:V_size] 103 | S = np.random.choice(list(range(0,V_size)), S_size, replace=False) 104 | V[S, :] = noise[:S_size] 105 | 106 | V_list.append(V) 107 | S_list.append(S) 108 | 109 | V = torch.FloatTensor(V_list) 110 | S = torch.LongTensor(S_list) 111 | 112 | return V, S 113 | -------------------------------------------------------------------------------- /data_loader/moons.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn import datasets 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | class Data: 7 | def __init__(self, params): 8 | self.params = params 9 | self.gen_datasets() 10 | 11 | def gen_datasets(self): 12 | raise NotImplementedError 13 | 14 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, 15 | get_test=True): 16 | raise NotImplementedError 17 | 18 | class TwoMoons(Data): 19 | def __init__(self, params): 20 | super().__init__(params) 21 | 22 | def gen_datasets(self): 23 | np.random.seed(1) # fix dataset 24 | V_size, S_size = self.params.v_size, self.params.s_size 25 | 26 | self.V_train, self.S_train = get_two_moons_dataset(V_size, S_size, rand_seed=0) 27 | self.V_val, self.S_val = get_two_moons_dataset(V_size, S_size, rand_seed=1) 28 | self.V_test, self.S_test = get_two_moons_dataset(V_size, S_size, rand_seed=2) 29 | 30 | self.fea_size = 2 31 | self.x_lim, self.y_lim = 4, 2 32 | 33 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True): 34 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True) 35 | val_dataset = SetDataset(self.V_val, self.S_val, self.params) 36 | test_dataset = SetDataset(self.V_test, self.S_test, self.params) 37 | 38 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 39 | shuffle=shuffle_train, num_workers=num_workers) 40 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, 41 | shuffle=False, num_workers=num_workers) 42 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 43 | shuffle=False, num_workers=num_workers) if get_test else None 44 | return train_loader, val_loader, test_loader 45 | 46 | class SetDataset(Dataset): 47 | def __init__(self, V, S, params, is_train=False): 48 | self.data = V 49 | self.labels = S 50 | self.is_train = is_train 51 | self.neg_num = params.neg_num 52 | self.v_size = params.v_size 53 | 54 | def __getitem__(self, index): 55 | V = self.data[index] 56 | S = self.labels[index] 57 | 58 | S_mask = torch.zeros([self.v_size]) 59 | S_mask[S] = 1 60 | if self.is_train: 61 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0] 62 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]] 63 | neg_S_mask = torch.zeros([self.v_size]) 64 | neg_S_mask[S] = 1 65 | neg_S_mask[neg_S] = 1 66 | return V, S_mask, neg_S_mask 67 | 68 | return V, S_mask 69 | 70 | def __len__(self): 71 | return len(self.data) 72 | 73 | def gen_moons(batch_size, rand_seed): 74 | data, Y = datasets.make_moons(n_samples=batch_size, noise=0.1, random_state=rand_seed) 75 | data = data.astype("float32") 76 | data = data * 2 + np.array([-1, -0.2]) 77 | 78 | noise_label = np.random.randint(2) 79 | noise = data[Y == noise_label] 80 | data = data[Y == (1 - noise_label)] 81 | return data, noise 82 | 83 | def get_two_moons_dataset(V_size, S_size, rand_seed): 84 | V_list, S_list = [], [] 85 | for idx in range(1000): 86 | data, noise = gen_moons(V_size*2, rand_seed*1000+idx) 87 | 88 | V = data[:V_size] 89 | S = np.random.choice(list(range(0,V_size)), S_size, replace=False) 90 | V[S, :] = noise[:S_size] 91 | 92 | V_list.append(V) 93 | S_list.append(S) 94 | 95 | V = torch.FloatTensor(np.array(V_list)) 96 | S = torch.LongTensor(np.array(S_list)) 97 | 98 | return V, S 99 | -------------------------------------------------------------------------------- /data_loader/pdbbind.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # 6 | # PDBBind dataset processed by moleculenet. 7 | 8 | import dgl.backend as F 9 | import numpy as np 10 | import multiprocessing 11 | import os 12 | import glob 13 | from functools import partial 14 | import pandas as pd 15 | 16 | from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive 17 | 18 | from dgllife.utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization 19 | 20 | __all__ = ['PDBBind'] 21 | 22 | class PDBBind(object): 23 | """PDBbind dataset processed by moleculenet. 24 | 25 | The description below is mainly based on 26 | `[1] `__. 27 | The PDBBind database consists of experimentally measured binding affinities for 28 | bio-molecular complexes `[2] `__, 29 | `[3] `__. It provides detailed 30 | 3D Cartesian coordinates of both ligands and their target proteins derived from experimental 31 | (e.g., X-ray crystallography) measurements. The availability of coordinates of the 32 | protein-ligand complexes permits structure-based featurization that is aware of the 33 | protein-ligand binding geometry. The authors of 34 | `[1] `__ use the 35 | "refined" and "core" subsets of the database 36 | `[4] `__, more carefully 37 | processed for data artifacts, as additional benchmarking targets. 38 | 39 | References: 40 | 41 | * [1] moleculenet: a benchmark for molecular machine learning 42 | * [2] The PDBbind database: collection of binding affinities for protein-ligand complexes 43 | with known three-dimensional structures 44 | * [3] The PDBbind database: methodologies and updates 45 | * [4] PDB-wide collection of binding data: current status of the PDBbind database 46 | 47 | Parameters 48 | ---------- 49 | subset : str 50 | In moleculenet, we can use either the "refined" subset or the "core" subset. We can 51 | retrieve them by setting ``subset`` to be ``'refined'`` or ``'core'``. The size 52 | of the ``'core'`` set is 195 and the size of the ``'refined'`` set is 3706. 53 | pdb_version : str 54 | The version of PDBBind dataset. Currently implemented: ``'v2007'``, ``'v2015'``. 55 | Default to ``'v2015'``. User should not specify the version if using local PDBBind data. 56 | load_binding_pocket : bool 57 | Whether to load binding pockets or full proteins. Default to True. 58 | remove_coreset_from_refinedset: bool 59 | Whether to remove core set from refined set when training with refined set and test with core set. 60 | Default to True. 61 | sanitize : bool 62 | Whether sanitization is performed in initializing RDKit molecule instances. See 63 | https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization. 64 | Default to False. 65 | calc_charges : bool 66 | Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce 67 | ``sanitize`` to be True. Default to False. 68 | remove_hs : bool 69 | Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite 70 | slow for large molecules. Default to False. 71 | use_conformation : bool 72 | Whether we need to extract molecular conformation from proteins and ligands. 73 | Default to True. 74 | construct_graph_and_featurize : callable 75 | Construct a DGLGraph for the use of GNNs. Mapping ``self.ligand_mols[i]``, 76 | ``self.protein_mols[i]``, ``self.ligand_coordinates[i]`` and 77 | ``self.protein_coordinates[i]`` to a DGLGraph. 78 | Default to :func:`dgllife.utils.ACNN_graph_construction_and_featurization`. 79 | zero_padding : bool 80 | Whether to perform zero padding. While DGL does not necessarily require zero padding, 81 | pooling operations for variable length inputs can introduce stochastic behaviour, which 82 | is not desired for sensitive scenarios. Default to True. 83 | num_processes : int or None 84 | Number of worker processes to use. If None, 85 | then we will use the number of CPUs in the system. Default None. 86 | local_path : str or None 87 | Local path of existing PDBBind dataset. 88 | Default None, and PDBBind dataset will be downloaded from DGL database. 89 | Specify this argument to a local path of customized dataset, which should follow the structure and the naming format of PDBBind v2015. 90 | """ 91 | def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False, 92 | calc_charges=False, remove_hs=False, use_conformation=True, 93 | construct_graph_and_featurize=ACNN_graph_construction_and_featurization, 94 | zero_padding=True, num_processes=None, local_path=None): 95 | self.task_names = ['-logKd/Ki'] 96 | self.n_tasks = len(self.task_names) 97 | self._read_data_files(pdb_version, subset, load_binding_pocket, remove_coreset_from_refinedset, local_path) 98 | self._preprocess(load_binding_pocket, 99 | sanitize, calc_charges, remove_hs, use_conformation, 100 | construct_graph_and_featurize, zero_padding, num_processes) 101 | # Prepare for Refined, Agglomerative Sequence Split and Agglomerative Structure Split 102 | if pdb_version == 'v2007' and not local_path: 103 | merged_df = self.df.merge(self.agg_split, on='PDB_code') 104 | self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index) 105 | for target_set in ['train', 'valid', 'test']] 106 | self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index) 107 | for target_set in ['train', 'valid', 'test']] 108 | 109 | def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_coreset_from_refinedset, local_path): 110 | """Download and extract pdbbind data files specified by the version""" 111 | # root_dir_path = get_download_dir() 112 | root_dir_path = '/root/dataset/.dgl' 113 | if local_path: 114 | if local_path[-1] != '/': 115 | local_path += '/' 116 | index_label_file = glob.glob(local_path + '*' + subset + '*data*')[0] 117 | elif pdb_version == 'v2015': 118 | self._url = 'dataset/pdbbind_v2015.tar.gz' 119 | data_path = root_dir_path + '/pdbbind_v2015.tar.gz' 120 | extracted_data_path = root_dir_path + '/pdbbind_v2015' 121 | download(_get_dgl_url(self._url), path=data_path, overwrite=False) 122 | extract_archive(data_path, extracted_data_path) 123 | 124 | if subset == 'core': 125 | index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013' 126 | elif subset == 'refined': 127 | index_label_file = extracted_data_path + '/v2015/INDEX_refined_data.2015' 128 | else: 129 | raise ValueError('Expect the subset_choice to be either core or refined, got {}'.format(subset)) 130 | elif pdb_version == 'v2007': 131 | self._url = 'dataset/pdbbind_v2007.tar.gz' 132 | data_path = root_dir_path + '/pdbbind_v2007.tar.gz' 133 | extracted_data_path = root_dir_path + '/pdbbind_v2007' 134 | download(_get_dgl_url(self._url), path=data_path, overwrite=False) 135 | extract_archive(data_path, extracted_data_path, overwrite=False) 136 | extracted_data_path += '/home/ubuntu' # extra layer 137 | 138 | # DataFrame containing the pdbbind_2007_agglomerative_split.txt 139 | self.agg_split = pd.read_csv(extracted_data_path + '/v2007/pdbbind_2007_agglomerative_split.txt') 140 | self.agg_split.rename(columns={'PDB ID':'PDB_code', 'Sequence-based assignment':'sequence', 'Structure-based assignment':'structure'}, inplace=True) 141 | self.agg_split.loc[self.agg_split['PDB_code']=='1.00E+66', 'PDB_code'] = '1e66' # fix typo 142 | if subset == 'core': 143 | index_label_file = extracted_data_path + '/v2007/INDEX.2007.core.data' 144 | elif subset == 'refined': 145 | index_label_file = extracted_data_path + '/v2007/INDEX.2007.refined.data' 146 | else: 147 | raise ValueError('Expect the subset_choice to be either core or refined, got {}'.format(subset)) 148 | 149 | contents = [] 150 | with open(index_label_file, 'r') as f: 151 | for line in f.readlines(): 152 | if line[0] != "#": 153 | splitted_elements = line.split() 154 | if pdb_version == 'v2015': 155 | if len(splitted_elements) == 8: 156 | # Ignore "//" 157 | contents.append(splitted_elements[:5] + splitted_elements[6:]) 158 | else: 159 | print('Incorrect data format.') 160 | print(splitted_elements) 161 | elif pdb_version == 'v2007': 162 | if len(splitted_elements) == 6: 163 | contents.append(splitted_elements) 164 | else: 165 | contents.append(splitted_elements[:5] + [' '.join(splitted_elements[5:])]) 166 | 167 | if pdb_version == 'v2015': 168 | self.df = pd.DataFrame(contents, columns=( 169 | 'PDB_code', 'resolution', 'release_year', 170 | '-logKd/Ki', 'Kd/Ki', 'reference', 'ligand_name')) 171 | elif pdb_version == 'v2007': 172 | self.df = pd.DataFrame(contents, columns=( 173 | 'PDB_code', 'resolution', 'release_year', 174 | '-logKd/Ki', 'Kd/Ki', 'cluster_ID')) 175 | 176 | pdbs = self.df['PDB_code'].tolist() 177 | 178 | # remove core set from refined set if using refined 179 | if remove_coreset_from_refinedset and subset == 'refined': 180 | if local_path: 181 | core_path = glob.glob(local_path + '*core*data*')[0] 182 | elif pdb_version == 'v2015': 183 | core_path = extracted_data_path + '/v2015/INDEX_core_data.2013' 184 | elif pdb_version == 'v2007': 185 | core_path = extracted_data_path + '/v2007/INDEX.2007.core.data' 186 | 187 | with open(core_path,'r') as f: 188 | for line in f: 189 | fields = line.strip().split() 190 | if fields[0] != "#" and fields[0] in pdbs: 191 | pdbs.remove(fields[0]) 192 | 193 | if local_path: 194 | pdb_path = local_path 195 | else: 196 | pdb_path = os.path.join(extracted_data_path, pdb_version) 197 | print('Loading PDBBind data from', pdb_path) 198 | self.ligand_files = [os.path.join(pdb_path, pdb, '{}_ligand.sdf'.format(pdb)) for pdb in pdbs] 199 | if load_binding_pocket: 200 | self.protein_files = [os.path.join(pdb_path, pdb, '{}_pocket.pdb'.format(pdb)) for pdb in pdbs] 201 | else: 202 | self.protein_files = [os.path.join(pdb_path, pdb, '{}_protein.pdb'.format(pdb)) for pdb in pdbs] 203 | 204 | def _filter_out_invalid(self, ligands_loaded, proteins_loaded, use_conformation): 205 | """Filter out invalid ligand-protein pairs. 206 | 207 | Parameters 208 | ---------- 209 | ligands_loaded : list 210 | Each element is a 2-tuple of the RDKit molecule instance and its associated atom 211 | coordinates. None is used to represent invalid/non-existing molecule or coordinates. 212 | proteins_loaded : list 213 | Each element is a 2-tuple of the RDKit molecule instance and its associated atom 214 | coordinates. None is used to represent invalid/non-existing molecule or coordinates. 215 | use_conformation : bool 216 | Whether we need conformation information (atom coordinates) and filter out molecules 217 | without valid conformation. 218 | """ 219 | num_pairs = len(proteins_loaded) 220 | self.indices, self.ligand_mols, self.protein_mols = [], [], [] 221 | if use_conformation: 222 | self.ligand_coordinates, self.protein_coordinates = [], [] 223 | else: 224 | # Use None for placeholders. 225 | self.ligand_coordinates = [None for _ in range(num_pairs)] 226 | self.protein_coordinates = [None for _ in range(num_pairs)] 227 | 228 | for i in range(num_pairs): 229 | ligand_mol, ligand_coordinates = ligands_loaded[i] 230 | protein_mol, protein_coordinates = proteins_loaded[i] 231 | if (not use_conformation) and all(v is not None for v in [protein_mol, ligand_mol]): 232 | self.indices.append(i) 233 | self.ligand_mols.append(ligand_mol) 234 | self.protein_mols.append(protein_mol) 235 | elif all(v is not None for v in [ 236 | protein_mol, protein_coordinates, ligand_mol, ligand_coordinates]): 237 | self.indices.append(i) 238 | self.ligand_mols.append(ligand_mol) 239 | self.ligand_coordinates.append(ligand_coordinates) 240 | self.protein_mols.append(protein_mol) 241 | self.protein_coordinates.append(protein_coordinates) 242 | 243 | def _preprocess(self, load_binding_pocket, 244 | sanitize, calc_charges, remove_hs, use_conformation, 245 | construct_graph_and_featurize, zero_padding, num_processes): 246 | """Preprocess the dataset. 247 | 248 | The pre-processing proceeds as follows: 249 | 250 | 1. Load the dataset 251 | 2. Clean the dataset and filter out invalid pairs 252 | 3. Construct graphs 253 | 4. Prepare node and edge features 254 | 255 | Parameters 256 | ---------- 257 | load_binding_pocket : bool 258 | Whether to load binding pockets or full proteins. 259 | sanitize : bool 260 | Whether sanitization is performed in initializing RDKit molecule instances. See 261 | https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization. 262 | calc_charges : bool 263 | Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce 264 | ``sanitize`` to be True. 265 | remove_hs : bool 266 | Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite 267 | slow for large molecules. 268 | use_conformation : bool 269 | Whether we need to extract molecular conformation from proteins and ligands. 270 | construct_graph_and_featurize : callable 271 | Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i], 272 | self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i] 273 | to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`. 274 | zero_padding : bool 275 | Whether to perform zero padding. While DGL does not necessarily require zero padding, 276 | pooling operations for variable length inputs can introduce stochastic behaviour, which 277 | is not desired for sensitive scenarios. 278 | num_processes : int or None 279 | Number of worker processes to use. If None, 280 | then we will use the number of CPUs in the system. 281 | """ 282 | if num_processes is None: 283 | num_processes = multiprocessing.cpu_count() - 1 284 | num_processes = min(num_processes, len(self.df)) 285 | 286 | print('Loading ligands...') 287 | ligands_loaded = multiprocess_load_molecules(self.ligand_files, 288 | sanitize=sanitize, 289 | calc_charges=calc_charges, 290 | remove_hs=remove_hs, 291 | use_conformation=use_conformation, 292 | num_processes=num_processes) 293 | 294 | print('Loading proteins...') 295 | proteins_loaded = multiprocess_load_molecules(self.protein_files, 296 | sanitize=sanitize, 297 | calc_charges=calc_charges, 298 | remove_hs=remove_hs, 299 | use_conformation=use_conformation, 300 | num_processes=num_processes) 301 | 302 | self._filter_out_invalid(ligands_loaded, proteins_loaded, use_conformation) 303 | self.df = self.df.iloc[self.indices] 304 | self.labels = F.zerocopy_from_numpy(self.df[self.task_names].values.astype(np.float32)) 305 | print('Finished cleaning the dataset, ' 306 | 'got {:d}/{:d} valid pairs'.format(len(self), len(self.ligand_files))) # account for the ones use_conformation failed 307 | 308 | # Prepare zero padding 309 | if zero_padding: 310 | max_num_ligand_atoms = 0 311 | max_num_protein_atoms = 0 312 | for i in range(len(self)): 313 | max_num_ligand_atoms = max( 314 | max_num_ligand_atoms, self.ligand_mols[i].GetNumAtoms()) 315 | max_num_protein_atoms = max( 316 | max_num_protein_atoms, self.protein_mols[i].GetNumAtoms()) 317 | else: 318 | max_num_ligand_atoms = None 319 | max_num_protein_atoms = None 320 | 321 | construct_graph_and_featurize = partial(construct_graph_and_featurize, 322 | max_num_ligand_atoms=max_num_ligand_atoms, 323 | max_num_protein_atoms=max_num_protein_atoms) 324 | 325 | print('Start constructing graphs and featurizing them.') 326 | num_mols = len(self) 327 | 328 | # construct graphs with multiprocessing 329 | # from signal import signal, SIGPIPE, SIG_DFL 330 | # signal(SIGPIPE,SIG_DFL) 331 | pool = multiprocessing.Pool(processes=num_processes) 332 | self.graphs = pool.starmap(construct_graph_and_featurize, 333 | zip(self.ligand_mols, self.protein_mols, 334 | self.ligand_coordinates, self.protein_coordinates)) 335 | print(f'Done constructing {len(self.graphs)} graphs.') 336 | 337 | 338 | def __len__(self): 339 | """Get the size of the dataset. 340 | 341 | Returns 342 | ------- 343 | int 344 | Number of valid ligand-protein pairs in the dataset. 345 | """ 346 | return len(self.indices) 347 | 348 | def __getitem__(self, item): 349 | """Get the datapoint associated with the index. 350 | 351 | Parameters 352 | ---------- 353 | item : int 354 | Index for the datapoint. 355 | 356 | Returns 357 | ------- 358 | int 359 | Index for the datapoint. 360 | rdkit.Chem.rdchem.Mol 361 | RDKit molecule instance for the ligand molecule. 362 | rdkit.Chem.rdchem.Mol 363 | RDKit molecule instance for the protein molecule. 364 | DGLGraph or tuple of DGLGraphs 365 | Pre-processed DGLGraph with features extracted. 366 | For ACNN, a single DGLGraph; 367 | For PotentialNet, a tuple of DGLGraphs that consists of a molecular graph and a KNN graph of the complex. 368 | Float32 tensor 369 | Label for the datapoint. 370 | """ 371 | return item, self.ligand_mols[item], self.protein_mols[item], \ 372 | self.graphs[item], self.labels[item] 373 | -------------------------------------------------------------------------------- /data_loader/set_pdbbind.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dgl 3 | import torch 4 | import pickle 5 | import numpy as np 6 | from tqdm import tqdm 7 | from rdkit import Chem 8 | from sklearn.cluster import AffinityPropagation 9 | from rdkit.DataStructs import FingerprintSimilarity 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | from .pdbbind import PDBBind 13 | 14 | class SetPDBBind(object): 15 | def __init__(self, params): 16 | super().__init__() 17 | self.params = params 18 | self.gen_datasets() 19 | 20 | def gen_datasets(self): 21 | np.random.seed(1) # fix dataset 22 | V_size, S_size = self.params.v_size, self.params.s_size 23 | self.dataset = load_pdbbind(self.params) 24 | 25 | data_root = '/root/dataset/pdbbind' 26 | data_path = os.path.join(data_root, 'pdbbind_set_data.pkl') 27 | if os.path.exists(data_path): 28 | print(f'load data from {data_path}') 29 | trainData, valData, testData = pickle.load(open(data_path, "rb")) 30 | self.V_train, self.S_train = trainData['V_train'], trainData['S_train'] 31 | self.V_val, self.S_val = valData['V_train'], valData['S_train'] 32 | self.V_test, self.S_test = testData['V_train'], testData['S_train'] 33 | else: 34 | self.V_train, self.S_train = get_set_pdbbind_dataset_activate(self.dataset, V_size, S_size, size=1000) 35 | self.V_val, self.S_val = get_set_pdbbind_dataset_activate(self.dataset, V_size, S_size, size=100) 36 | self.V_test, self.S_test = get_set_pdbbind_dataset_activate(self.dataset, V_size, S_size, size=100) 37 | 38 | trainData = {'V_train': self.V_train, 'S_train': self.S_train} 39 | valData = {'V_train': self.V_val, 'S_train': self.S_val} 40 | testData = {'V_train': self.V_test, 'S_train': self.S_test} 41 | if not os.path.exists(data_root): 42 | os.makedirs(data_root) 43 | pickle.dump((trainData, valData, testData), open(data_path, "wb")) 44 | 45 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True): 46 | train_dataset = SetDataset(self.dataset, self.V_train, self.S_train, self.params, is_train=True) 47 | val_dataset = SetDataset(self.dataset, self.V_val, self.S_val, self.params) 48 | test_dataset = SetDataset(self.dataset, self.V_test, self.S_test, self.params) 49 | 50 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 51 | collate_fn=collate_train, pin_memory=True, shuffle=shuffle_train, num_workers=num_workers) 52 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, 53 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers) 54 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 55 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers) if get_test else None 56 | 57 | return train_loader, val_loader, test_loader 58 | 59 | def collate_train(data): 60 | U, S, neg_S = map(list, zip(*data)) 61 | bs, vs = len(U), U[0].shape[0] 62 | 63 | bg = dgl.batch([ gs[idx] for gs in U for idx in range(vs)]) 64 | for nty in bg.ntypes: 65 | bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty) 66 | for ety in bg.canonical_etypes: 67 | bg.set_e_initializer(dgl.init.zero_initializer, etype=ety) 68 | 69 | S = torch.cat(S, dim=0).reshape(bs, -1) 70 | neg_S = torch.cat(neg_S, dim=0).reshape(bs, -1) 71 | return bg, S, neg_S 72 | 73 | def collate_val_and_test(data): 74 | U, S = map(list, zip(*data)) 75 | bs, vs = len(U), U[0].shape[0] 76 | 77 | bg = dgl.batch([ gs[idx] for gs in U for idx in range(vs)]) 78 | for nty in bg.ntypes: 79 | bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty) 80 | for ety in bg.canonical_etypes: 81 | bg.set_e_initializer(dgl.init.zero_initializer, etype=ety) 82 | 83 | S = torch.cat(S, dim=0).reshape(bs, -1) 84 | return bg, S 85 | 86 | class SetDataset(Dataset): 87 | def __init__(self, dataset, V_idxs, S_idxs, params, is_train=False): 88 | _, self.graphs, _ = dataset 89 | self.V_idxs, self.S_idxs = V_idxs, S_idxs 90 | self.is_train = is_train 91 | self.neg_num = params.neg_num 92 | self.v_size = params.v_size 93 | 94 | def __getitem__(self, index): 95 | V_idxs, S = np.array(self.V_idxs[index]), np.array(self.S_idxs[index]) 96 | V_graphs = np.array([self.graphs[item] for item in V_idxs]) 97 | 98 | S_mask = torch.zeros([self.v_size]) 99 | S_mask[S] = 1 100 | if self.is_train: 101 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0] 102 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]] 103 | neg_S_mask = torch.zeros([self.v_size]) 104 | neg_S_mask[S] = 1 105 | neg_S_mask[neg_S] = 1 106 | return V_graphs, S_mask, neg_S_mask 107 | 108 | return V_graphs, S_mask 109 | 110 | def __len__(self): 111 | return len(self.V_idxs) 112 | 113 | def load_pdbbind(params): 114 | params.subset = 'core' 115 | dataset = PDBBind(subset=params.subset) 116 | 117 | # decompose dataset 118 | _, ligand_mols, protein_mols, graphs, labels = map(list, zip(*dataset)) 119 | ligand_mols = np.array(ligand_mols) 120 | graphs = np.array(graphs) 121 | labels = torch.stack(labels, dim=0) 122 | return (ligand_mols, graphs, labels) 123 | 124 | def get_set_pdbbind_dataset_activate(dataset, V_size, S_size, size=1000): 125 | """ 126 | Generate dataset for compound selection with only the bioactivity filter 127 | """ 128 | mols, _, labels = dataset 129 | data_size = len(mols) 130 | 131 | V_list, S_list = [], [] 132 | for _ in tqdm(range(size)): 133 | V_idxs = np.random.permutation(data_size)[:V_size] 134 | sub_labels = labels[V_idxs].squeeze(dim=-1) 135 | _, idxs = torch.topk(sub_labels, S_size) 136 | 137 | V_list.append(V_idxs) 138 | S_list.append(idxs) 139 | return np.array(V_list), np.array(S_list) 140 | 141 | def get_set_pdbbind_dataset(dataset, V_size, S_size, size=1000): 142 | """ 143 | Generate dataset for compound selection with bioactivity and diversity filters 144 | """ 145 | mols, _, labels = dataset 146 | data_size = len(mols) 147 | 148 | V_list, S_list = [], [] 149 | pbar = tqdm(total=size) 150 | num = 0 151 | while True: 152 | if num == size: break 153 | 154 | V_idxs = np.random.permutation(data_size)[:V_size] 155 | sub_labels = labels[V_idxs].squeeze(dim=-1) 156 | _, idxs = torch.topk(sub_labels, V_size // 3) 157 | filter_idxs = V_idxs[idxs] 158 | S_idxs = get_os_oracle(mols, filter_idxs) 159 | 160 | if len(S_idxs) == 0: continue 161 | 162 | V_list.append(V_idxs) 163 | S_list.append([np.where(V_idxs == item)[0][0] for item in S_idxs]) 164 | 165 | num += 1 166 | pbar.update(1) 167 | return np.array(V_list), np.array(S_list) 168 | 169 | def get_os_oracle(mols, filter_idxs): 170 | # reference: https://nyxflower.com/2020/07/17/molecule-clustering/ 171 | mol_list = mols[filter_idxs] 172 | n = len(mol_list) 173 | 174 | sm = np.zeros((n, n)) 175 | for i in range(n): 176 | for j in range(i, n): 177 | m1, m2 = mol_list[i], mol_list[j] 178 | sm[i, j] = FingerprintSimilarity(Chem.RDKFingerprint(m1), Chem.RDKFingerprint(m2)) 179 | sm = sm + sm.T - np.eye(n) 180 | af = AffinityPropagation().fit(sm) 181 | cluster_centers_indices = af.cluster_centers_indices_ 182 | return filter_idxs[cluster_centers_indices] 183 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model.EquiVSet import EquiVSet 2 | from utils.config import MOONS_CONFIG, GAUSSIAN_CONFIG, BINDINGDB_CONFIG, AMAZON_CONFIG, CELEBA_CONFIG, PDBBIND_CONFIG 3 | 4 | if __name__ == "__main__": 5 | argparser = EquiVSet.get_model_specific_argparser() 6 | hparams = argparser.parse_args() 7 | 8 | data_name = hparams.data_name 9 | if data_name == 'moons': 10 | hparams.__dict__.update(MOONS_CONFIG) 11 | elif data_name == 'gaussian': 12 | hparams.__dict__.update(GAUSSIAN_CONFIG) 13 | elif data_name == 'amazon': 14 | hparams.__dict__.update(AMAZON_CONFIG) 15 | elif data_name == 'celeba': 16 | hparams.__dict__.update(CELEBA_CONFIG) 17 | elif data_name == 'pdbbind': 18 | hparams.__dict__.update(PDBBIND_CONFIG) 19 | elif data_name == 'bindingdb': 20 | hparams.__dict__.update(BINDINGDB_CONFIG) 21 | else: 22 | raise ValueError('invalid dataset...') 23 | 24 | model = EquiVSet(hparams) 25 | 26 | if hparams.train: 27 | model.run_training_sessions() 28 | else: 29 | model.load() 30 | print('Loaded model with: %s' % model.flag_hparams()) 31 | 32 | val_perf, test_perf = model.run_test() 33 | print('Val: {:8.2f}'.format(val_perf)) 34 | print('Test: {:8.2f}'.format(test_perf)) -------------------------------------------------------------------------------- /model/EquiVSet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from model.base_model import Base_Model 4 | from model.modules import SetFuction, RecNet 5 | 6 | 7 | class EquiVSet(Base_Model): 8 | def __init__(self, hparams): 9 | super().__init__(hparams=hparams) 10 | 11 | def define_parameters(self): 12 | self.set_func = SetFuction(params=self.hparams) 13 | self.rec_net = RecNet(params=self.hparams) if self.hparams.mode != 'diffMF' else None 14 | 15 | def configure_optimizers(self): 16 | optim_energy = torch.optim.Adam(self.set_func.parameters(), lr=self.hparams.lr, 17 | weight_decay=self.hparams.weight_decay) 18 | optim_var = torch.optim.Adam(self.rec_net.parameters(), lr=self.hparams.lr, 19 | weight_decay=self.hparams.weight_decay) if self.hparams.mode != 'diffMF' else None 20 | return optim_energy, optim_var 21 | 22 | def configure_gradient_clippers(self): 23 | return [(self.parameters(), self.hparams.clip)] 24 | 25 | def inference(self, V, bs): 26 | if self.hparams.mode == 'diffMF': 27 | bs, vs = V.shape[:2] 28 | q = .5 * torch.ones(bs, vs).to(V.device) 29 | else: 30 | # mode == 'ind' or 'copula' 31 | q = self.rec_net.get_vardist(V, bs) 32 | 33 | for i in range(self.hparams.RNN_steps): 34 | sample_matrix_1, sample_matrix_0 = self.set_func.MC_sampling(q, self.hparams.num_samples) 35 | q = self.set_func.mean_field_iteration(V, sample_matrix_1, sample_matrix_0) 36 | 37 | return q 38 | 39 | def get_hparams_grid(self): 40 | grid = Base_Model.get_general_hparams_grid() 41 | grid.update({ 42 | 'RNN_steps': [1], 43 | 'num_samples': [1, 5, 10, 15, 20], 44 | 'rank': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], 45 | 'tau': [0.01, 0.03, 0.05, 0.07, 0.09, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 5, 10], 46 | }) 47 | return grid 48 | 49 | @staticmethod 50 | def get_model_specific_argparser(): 51 | parser = Base_Model.get_general_argparser() 52 | 53 | parser.add_argument('--mode', type=str, default='copula', 54 | choices=['diffMF', 'ind', 'copula'], 55 | help='name of the variant model [%(default)s]') 56 | 57 | parser.add_argument('--RNN_steps', type=int, default=1, 58 | help='num of RNN steps [%(default)d]') 59 | parser.add_argument('--num_samples', type=int, default=5, 60 | help='num of Monte Carlo samples [%(default)d]') 61 | parser.add_argument('--rank', type=int, default=5, 62 | help='rank of the perturbation matrix [%(default)d]') 63 | parser.add_argument('--tau', type=float, default=0.1, 64 | help='temperature of the relaxed multivariate bernoulli [%(default)g]') 65 | parser.add_argument('--neg_num', type=int, default=1, 66 | help='num of the negtive item [%(default)d]') 67 | 68 | return parser 69 | -------------------------------------------------------------------------------- /model/acnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # 6 | # Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity""" 7 | # pylint: disable=C0103, C0123, W0221, E1101, R1721 8 | 9 | import itertools 10 | import dgl 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | 15 | from dgl.nn.pytorch import AtomicConv 16 | 17 | __all__ = ['ACNN'] 18 | 19 | def truncated_normal_(tensor, mean=0., std=1.): 20 | """Fills the given tensor in-place with elements sampled from the truncated normal 21 | distribution parameterized by mean and std. 22 | 23 | The generated values follow a normal distribution with specified mean and 24 | standard deviation, except that values whose magnitude is more than 2 std 25 | from the mean are dropped. 26 | 27 | We credit to Ruotian Luo for this implementation: 28 | https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15. 29 | 30 | Parameters 31 | ---------- 32 | tensor : Float32 tensor of arbitrary shape 33 | Tensor to be filled. 34 | mean : float 35 | Mean of the truncated normal distribution. 36 | std : float 37 | Standard deviation of the truncated normal distribution. 38 | """ 39 | shape = tensor.shape 40 | tmp = tensor.new_empty(shape + (4,)).normal_() 41 | valid = (tmp < 2) & (tmp > -2) 42 | ind = valid.max(-1, keepdim=True)[1] 43 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 44 | tensor.data.mul_(std).add_(mean) 45 | 46 | class MLP(nn.Module): 47 | """MLP with linear output""" 48 | 49 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 50 | """MLP layers construction 51 | 52 | Paramters 53 | --------- 54 | num_layers: int 55 | The number of linear layers 56 | input_dim: int 57 | The dimensionality of input features 58 | hidden_dim: int 59 | The dimensionality of hidden units at ALL layers 60 | output_dim: int 61 | The number of classes for prediction 62 | 63 | """ 64 | super(MLP, self).__init__() 65 | self.linear_or_not = True # default is linear model 66 | self.num_layers = num_layers 67 | self.output_dim = output_dim 68 | 69 | if num_layers < 1: 70 | raise ValueError("number of layers should be positive!") 71 | elif num_layers == 1: 72 | # Linear model 73 | self.linear = nn.Linear(input_dim, output_dim) 74 | else: 75 | # Multi-layer model 76 | self.linear_or_not = False 77 | self.linears = torch.nn.ModuleList() 78 | self.batch_norms = torch.nn.ModuleList() 79 | 80 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 81 | for layer in range(num_layers - 2): 82 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 83 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 84 | 85 | for layer in range(num_layers - 1): 86 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 87 | 88 | def forward(self, x): 89 | if self.linear_or_not: 90 | # If linear model 91 | return self.linear(x) 92 | else: 93 | # If MLP 94 | h = x 95 | for i in range(self.num_layers - 1): 96 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 97 | return self.linears[-1](h) 98 | 99 | class ACNNPredictor(nn.Module): 100 | """Predictor for ACNN. 101 | 102 | Parameters 103 | ---------- 104 | in_size : int 105 | Number of radial filters used. 106 | hidden_sizes : list of int 107 | Specifying the hidden sizes for all layers in the predictor. 108 | weight_init_stddevs : list of float 109 | Specifying the standard deviations to use for truncated normal 110 | distributions in initialzing weights for the predictor. 111 | dropouts : list of float 112 | Specifying the dropouts to use for all layers in the predictor. 113 | features_to_use : None or float tensor of shape (T) 114 | In the original paper, these are atomic numbers to consider, representing the types 115 | of atoms. T for the number of types of atomic numbers. Default to None. 116 | num_tasks : int 117 | Output size. 118 | """ 119 | def __init__(self, in_size, hidden_sizes, weight_init_stddevs, 120 | dropouts, features_to_use, num_tasks): 121 | super(ACNNPredictor, self).__init__() 122 | 123 | if type(features_to_use) != type(None): 124 | in_size *= len(features_to_use) 125 | 126 | modules = [] 127 | for i, h in enumerate(hidden_sizes): 128 | linear_layer = nn.Linear(in_size, h) 129 | truncated_normal_(linear_layer.weight, std=weight_init_stddevs[i]) 130 | modules.append(linear_layer) 131 | modules.append(nn.ReLU()) 132 | modules.append(nn.Dropout(dropouts[i])) 133 | in_size = h 134 | linear_layer = nn.Linear(in_size, num_tasks) 135 | truncated_normal_(linear_layer.weight, std=weight_init_stddevs[-1]) 136 | modules.append(linear_layer) 137 | self.project = nn.Sequential(*modules) 138 | self.fea_layer = MLP(num_layers=1, input_dim=1922, hidden_dim=2048, output_dim=256) 139 | 140 | def forward(self, batch_size, frag1_node_indices_in_complex, frag2_node_indices_in_complex, 141 | ligand_conv_out, protein_conv_out, complex_conv_out): 142 | """Perform the prediction. 143 | 144 | Parameters 145 | ---------- 146 | batch_size : int 147 | Number of datapoints in a batch. 148 | frag1_node_indices_in_complex : Int64 tensor of shape (V1) 149 | Indices for atoms in the first fragment (protein) in the batched complex. 150 | frag2_node_indices_in_complex : list of int of length V2 151 | Indices for atoms in the second fragment (ligand) in the batched complex. 152 | ligand_conv_out : Float32 tensor of shape (V2, K * T) 153 | Updated ligand node representations. V2 for the number of atoms in the 154 | ligand, K for the number of radial filters, and T for the number of types 155 | of atomic numbers. 156 | protein_conv_out : Float32 tensor of shape (V1, K * T) 157 | Updated protein node representations. V1 for the number of 158 | atoms in the protein, K for the number of radial filters, 159 | and T for the number of types of atomic numbers. 160 | complex_conv_out : Float32 tensor of shape (V1 + V2, K * T) 161 | Updated complex node representations. V1 and V2 separately 162 | for the number of atoms in the ligand and protein, K for 163 | the number of radial filters, and T for the number of 164 | types of atomic numbers. 165 | 166 | Returns 167 | ------- 168 | Float32 tensor of shape (B, O) 169 | Predicted protein-ligand binding affinity. B for the number 170 | of protein-ligand pairs in the batch and O for the number of tasks. 171 | """ 172 | 173 | ligand_feats = self.project(ligand_conv_out) # (V1, O) 174 | protein_feats = self.project(protein_conv_out) # (V2, O) 175 | complex_feats = self.project(complex_conv_out) # (V1+V2, O) 176 | 177 | ligand_energy = ligand_feats.reshape(batch_size, -1) # (B, O) 178 | protein_energy = protein_feats.reshape(batch_size, -1) # (B, O) 179 | 180 | complex_ligand_energy = complex_feats[frag1_node_indices_in_complex].reshape( 181 | batch_size, -1) 182 | complex_protein_energy = complex_feats[frag2_node_indices_in_complex].reshape( 183 | batch_size, -1) 184 | complex_energy = torch.cat([complex_ligand_energy, complex_protein_energy], dim=-1) 185 | 186 | fea = torch.cat([ligand_energy, protein_energy, complex_energy], dim=-1) 187 | return self.fea_layer(fea) 188 | 189 | 190 | class ACNN(nn.Module): 191 | """Atomic Convolutional Networks. 192 | 193 | The model was proposed in `Atomic Convolutional Networks for 194 | Predicting Protein-Ligand Binding Affinity `__. 195 | 196 | The prediction proceeds as follows: 197 | 198 | 1. Perform message passing to update atom representations for the 199 | ligand, protein and protein-ligand complex. 200 | 2. Predict the energy of atoms from their representations with an MLP. 201 | 3. Take the sum of predicted energy of atoms within each molecule for 202 | predicted energy of the ligand, protein and protein-ligand complex. 203 | 4. Make the final prediction by subtracting the predicted ligand and protein 204 | energy from the predicted complex energy. 205 | 206 | Parameters 207 | ---------- 208 | hidden_sizes : list of int 209 | ``hidden_sizes[i]`` gives the size of hidden representations in the i-th 210 | hidden layer of the MLP. By Default, ``[32, 32, 16]`` will be used. 211 | weight_init_stddevs : list of float 212 | ``weight_init_stddevs[i]`` gives the std to initialize parameters in the 213 | i-th layer of the MLP. Note that ``len(weight_init_stddevs) == len(hidden_sizes) + 1`` 214 | due to the output layer. By default, we use ``1 / sqrt(hidden_sizes[i])`` for hidden 215 | layers and 0.01 for the output layer. 216 | dropouts : list of float 217 | ``dropouts[i]`` gives the dropout in the i-th hidden layer of the MLP. By default, 218 | no dropout is used. 219 | features_to_use : None or float tensor of shape (T) 220 | In the original paper, these are atomic numbers to consider, representing the types 221 | of atoms. T for the number of types of atomic numbers. If None, we use same parameters 222 | for all atoms regardless of their type. Default to None. 223 | radial : list 224 | The list consists of 3 sublists of floats, separately for the 225 | options of interaction cutoff, the options of rbf kernel mean and the 226 | options of rbf kernel scaling. By default, 227 | ``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used. 228 | num_tasks : int 229 | Number of output tasks. Default to 1. 230 | """ 231 | def __init__(self, hidden_sizes=None, weight_init_stddevs=None, dropouts=None, 232 | features_to_use=None, radial=None, num_tasks=1): 233 | super(ACNN, self).__init__() 234 | 235 | if hidden_sizes is None: 236 | hidden_sizes = [32, 32, 16] 237 | 238 | if weight_init_stddevs is None: 239 | weight_init_stddevs = [1. / float(np.sqrt(hidden_sizes[i])) 240 | for i in range(len(hidden_sizes))] 241 | weight_init_stddevs.append(0.01) 242 | 243 | if dropouts is None: 244 | dropouts = [0. for _ in range(len(hidden_sizes))] 245 | 246 | if radial is None: 247 | radial = [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]] 248 | # Take the product of sets of options and get a list of 3-tuples. 249 | radial_params = [x for x in itertools.product(*radial)] 250 | radial_params = torch.stack(list(map(torch.tensor, zip(*radial_params))), dim=1) 251 | 252 | interaction_cutoffs = radial_params[:, 0] 253 | rbf_kernel_means = radial_params[:, 1] 254 | rbf_kernel_scaling = radial_params[:, 2] 255 | 256 | self.ligand_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, 257 | rbf_kernel_scaling, features_to_use) 258 | self.protein_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, 259 | rbf_kernel_scaling, features_to_use) 260 | self.complex_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, 261 | rbf_kernel_scaling, features_to_use) 262 | self.predictor = ACNNPredictor(radial_params.shape[0], hidden_sizes, 263 | weight_init_stddevs, dropouts, features_to_use, num_tasks) 264 | 265 | def forward(self, graph): 266 | """Apply the model for prediction. 267 | 268 | Parameters 269 | ---------- 270 | graph : DGLHeteroGraph 271 | DGLHeteroGraph consisting of the ligand graph, the protein graph 272 | and the complex graph, along with preprocessed features. For a batch of 273 | protein-ligand pairs, we assume zero padding is performed so that the 274 | number of ligand and protein atoms is the same in all pairs. 275 | 276 | Returns 277 | ------- 278 | Float32 tensor of shape (B, O) 279 | Predicted protein-ligand binding affinity. B for the number 280 | of protein-ligand pairs in the batch and O for the number of tasks. 281 | """ 282 | ligand_graph = graph[('ligand_atom', 'ligand', 'ligand_atom')] 283 | ligand_graph_node_feats = ligand_graph.ndata['atomic_number'] 284 | assert ligand_graph_node_feats.shape[-1] == 1 285 | ligand_graph_distances = ligand_graph.edata['distance'] 286 | ligand_conv_out = self.ligand_conv(ligand_graph, 287 | ligand_graph_node_feats, 288 | ligand_graph_distances) 289 | 290 | protein_graph = graph[('protein_atom', 'protein', 'protein_atom')] 291 | protein_graph_node_feats = protein_graph.ndata['atomic_number'] 292 | assert protein_graph_node_feats.shape[-1] == 1 293 | protein_graph_distances = protein_graph.edata['distance'] 294 | protein_conv_out = self.protein_conv(protein_graph, 295 | protein_graph_node_feats, 296 | protein_graph_distances) 297 | 298 | complex_graph = dgl.edge_type_subgraph(graph, 299 | [('ligand_atom', 'complex', 'ligand_atom'), 300 | ('ligand_atom', 'complex', 'protein_atom'), 301 | ('protein_atom', 'complex', 'ligand_atom'), 302 | ('protein_atom', 'complex', 'protein_atom')]) 303 | complex_graph = dgl.to_homogeneous( 304 | complex_graph, ndata=['atomic_number'], edata=['distance']) 305 | complex_graph_node_feats = complex_graph.ndata['atomic_number'] 306 | assert complex_graph_node_feats.shape[-1] == 1 307 | complex_graph_distances = complex_graph.edata['distance'] 308 | complex_conv_out = self.complex_conv(complex_graph, 309 | complex_graph_node_feats, 310 | complex_graph_distances) 311 | 312 | frag1_node_indices_in_complex = torch.where(complex_graph.ndata['_TYPE'] == 0)[0] 313 | frag2_node_indices_in_complex = list(set(range(complex_graph.num_nodes())) - 314 | set(frag1_node_indices_in_complex.tolist())) 315 | 316 | return self.predictor( 317 | graph.batch_size, 318 | frag1_node_indices_in_complex, 319 | frag2_node_indices_in_complex, 320 | ligand_conv_out, protein_conv_out, complex_conv_out) 321 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import argparse 5 | import datetime 6 | import numpy as np 7 | import torch.nn as nn 8 | from copy import deepcopy 9 | from datetime import timedelta 10 | from collections import OrderedDict 11 | from collections import defaultdict 12 | from timeit import default_timer as timer 13 | 14 | from utils.logger import Logger 15 | from utils.evaluation import compute_metrics 16 | from utils.pytorch_helper import move_to_device 17 | from data_loader import TwoMoons, GaussianMixture, Amazon, CelebA 18 | from data_loader import SetPDBBind, SetBindingDB 19 | 20 | 21 | class Base_Model(nn.Module): 22 | def __init__(self, hparams): 23 | super().__init__() 24 | self.hparams = hparams 25 | self.hparams.save_path = self.hparams.root_path + self.hparams.model_name 26 | self.load_data() 27 | 28 | def load_data(self): 29 | data_name = self.hparams.data_name 30 | if data_name == 'moons': 31 | self.data = TwoMoons(self.hparams) 32 | elif data_name == 'gaussian': 33 | self.data = GaussianMixture(self.hparams) 34 | elif data_name == 'amazon': 35 | self.data = Amazon(self.hparams) 36 | elif data_name == 'celeba': 37 | self.data = CelebA(self.hparams) 38 | elif data_name == 'pdbbind': 39 | self.data = SetPDBBind(self.hparams) 40 | elif data_name == 'bindingdb': 41 | self.data = SetBindingDB(self.hparams) 42 | else: 43 | raise ValueError("invalid dataset...") 44 | 45 | def configure_optimizers(self): 46 | raise NotImplementedError 47 | 48 | def configure_gradient_clippers(self): 49 | raise NotImplementedError 50 | 51 | def run_training_sessions(self): 52 | logger = Logger(self.hparams.save_path + '.log', on=True) 53 | val_perfs = [] 54 | test_perfs = [] 55 | best_val_perf = float('-inf') 56 | start = timer() 57 | random.seed(self.hparams.seed) # For reproducible random runs 58 | 59 | for run_num in range(1, self.hparams.num_runs + 1): 60 | state_dict, val_perf, test_perf = self.run_training_session(run_num, logger) 61 | val_perfs.append(val_perf) 62 | test_perfs.append(test_perf) 63 | 64 | if val_perf > best_val_perf: 65 | best_val_perf = val_perf 66 | logger.log('----New best {:8.2f}, saving'.format(val_perf)) 67 | torch.save({'hparams': self.hparams, 68 | 'state_dict': state_dict}, self.hparams.save_path) 69 | 70 | logger.log('Time: %s' % str(timedelta(seconds=round(timer() - start)))) 71 | self.load() 72 | if self.hparams.num_runs > 1: 73 | logger.log_perfs(val_perfs) 74 | logger.log('best hparams: ' + self.flag_hparams()) 75 | if self.hparams.auto_repar == False: 76 | logger.log_test_perfs(test_perfs, self.hparams) 77 | 78 | val_perf, test_perf = self.run_test() 79 | logger.log('Val: {:8.2f}'.format(val_perf)) 80 | logger.log('Test: {:8.2f}'.format(test_perf)) 81 | 82 | def run_training_session(self, run_num, logger): 83 | self.train() 84 | 85 | # Scramble hyperparameters if number of runs is greater than 1. 86 | if self.hparams.num_runs > 1: 87 | logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs)) 88 | if self.hparams.auto_repar == True: 89 | for hparam, values in self.get_hparams_grid().items(): 90 | assert hasattr(self.hparams, hparam) 91 | self.hparams.__dict__[hparam] = random.choice(values) 92 | else: 93 | self.hparams.seed = np.random.randint(100000) 94 | 95 | np.random.seed(self.hparams.seed) 96 | random.seed(self.hparams.seed) 97 | torch.manual_seed(self.hparams.seed) 98 | torch.cuda.manual_seed_all(self.hparams.seed) 99 | 100 | self.define_parameters() 101 | logger.log(str(self)) 102 | logger.log('#params: %d' % sum([p.numel() for p in self.parameters()])) 103 | logger.log('hparams: %s' % self.flag_hparams()) 104 | 105 | device = torch.device('cuda' if self.hparams.cuda else 'cpu') 106 | self.to(device) 107 | 108 | optim_energy, optim_var = self.configure_optimizers() 109 | gradient_clippers = self.configure_gradient_clippers() 110 | train_loader, val_loader, test_loader = self.data.get_loaders( 111 | self.hparams.batch_size, self.hparams.num_workers, 112 | shuffle_train=True, get_test=True) 113 | best_val_perf = float('-inf') 114 | best_state_dict = None 115 | forward_sum = defaultdict(float) 116 | num_steps = 0 117 | bad_epochs = 0 118 | 119 | times = [] 120 | try: 121 | for epoch in range(1, self.hparams.epochs + 1): 122 | starttime = datetime.datetime.now() 123 | 124 | for batch_num, batch in enumerate(train_loader): 125 | V_set, S_set, neg_S_set = move_to_device(batch, device) 126 | 127 | if self.hparams.mode != 'diffMF': 128 | # optimize variational distribution (the q distribution) 129 | optim_var.zero_grad() 130 | neg_elbo = self.rec_net(V_set, self.set_func, bs=S_set.size(0)) 131 | neg_elbo.backward() 132 | for params, clip in gradient_clippers: 133 | nn.utils.clip_grad_norm_(params, clip) 134 | optim_var.step() 135 | 136 | if math.isnan(neg_elbo): 137 | logger.log('Stopping epoch because loss is NaN') 138 | break 139 | 140 | # optimize energy function (the p distribution) 141 | optim_energy.zero_grad() 142 | entropy_loss = self.set_func(V_set, S_set, neg_S_set, self.rec_net) 143 | entropy_loss.backward() 144 | for params, clip in gradient_clippers: 145 | nn.utils.clip_grad_norm_(params, clip) 146 | optim_energy.step() 147 | 148 | num_steps += 1 149 | forward_sum['neg_elbo'] += neg_elbo.item() if self.hparams.mode != 'diffMF' else 0 150 | forward_sum['entropy'] += entropy_loss.item() 151 | if math.isnan(entropy_loss): 152 | logger.log('Stopping epoch because loss is NaN') 153 | break 154 | 155 | endtime = datetime.datetime.now() 156 | times.append(endtime - starttime) 157 | 158 | if math.isnan(forward_sum['neg_elbo']) or math.isnan(forward_sum['entropy']): 159 | logger.log('Stopping training session because loss is NaN') 160 | break 161 | 162 | val_perf = self.evaluate(val_loader, device) 163 | logger.log('End of epoch {:3d}'.format(epoch), False) 164 | logger.log(' '.join([' | {:s} {:8.2f}'.format( 165 | key, forward_sum[key] / num_steps) 166 | for key in forward_sum]), False) 167 | logger.log(' | val perf {:8.2f}'.format(val_perf), False) 168 | 169 | if val_perf > best_val_perf: 170 | best_val_perf = val_perf 171 | bad_epochs = 0 172 | logger.log('\t\t*Best model so far, deep copying*') 173 | best_state_dict = deepcopy(self.state_dict()) 174 | test_perf = self.evaluate(test_loader, device) 175 | else: 176 | bad_epochs += 1 177 | logger.log('\t\tBad epoch %d' % bad_epochs) 178 | 179 | if bad_epochs > self.hparams.num_bad_epochs: 180 | break 181 | 182 | except KeyboardInterrupt: 183 | logger.log('-' * 89) 184 | logger.log('Exiting from training early') 185 | 186 | logger.log("time per training epoch: " + str(np.mean(times))) 187 | return best_state_dict, best_val_perf, test_perf 188 | 189 | def evaluate(self, eval_loader, device): 190 | self.eval() 191 | with torch.no_grad(): 192 | perf = compute_metrics(eval_loader, self.inference, self.hparams.v_size, device) 193 | self.train() 194 | return perf 195 | 196 | def run_test(self): 197 | device = torch.device('cuda' if self.hparams.cuda else 'cpu') 198 | _, val_loader, test_loader = self.data.get_loaders(self.hparams.batch_size, 199 | self.hparams.num_workers, shuffle_train=True, get_test=True) 200 | val_perf = self.evaluate(val_loader, device) 201 | test_perf = self.evaluate(test_loader, device) 202 | return val_perf, test_perf 203 | 204 | def load(self): 205 | device = torch.device('cuda' if self.hparams.cuda else 'cpu') 206 | checkpoint = torch.load(self.hparams.save_path) if self.hparams.cuda \ 207 | else torch.load(self.hparams.save_path, 208 | map_location=torch.device('cpu')) 209 | if checkpoint['hparams'].cuda and not self.hparams.cuda: 210 | checkpoint['hparams'].cuda = False 211 | self.hparams = checkpoint['hparams'] 212 | self.define_parameters() 213 | self.load_state_dict(checkpoint['state_dict']) 214 | self.to(device) 215 | 216 | def flag_hparams(self): 217 | flags = '%s' % (self.hparams.model_name) 218 | for hparam in vars(self.hparams): 219 | val = getattr(self.hparams, hparam) 220 | if str(val) == 'False': 221 | continue 222 | elif str(val) == 'True': 223 | flags += ' --%s' % (hparam) 224 | elif str(hparam) in {'model_name', 'data_path', 'num_runs', 225 | 'auto_repar', 'save_path'}: 226 | continue 227 | else: 228 | flags += ' --%s %s' % (hparam, val) 229 | return flags 230 | 231 | @staticmethod 232 | def get_general_hparams_grid(): 233 | grid = OrderedDict({ 234 | 'seed': list(range(100000)), 235 | 'lr': [0.003, 0.001, 0.0005, 0.0001], 236 | 'clip': [1, 5, 10], 237 | 'batch_size': [32, 64, 128], 238 | 'init': [0, 0.5, 0.1, 0.05, 0.01], 239 | }) 240 | return grid 241 | 242 | @staticmethod 243 | def get_general_argparser(): 244 | parser = argparse.ArgumentParser() 245 | 246 | parser.add_argument('model_name', type=str) 247 | parser.add_argument('--data_name', type=str, default='moons', 248 | choices=['moons', 'gaussian', 'amazon', 'celeba', 'pdbbind', 'bindingdb'], 249 | help='name of dataset [%(default)d]') 250 | parser.add_argument('--amazon_cat', type=str, default='toys', 251 | choices=['toys', 'furniture', 'gear', 'carseats', 'bath', 'health', 'diaper', 'bedding', 252 | 'safety', 'feeding', 'apparel', 'media'], 253 | help='category of amazon baby registry dataset [%(default)d]') 254 | parser.add_argument('--root_path', type=str, 255 | default='./') 256 | parser.add_argument('--train', action='store_true', 257 | help='train a model?') 258 | parser.add_argument('--auto_repar', action='store_true', 259 | help='use auto parameterization?') 260 | 261 | parser.add_argument('--v_size', type=int, default=30, 262 | help='size of ground set [%(default)d]') 263 | parser.add_argument('--s_size', type=int, default=10, 264 | help='size of subset [%(default)d]') 265 | parser.add_argument('--num_layers', type=int, default=2, 266 | help='num layers [%(default)d]') 267 | 268 | parser.add_argument('--batch_size', type=int, default=4, 269 | help='batch size [%(default)d]') 270 | parser.add_argument('--lr', type=float, default=0.0001, 271 | help='initial learning rate [%(default)g]') 272 | parser.add_argument("--weight_decay", type=float, default=1e-5, 273 | help='weight decay rate [%(default)g]') 274 | parser.add_argument('--init', type=float, default=0.05, 275 | help='unif init range (default if 0) [%(default)g]') 276 | parser.add_argument('--clip', type=float, default=10, 277 | help='gradient clipping [%(default)g]') 278 | parser.add_argument('--epochs', type=int, default=100, 279 | help='max number of epochs [%(default)d]') 280 | parser.add_argument('--num_runs', type=int, default=1, 281 | help='num random runs (not random if 1) ' 282 | '[%(default)d]') 283 | 284 | parser.add_argument('--num_bad_epochs', type=int, default=6, 285 | help='num indulged bad epochs [%(default)d]') 286 | parser.add_argument('--num_workers', type=int, default=2, 287 | help='num dataloader workers [%(default)d]') 288 | parser.add_argument('--cuda', action='store_true', 289 | help='use CUDA?') 290 | parser.add_argument('--seed', type=int, default=50971, 291 | help='random seed [%(default)d]') 292 | 293 | return parser 294 | -------------------------------------------------------------------------------- /model/celebaCNN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class celebaCNN(nn.Sequential): 5 | def __init__(self): 6 | super(celebaCNN, self).__init__() 7 | 8 | in_ch = [3] + [32,64,128] 9 | kernels = [3,4,5] 10 | strides = [2,2,2] 11 | layer_size = 3 12 | self.conv = nn.ModuleList([nn.Conv2d(in_channels = in_ch[i], 13 | out_channels = in_ch[i+1], 14 | kernel_size = kernels[i], 15 | stride = strides[i]) for i in range(layer_size)]) 16 | self.conv = self.conv.double() 17 | self.fc1 = nn.Linear(128, 256) 18 | 19 | def _forward_features(self, x): 20 | for l in self.conv: 21 | x = F.relu(l(x)) 22 | x = F.adaptive_max_pool2d(x, output_size=1) 23 | return x 24 | 25 | def forward(self, v): 26 | v = self._forward_features(v.double()) 27 | v = v.view(v.size(0), -1) 28 | v = self.fc1(v.float()) 29 | return v -------------------------------------------------------------------------------- /model/deepDTA.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/mims-harvard/TDC/blob/main/examples/multi_pred/dti_dg/domainbed/networks.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class Identity(nn.Module): 8 | """An identity layer""" 9 | def __init__(self): 10 | super(Identity, self).__init__() 11 | 12 | def forward(self, x): 13 | return x 14 | 15 | class MLP(nn.Module): 16 | """Just an MLP""" 17 | def __init__(self, n_inputs, n_outputs, hparams): 18 | super(MLP, self).__init__() 19 | self.input = nn.Linear(n_inputs, hparams['mlp_width']) 20 | self.dropout = nn.Dropout(hparams['mlp_dropout']) 21 | self.hiddens = nn.ModuleList([ 22 | nn.Linear(hparams['mlp_width'], hparams['mlp_width']) 23 | for _ in range(hparams['mlp_depth']-2)]) 24 | self.output = nn.Linear(hparams['mlp_width'], n_outputs) 25 | self.n_outputs = n_outputs 26 | 27 | def forward(self, x): 28 | x = self.input(x) 29 | x = self.dropout(x) 30 | x = F.relu(x) 31 | for hidden in self.hiddens: 32 | x = hidden(x) 33 | x = self.dropout(x) 34 | x = F.relu(x) 35 | x = self.output(x) 36 | return x 37 | 38 | class CNN(nn.Sequential): 39 | def __init__(self, encoding): 40 | super(CNN, self).__init__() 41 | if encoding == 'drug': 42 | in_ch = [41] + [32,64,96] 43 | kernels = [4,6,8] 44 | layer_size = 3 45 | self.conv = nn.ModuleList([nn.Conv1d(in_channels = in_ch[i], 46 | out_channels = in_ch[i+1], 47 | kernel_size = kernels[i]) for i in range(layer_size)]) 48 | self.conv = self.conv.double() 49 | n_size_d = self._get_conv_output((41, 100)) 50 | self.fc1 = nn.Linear(n_size_d, 256) 51 | elif encoding == 'protein': 52 | in_ch = [20] + [32,64,96] 53 | kernels = [4,8,12] 54 | layer_size = 3 55 | self.conv = nn.ModuleList([nn.Conv1d(in_channels = in_ch[i], 56 | out_channels = in_ch[i+1], 57 | kernel_size = kernels[i]) for i in range(layer_size)]) 58 | self.conv = self.conv.double() 59 | n_size_p = self._get_conv_output((20, 1000)) 60 | self.fc1 = nn.Linear(n_size_p, 256) 61 | 62 | def _get_conv_output(self, shape): 63 | bs = 1 64 | input = Variable(torch.rand(bs, *shape)) 65 | output_feat = self._forward_features(input.double()) 66 | n_size = output_feat.data.view(bs, -1).size(1) 67 | return n_size 68 | 69 | def _forward_features(self, x): 70 | for l in self.conv: 71 | x = F.relu(l(x)) 72 | x = F.adaptive_max_pool1d(x, output_size=1) 73 | return x 74 | 75 | def forward(self, v): 76 | v = self._forward_features(v.double()) 77 | v = v.view(v.size(0), -1) 78 | v = self.fc1(v.float()) 79 | return v 80 | 81 | class DeepDTA_Encoder(nn.Sequential): 82 | def __init__(self): 83 | super(DeepDTA_Encoder, self).__init__() 84 | self.input_dim_drug = 256 85 | self.input_dim_protein = 256 86 | 87 | self.model_drug = CNN('drug') 88 | self.model_protein = CNN('protein') 89 | self.predictor = nn.Linear(self.input_dim_drug + self.input_dim_protein, 256) 90 | 91 | def forward(self, V): 92 | v_D, v_P = V 93 | # each encoding 94 | v_D = self.model_drug(v_D) 95 | v_P = self.model_protein(v_P) 96 | # concatenate and output feature 97 | v_f = torch.cat((v_D, v_P), 1) 98 | v_f = self.predictor(v_f) 99 | return v_f 100 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from model.acnn import ACNN 7 | from utils.config import ACNN_CONFIG 8 | from model.celebaCNN import celebaCNN 9 | from model.deepDTA import DeepDTA_Encoder 10 | from utils.pytorch_helper import FF, normal_cdf 11 | 12 | 13 | class SetFuction(nn.Module): 14 | def __init__(self, params): 15 | super(SetFuction, self).__init__() 16 | self.params = params 17 | self.dim_feature = 256 18 | 19 | self.init_layer = self.define_init_layer() 20 | self.ff = FF(self.dim_feature, 500, 1, self.params.num_layers) 21 | 22 | def define_init_layer(self): 23 | data_name = self.params.data_name 24 | if data_name == 'moons': 25 | return nn.Linear(2, self.dim_feature) 26 | elif data_name == 'gaussian': 27 | return nn.Linear(2, self.dim_feature) 28 | elif data_name == 'amazon': 29 | return nn.Linear(768, self.dim_feature) 30 | elif data_name == 'celeba': 31 | return celebaCNN() 32 | elif data_name == 'pdbbind': 33 | return ACNN(hidden_sizes=ACNN_CONFIG['hidden_sizes'], 34 | weight_init_stddevs=ACNN_CONFIG['weight_init_stddevs'], 35 | dropouts=ACNN_CONFIG['dropouts'], 36 | features_to_use=ACNN_CONFIG['atomic_numbers_considered'], 37 | radial=ACNN_CONFIG['radial']) 38 | elif data_name == 'bindingdb': 39 | return DeepDTA_Encoder() 40 | else: 41 | raise ValueError("invalid dataset...") 42 | 43 | def MC_sampling(self, q, M): 44 | """ 45 | Bernoulli sampling using q as parameters. 46 | Args: 47 | q: paramter of Bernoulli distribution 48 | M: number of samples 49 | 50 | Returns: 51 | Sampled subsets F(S+i), F(S) 52 | 53 | """ 54 | bs, vs = q.shape 55 | q = q.reshape(bs, 1, 1, vs).expand(bs, M, vs, vs) 56 | sample_matrix = torch.bernoulli(q) 57 | 58 | mask = torch.cat([torch.eye(vs, vs).unsqueeze(0) for _ in range(M)], dim=0).unsqueeze(0).to(q.device) 59 | matrix_0 = sample_matrix * (1 - mask) 60 | matrix_1 = matrix_0 + mask 61 | return matrix_1, matrix_0 62 | 63 | def mean_field_iteration(self, V, subset_i, subset_not_i): 64 | F_1 = self.F_S(V, subset_i, fpi=True).squeeze(-1) 65 | F_0 = self.F_S(V, subset_not_i, fpi=True).squeeze(-1) 66 | q = torch.sigmoid((F_1 - F_0).mean(1)) 67 | return q 68 | 69 | def cross_entropy(self, q, S, neg_S): 70 | loss = - torch.sum((S * torch.log(q + 1e-12) + (1 - S) * torch.log(1 - q + 1e-12)) * neg_S, dim=-1) 71 | return loss.mean() 72 | 73 | def forward(self, V, S, neg_S, rec_net): # return cross-entropy loss 74 | if self.params.mode == 'diffMF': 75 | bs, vs = V.shape[:2] 76 | q = .5 * torch.ones(bs, vs).to(V.device) 77 | else: 78 | # mode == 'ind' or 'copula' 79 | q = rec_net.get_vardist(V, S.shape[0]).detach() # notice the detach here 80 | 81 | for i in range(self.params.RNN_steps): 82 | sample_matrix_1, sample_matrix_0 = self.MC_sampling(q, self.params.num_samples) 83 | q = self.mean_field_iteration(V, sample_matrix_1, sample_matrix_0) 84 | 85 | loss = self.cross_entropy(q, S, neg_S) 86 | return loss 87 | 88 | def F_S(self, V, subset_mat, fpi=False): 89 | if fpi: 90 | # to fix point iteration (aka mean-field iteration) 91 | fea = self.init_layer(V).reshape(subset_mat.shape[0], 1, -1, self.dim_feature) 92 | else: 93 | # to encode variational dist 94 | fea = self.init_layer(V).reshape(subset_mat.shape[0], -1, self.dim_feature) 95 | fea = subset_mat @ fea 96 | fea = self.ff(fea) 97 | return fea 98 | 99 | 100 | class RecNet(nn.Module): 101 | def __init__(self, params): 102 | super(RecNet, self).__init__() 103 | self.params = params 104 | self.dim_feature = 256 105 | num_layers = self.params.num_layers 106 | 107 | self.init_layer = self.define_init_layer() 108 | self.ff = FF(self.dim_feature, 500, 500, num_layers - 1 if num_layers > 0 else 0) 109 | self.h_to_mu = nn.Linear(500, 1) 110 | if self.params.mode == 'copula': 111 | self.h_to_std = nn.Linear(500, 1) 112 | self.h_to_U = nn.ModuleList([nn.Linear(500, 1) for i in range(self.params.rank)]) 113 | 114 | def define_init_layer(self): 115 | data_name = self.params.data_name 116 | if data_name == 'moons': 117 | return nn.Linear(2, self.dim_feature) 118 | elif data_name == 'gaussian': 119 | return nn.Linear(2, self.dim_feature) 120 | elif data_name == 'amazon': 121 | return nn.Linear(768, self.dim_feature) 122 | elif data_name == 'celeba': 123 | return celebaCNN() 124 | elif data_name == 'pdbbind': 125 | return ACNN(hidden_sizes=ACNN_CONFIG['hidden_sizes'], 126 | weight_init_stddevs=ACNN_CONFIG['weight_init_stddevs'], 127 | dropouts=ACNN_CONFIG['dropouts'], 128 | features_to_use=ACNN_CONFIG['atomic_numbers_considered'], 129 | radial=ACNN_CONFIG['radial']) 130 | elif data_name == 'bindingdb': 131 | return DeepDTA_Encoder() 132 | else: 133 | raise ValueError("invalid dataset...") 134 | 135 | def encode(self, V, bs): 136 | """ 137 | 138 | Args: 139 | V: the ground set. [batch_size, v_size, fea_dim] 140 | bs: batch_size 141 | 142 | Returns: 143 | ber: predicted probabilities. [batch_size, v_size] 144 | std: the diagonal matrix D [batch_size, v_size] 145 | u_perturbation: the low rank perturbation matrix [batch_size, v_size, rank] 146 | 147 | """ 148 | fea = self.init_layer(V).reshape(bs, -1, self.dim_feature) 149 | h = torch.relu(self.ff(fea)) 150 | ber = torch.sigmoid(self.h_to_mu(h)).squeeze(-1) # [batch_size, v_size] 151 | 152 | if self.params.mode == 'copula': 153 | std = F.softplus(self.h_to_std(h)).squeeze(-1) # [batch_size, v_size] 154 | rs = [] 155 | for i in range(self.params.rank): 156 | rs.append(torch.tanh(self.h_to_U[i](h))) 157 | u_perturbation = torch.cat(rs, -1) # [batch_size, v_size, rank] 158 | 159 | return ber, std, u_perturbation 160 | return ber, None, None 161 | 162 | def MC_sampling(self, ber, std, u_pert, M): 163 | """ 164 | Sampling using CopulaBernoulli 165 | 166 | Args: 167 | ber: location parameter (0, 1) [batch_size, v_size] 168 | std: standard deviation (0, +infinity) [batch_size, v_size] 169 | u_pert: lower rank perturbation (-1, 1) [batch_size, v_size, rank] 170 | M: number of MC approximation 171 | 172 | Returns: 173 | Sampled subsets 174 | """ 175 | bs, vs = ber.shape 176 | 177 | if self.params.mode == 'copula': 178 | eps = torch.randn((bs, M, vs)).to(ber.device) 179 | eps_corr = torch.randn((bs, M, self.params.rank, 1)).to(ber.device) 180 | g = eps * std.unsqueeze(1) + torch.matmul(u_pert.unsqueeze(1), eps_corr).squeeze(-1) 181 | u = normal_cdf(g, 0, 1) 182 | else: 183 | # mode == 'ind' 184 | u = torch.rand((bs, M, vs)).to(ber.device) 185 | 186 | ber = ber.unsqueeze(1) 187 | l = torch.log(ber + 1e-12) - torch.log(1 - ber + 1e-12) + \ 188 | torch.log(u + 1e-12) - torch.log(1 - u + 1e-12) 189 | 190 | prob = torch.sigmoid(l / self.params.tau) 191 | r = torch.bernoulli(prob) # binary vector 192 | s = prob + (r - prob).detach() # straight through estimator 193 | return s 194 | 195 | def cal_elbo(self, V, sample_mat, set_func, q): 196 | f_mt = set_func.F_S(V, sample_mat).squeeze(-1).mean(-1) 197 | entropy = - torch.sum(q * torch.log(q + 1e-12) + (1 - q) * torch.log(1 - q + 1e-12), dim=-1) 198 | elbo = f_mt + entropy 199 | return elbo.mean() 200 | 201 | def forward(self, V, set_func, bs): # return negative ELBO 202 | ber, std, u_perturbation = self.encode(V, bs) 203 | sample_mat = self.MC_sampling(ber, std, u_perturbation, self.params.num_samples) 204 | elbo = self.cal_elbo(V, sample_mat, set_func, ber) 205 | return -elbo 206 | 207 | def get_vardist(self, V, bs): 208 | fea = self.init_layer(V).reshape(bs, -1, self.dim_feature) 209 | h = torch.relu(self.ff(fea)) 210 | ber = torch.sigmoid(self.h_to_mu(h)).squeeze(-1) 211 | return ber 212 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | MOONS_CONFIG = { 5 | 'data_name': 'moons', 6 | 'v_size': 100, 7 | 's_size': 10, 8 | 'batch_size': 128 9 | } 10 | 11 | GAUSSIAN_CONFIG = { 12 | 'data_name': 'gaussian', 13 | 'v_size': 100, 14 | 's_size': 10, 15 | 'batch_size': 128 16 | } 17 | 18 | AMAZON_CONFIG = { 19 | 'data_name': 'amazon', 20 | 'v_size': 30, 21 | 'batch_size': 128 22 | } 23 | 24 | CELEBA_CONFIG = { 25 | 'data_name': 'celeba', 26 | 'v_size': 8, 27 | 'batch_size': 128 28 | } 29 | 30 | PDBBIND_CONFIG = { 31 | 'data_name': 'pdbbind', 32 | 'v_size': 30, 33 | 's_size': 5, 34 | 'batch_size': 32 35 | } 36 | 37 | BINDINGDB_CONFIG = { 38 | 'data_name': 'bindingdb', 39 | 'v_size': 300, 40 | 's_size': 15, 41 | 'batch_size': 4 42 | } 43 | 44 | ACNN_CONFIG = { 45 | 'hidden_sizes': [32, 32, 16], 46 | 'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)), 47 | 1. / float(np.sqrt(16)), 0.01], 48 | 'dropouts': [0., 0., 0.], 49 | 'atomic_numbers_considered': torch.tensor([ 50 | 1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]), 51 | 'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]], 52 | } -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.pytorch_helper import set_value_according_index, move_to_device 4 | 5 | def compute_metrics(loader, infer_func, v_size, device): 6 | jc_list = [] 7 | for batch_num, batch in enumerate(loader): 8 | V_set, S_set = move_to_device(batch, device) 9 | 10 | q = infer_func(V_set, S_set.shape[0]) 11 | _, idx = torch.topk(q, S_set.shape[-1], dim=1, largest=True) 12 | 13 | pre_list = [] 14 | for i in range(len(idx)): 15 | pre_mask = torch.zeros([S_set.shape[-1]]).to(device) 16 | ids = idx[i][:int(torch.sum(S_set[i]).item())] 17 | pre_mask[ids] = 1 18 | pre_list.append(pre_mask.unsqueeze(0)) 19 | pre_mask = torch.cat(pre_list, dim=0) 20 | true_mask = S_set 21 | 22 | intersection = true_mask * pre_mask 23 | union = true_mask + pre_mask - intersection 24 | jc = intersection.sum(dim=-1) / union.sum(dim=-1) 25 | jc_list.append(jc) 26 | 27 | jca = torch.cat(jc_list, 0).mean(0).item() 28 | return jca * 100 29 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import statistics as stat 4 | import sys 5 | 6 | class Logger(object): 7 | 8 | def __init__(self, log_path, on=True): 9 | self.log_path = log_path 10 | self.on = on 11 | # self.on = False 12 | 13 | if self.on: 14 | while os.path.isfile(self.log_path): 15 | self.log_path += '+' 16 | 17 | def log(self, string, newline=True): 18 | if self.on: 19 | with open(self.log_path, 'a') as logf: 20 | logf.write(string) 21 | if newline: logf.write('\n') 22 | 23 | sys.stdout.write(string) 24 | if newline: sys.stdout.write('\n') 25 | sys.stdout.flush() 26 | 27 | def log_perfs(self, perfs): 28 | valid_perfs = [perf for perf in perfs if not math.isinf(perf)] 29 | best_perf = max(valid_perfs) 30 | self.log('-' * 89) 31 | self.log('%d perfs: %s' % (len(perfs), str(perfs))) 32 | self.log('perf max: %g' % best_perf) 33 | self.log('perf min: %g' % min(valid_perfs)) 34 | self.log('perf avg: %g' % stat.mean(valid_perfs)) 35 | self.log('perf std: %g' % (stat.stdev(valid_perfs) 36 | if len(valid_perfs) > 1 else 0.0)) 37 | self.log('(excluded %d out of %d runs that produced -inf)' % 38 | (len(perfs) - len(valid_perfs), len(perfs))) 39 | self.log('-' * 89) 40 | 41 | def log_test_perfs(self, perfs, params): 42 | valid_perfs = [perf for perf in perfs if not math.isinf(perf)] 43 | mean = stat.mean(valid_perfs) 44 | std = stat.stdev(valid_perfs if len(valid_perfs) > 1 else 0.0) 45 | 46 | if self.on: 47 | name = params.root_path + 'EquiVSet.log' 48 | with open(name, 'a') as logf: 49 | logf.write( params.model_name + ': ' + '\n') 50 | logf.write(str(perfs) + '\n') 51 | string = f'avg: {mean} ' + f'std: {std}\n' 52 | logf.write(string) 53 | self.log('-' * 89) 54 | self.log('%d perfs: %s' % (len(perfs), str(perfs))) 55 | self.log('perf avg: %g' % mean) 56 | self.log('perf std: %g' % std) 57 | self.log('-' * 89) -------------------------------------------------------------------------------- /utils/pytorch_helper.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import math 3 | import torch 4 | import pickle 5 | import torch.nn as nn 6 | 7 | 8 | def set_value_according_index(tensor, idx, value): 9 | mask_val = torch.ones(idx.shape).to(tensor.device) * value 10 | tensor.scatter_(1, idx, mask_val) # fill the values along dimension 1 11 | return tensor 12 | 13 | def normal_cdf(value, loc, scale): 14 | return 0.5 * (1 + torch.erf( (value - loc) / (scale * math.sqrt(2)) )) 15 | 16 | def get_init_function(init_value): 17 | def init_function(m): 18 | if init_value > 0.: 19 | if hasattr(m, 'weight'): 20 | m.weight.data.uniform_(-init_value, init_value) 21 | if hasattr(m, 'bias'): 22 | m.bias.data.fill_(0.) 23 | 24 | return init_function 25 | 26 | 27 | def move_to_device(obj, device): 28 | if torch.is_tensor(obj): 29 | return obj.to(device) 30 | elif isinstance(obj, dgl.DGLGraph): 31 | return obj.to(device) 32 | elif isinstance(obj, dict): 33 | res = {} 34 | for k, v in obj.items(): 35 | res[k] = move_to_device(v, device) 36 | return res 37 | elif isinstance(obj, list): 38 | res = [] 39 | for v in obj: 40 | res.append(move_to_device(v, device)) 41 | return res 42 | elif isinstance(obj, tuple): 43 | res = () 44 | for v in obj: 45 | res += (move_to_device(v, device),) 46 | return res 47 | else: 48 | raise TypeError("Invalid type for move_to_device") 49 | 50 | class FF(nn.Module): 51 | def __init__(self, dim_input, dim_hidden, dim_output, num_layers, 52 | activation='relu', dropout_rate=0, layer_norm=False, 53 | residual_connection=False): 54 | super().__init__() 55 | 56 | assert num_layers >= 0 # 0 = Linear 57 | if num_layers > 0: 58 | assert dim_hidden > 0 59 | if residual_connection: 60 | assert dim_hidden == dim_input 61 | 62 | self.residual_connection = residual_connection 63 | self.stack = nn.ModuleList() 64 | for l in range(num_layers): 65 | layer = [] 66 | 67 | if layer_norm: 68 | layer.append(nn.LayerNorm(dim_input if l == 0 else dim_hidden)) 69 | 70 | layer.append(nn.Linear(dim_input if l == 0 else dim_hidden, 71 | dim_hidden)) 72 | layer.append({'tanh': nn.Tanh(), 'relu': nn.ReLU()}[activation]) 73 | 74 | if dropout_rate > 0: 75 | layer.append(nn.Dropout(dropout_rate)) 76 | 77 | self.stack.append(nn.Sequential(*layer)) 78 | 79 | self.out = nn.Linear(dim_input if num_layers < 1 else dim_hidden, 80 | dim_output) 81 | 82 | def forward(self, x): 83 | for layer in self.stack: 84 | x = x + layer(x) if self.residual_connection else layer(x) 85 | return self.out(x) 86 | 87 | def read_from_pickle(filename): 88 | filename = filename + '.pickle' 89 | with open(filename, 'rb') as f: 90 | x = pickle.load(f) 91 | return x 92 | 93 | def find_not_in_set(U,S): 94 | Ind = torch.ones(U.shape[0],dtype=bool) 95 | Ind[S] = False 96 | return U[Ind] 97 | --------------------------------------------------------------------------------