├── .gitignore
├── README.md
├── assets
└── equivset.png
├── data_loader
├── __init__.py
├── amazon.py
├── bindingdb.py
├── celeba.py
├── gaussian.py
├── moons.py
├── pdbbind.py
└── set_pdbbind.py
├── main.py
├── model
├── EquiVSet.py
├── acnn.py
├── base_model.py
├── celebaCNN.py
├── deepDTA.py
└── modules.py
└── utils
├── config.py
├── evaluation.py
├── logger.py
└── pytorch_helper.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .idea/
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
:fire:EquiVSet:fire:
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | This repo contains PyTorch implementation of the paper "[Learning Neural Set Functions Under the Optimal Subset Oracle](https://arxiv.org/abs/2203.01693)"
12 |
13 | by [Zijing Ou](https://j-zin.github.io/), [Tingyang Xu](https://scholar.google.com.hk/citations?user=6gIs5YMAAAAJ&hl=en), [Qinliang Su](https://scholar.google.com/citations?user=cuIweygAAAAJ&hl=en), [Yingzhen Li](http://yingzhenli.net/home/en/), [Peilin Zhao](https://peilinzhao.github.io/), and [Yatao Bian](https://yataobian.com/).
14 |
15 | > We propose a way to learn set functions when the optimal subsets are given from an optimal subset oracle. This setting is different to other works that learn set functions from the function value oracle that provide utility values for each specific subset. Thus this setting is arguably more practically important but is surprisingly overlooked by previous works. To learn set functions under the optimal subset oracle, we propose to cast the problem into maximum likelihood estimation by replacing the utility function with an energy-based model such that it is proportional to the utility value, and satisfies some desiderata for set functions (e.g., permutation invariance, etc). Then mean-field variational inference and its amortized variants are proposed to learn EBMs on the sets. We evaluate our approach in a wide range of applications, including product recommendation, set anomaly detection, and compound selection in AI-aided drug discovery. The empirical results show our approach is promising.
16 |
17 | 
18 | Overview of the training and inference processes of EquiVSet.
19 |
20 | ## Installation
21 |
22 | Please ensure that:
23 |
24 | - Python >= 3.6
25 | - PyTorch >= 1.8.0
26 | - dgl >= 0.7.0
27 |
28 | The following pakages are needed if you want to run the `compound selection` experiments:
29 |
30 | - **rdkit**: We recommend installing it with `conda install -c rdkit rdkit==2018.09.3`. For other installation recipes, see the [official documentation](https://www.rdkit.org/docs/Install.html).
31 | - **dgllife**: We recommend installing it with `pip install dgllife`. More information is available in the [official documentation](https://lifesci.dgl.ai/install/index.html).
32 | - **tdc**: We recommend installing it with `pip install PyTDC`. See the [official documentation](https://tdc.readthedocs.io/en/main/install.html) for more information.
33 |
34 | We provide step-by-step installation commands as follows:
35 |
36 | ```
37 | conda create -n EquiVSet python=3.7
38 | source activate EquiVSet
39 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
40 | pip install dgl-cu110 dglgo -f https://data.dgl.ai/wheels/repo.html
41 |
42 | # The following commands are used for compound selection:
43 | conda install -c rdkit rdkit==2018.09.3
44 | pip install dgllife
45 | pip install PyTDC
46 | ```
47 |
48 | ## Datasets
49 | For the experiments, we use the following datasets:
50 |
51 | - [Amazon baby registry dataset](https://www.kaggle.com/datasets/roopalik/amazon-baby-dataset) for the `product recommendation` experiments. The dataset is available [here](https://drive.google.com/file/d/1OLbCOTsRyowxw3_AzhxJPVB8VAgjt2Y6/view?usp=sharing).
52 | - [CelebA dataset](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) fot the `set anomaly detection` experiments. The images are available [here](https://drive.google.com/file/d/0B7EVK8r0v71pZjFTYXZWM3FlRnM/view?usp=sharing&resourcekey=0-dYn9z10tMJOBAkviAcfdyQ) and the attribute labels are available [here](https://drive.google.com/file/d/0B7EVK8r0v71pblRyaVFSWGxPY0U/view?usp=sharing&resourcekey=0-YW2qIuRcWHy_1C2VaRGL3Q).
53 | - [PDBBind](http://www.pdbbind.org.cn/) and [BindingDB](https://www.bindingdb.org/bind/index.jsp) for the `compubd selection` experiments. The PDBBind dataset is available [here](http://www.pdbbind.org.cn/index.php?newsid=20#news_section) and the BindingDB dataset is available [here](https://www.bindingdb.org/bind/index.jsp).
54 |
55 | For all experiments, the dataset is automatically downloaded and preprocessed when you run the corresponding code. You could also download the dataset manually using the link provided.
56 |
57 | ## Experiments
58 |
59 | This repository implements the synthetic experiments (section 6), product recommendation (section 6), set anomaly detection (section 6), and compound selection (section 6).
60 |
61 | ### Synthetic Experiments
62 |
63 | To run on the Two-Moons and Gaussian-Mixture dataset
64 | ```
65 | python main.py equivset --train --cuda --data_name
66 | ```
67 | `dataset_name` is chosen in ['moons', 'gaussian'].
68 | We also provide the Jupyter notebook [](https://colab.research.google.com/drive/1_EI0BUjFzNAVxWS1ao-xia_UVmW4KLi4?usp=sharing) to run the synthetic experiments.
69 |
70 |
71 | ### Product Recommendation
72 |
73 | To run on the Amazon baby registry dataset
74 | ```
75 | python main.py equivset --train --cuda --data_name amazon --amazon_cat
76 | ```
77 | `category_name` is chosen in ['toys', 'furniture', 'gear', 'carseats', 'bath', 'health', 'diaper', 'bedding', 'safety', 'feeding', 'apparel', 'media'].
78 |
79 | ### Set Anomaly Detection
80 |
81 | To run on the CelebA dataset
82 | ```
83 | python main.py equivset --train --cuda --data_name celeba
84 | ```
85 |
86 | ### Compound Selection
87 |
88 | To run on the PDBBind and BindingDB dataset
89 | ```
90 | python main.py equivset --train --cuda --data_name
91 | ```
92 | `dataset_name` is chosen in ['pdbbind', 'bindingdb'].
93 |
94 | ## Citation
95 |
96 | :smile:If you find this repo is useful, please consider to cite our paper:
97 | ```
98 | @article{ou2022learning,
99 | title={Learning Set Functions Under the Optimal Subset Oracle via Equivariant Variational Inference},
100 | author={Ou, Zijing and Xu, Tingyang and Su, Qinliang and Li, Yingzhen and Zhao, Peilin and Bian, Yatao},
101 | journal={arXiv preprint arXiv:2203.01693},
102 | year={2022}
103 | }
104 | ```
105 |
--------------------------------------------------------------------------------
/assets/equivset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SubsetSelection/EquiVSet/76713b8e61279639261fb8554d4223f86b84b6c0/assets/equivset.png
--------------------------------------------------------------------------------
/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 | from .moons import TwoMoons
2 | from .gaussian import GaussianMixture
3 | from .amazon import Amazon
4 | from .celeba import CelebA
5 | from .set_pdbbind import SetPDBBind
6 | from .bindingdb import SetBindingDB
--------------------------------------------------------------------------------
/data_loader/amazon.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import gdown
4 | import torch
5 | import zipfile
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | from utils.pytorch_helper import read_from_pickle, find_not_in_set
10 |
11 | class Data:
12 | def __init__(self, params):
13 | self.params = params
14 |
15 | def gen_datasets(self):
16 | raise NotImplementedError
17 |
18 | def get_loaders(self, batch_size, num_workers, shuffle_train=False,
19 | get_test=True):
20 | raise NotImplementedError
21 |
22 | class Amazon(Data):
23 | def __init__(self, params):
24 | super().__init__(params)
25 | data_root = self.download_amazon()
26 |
27 | torch.manual_seed(1) # fix dataset
28 | self.gen_datasets(data_root)
29 |
30 | def download_amazon(self):
31 | url = 'https://drive.google.com/uc?id=1OLbCOTsRyowxw3_AzhxJPVB8VAgjt2Y6'
32 | data_root = '/root/dataset/amazon'
33 | download_path = f'{data_root}/amazon_baby_registry.zip'
34 | if not os.path.exists(data_root):
35 | os.makedirs(data_root)
36 | if not os.listdir(data_root):
37 | gdown.download(url, download_path, quiet=False)
38 | with zipfile.ZipFile(download_path, 'r') as ziphandler:
39 | ziphandler.extractall(data_root)
40 | return data_root
41 |
42 | def read_real_data(self, data_root):
43 | pickle_filename = data_root + '/' + self.params.amazon_cat
44 | dataset_ = read_from_pickle(pickle_filename)
45 | for i in range(len(dataset_)):
46 | dataset_[i+1] = torch.tensor(dataset_[i+1])
47 | data_ = torch.zeros(len(dataset_), dataset_[1].shape[0])
48 | for i in range(len(dataset_)):
49 | data_[i,:] = dataset_[i+1]
50 |
51 | csv_filename = data_root + '/' + self.params.amazon_cat + '.csv'
52 | with open(csv_filename, newline='') as csvfile:
53 | reader = csv.reader(csvfile, delimiter=',', quoting=csv.QUOTE_NONNUMERIC,quotechar='|')
54 | S = {}
55 | i=-1
56 | for row in reader:
57 | i=i+1
58 | S[i] = torch.tensor([int(row[x]) for x in range(len(row))]).long()
59 | return data_ , S
60 |
61 | def filter_S(self, data, S):
62 | S_list = []
63 | V_list = []
64 | for i in range(len(S)):
65 | if S[i].shape[0]>2 and S[i].shape[0] < self.params.v_size:
66 | Svar = S[i] - 1 # index starts froms 0
67 | sub_set, ground_set= self.construct_ground_set(data, Svar, V=self.params.v_size)
68 | S_list.append(sub_set); V_list.append(ground_set)
69 | S = S_list
70 | U = V_list
71 | return U, S
72 |
73 | def construct_ground_set(self, data, S, V):
74 | S_data = data[S]
75 | S_mean = S_data.mean(dim=0).unsqueeze(0)
76 | UnotS_data = find_not_in_set(data, S)
77 | S_mean_norm = F.normalize(S_mean, dim=-1)
78 | UnotS_data_norm = F.normalize(UnotS_data, dim=-1)
79 |
80 | cos_sim = (S_mean_norm @ UnotS_data_norm.T).squeeze(0)
81 | _, idx = torch.sort(cos_sim)
82 | UnotS_idx = idx[:V-S.shape[0]]
83 | UnotS_data = UnotS_data[UnotS_idx]
84 |
85 | S = torch.randperm(V)[:S.shape[0]]
86 | UnotS_idx = torch.ones(V,dtype=bool)
87 | UnotS_idx[S] = False
88 |
89 | U = torch.zeros([V, data.shape[-1]])
90 | U[S] = S_data
91 | U[UnotS_idx] = UnotS_data
92 |
93 | return S, U
94 |
95 | def split_into_training_test(self, data_mat, S):
96 | folds = [0.33, 0.33, 0.33]
97 | num_elem = len(data_mat)
98 | tr_size = int(folds[0]* num_elem)
99 | dev_size = int((folds[1]+folds[0])* num_elem)
100 | test_size = num_elem
101 |
102 | V_train = data_mat[0:tr_size]
103 | V_dev = data_mat[tr_size:dev_size]
104 | V_test = data_mat[dev_size:test_size]
105 |
106 | S_train = S[0:tr_size]
107 | S_dev = S[tr_size:dev_size]
108 | S_test = S[dev_size:test_size]
109 |
110 | V_sets = (V_train, V_dev, V_test)
111 | S_sets = (S_train, S_dev, S_test)
112 | return V_sets, S_sets
113 |
114 | def gen_datasets(self, data_root):
115 | data, S = self.read_real_data(data_root)
116 | data, S = self.filter_S(data, S)
117 | V_sets, S_sets = self.split_into_training_test(data, S)
118 |
119 | self.V_train, self.V_val, self.V_test = V_sets
120 | self.S_train, self.S_val, self.S_test = S_sets
121 |
122 | self.fea_size = self.V_train[0].shape[-1]
123 |
124 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True):
125 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True)
126 | val_dataset = SetDataset(self.V_val, self.S_val, self.params)
127 | test_dataset = SetDataset(self.V_test, self.S_test, self.params)
128 |
129 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
130 | shuffle=shuffle_train, num_workers=num_workers)
131 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
132 | shuffle=False, num_workers=num_workers)
133 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
134 | shuffle=False, num_workers=num_workers) if get_test else None
135 | return train_loader, val_loader, test_loader
136 |
137 | class SetDataset(Dataset):
138 | def __init__(self, V, S, params, is_train=False):
139 | self.data = V
140 | self.labels = S
141 | self.is_train = is_train
142 | self.neg_num = params.neg_num
143 | self.v_size = params.v_size
144 |
145 | def __getitem__(self, index):
146 | V = self.data[index]
147 | S = self.labels[index]
148 |
149 | S_mask = torch.zeros([self.v_size])
150 | S_mask[S] = 1
151 | if self.is_train:
152 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0]
153 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]]
154 | neg_S_mask = torch.zeros([self.v_size])
155 | neg_S_mask[S] = 1
156 | neg_S_mask[neg_S] = 1
157 | return V, S_mask, neg_S_mask
158 |
159 | return V, S_mask
160 |
161 | def __len__(self):
162 | return len(self.data)
163 |
--------------------------------------------------------------------------------
/data_loader/bindingdb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pickle
4 | import numpy as np
5 | from tqdm import tqdm
6 | from rdkit import Chem
7 | from tdc.multi_pred import DTI
8 | from multiprocessing import Pool
9 | from rdkit.DataStructs import FingerprintSimilarity
10 | from sklearn.cluster import AffinityPropagation
11 |
12 | from torch.utils.data import Dataset, DataLoader
13 |
14 | acid_list = ['H', 'M', 'C', 'P', 'L', 'A', 'R', 'F', 'D', 'T', 'K', 'E', 'S', 'V', 'G', 'Y', 'N', 'W', 'I', 'Q']
15 | CHARPROTLEN = len(acid_list)
16 | CHARPROTDIC = { acid_list[idx]: idx for idx in range(len(acid_list))}
17 |
18 | smile_list = ['o', 'N', '/', 'H', '#', 'C', 'i', '+', 'l', '@', '8', '-', '6', '3', '\\', '2', 'B', 'P', '.',
19 | 'e', '9', '7', 'a', 's', 'O', ')', '0', 'n', '1', '4', 'I', 'F', ']', 'S', '5', '(', '[', '=', '%', 'c', 'r']
20 | CHARCANSMILEN = len(smile_list)
21 | CHARCANSMIDIC = { smile_list[idx]: idx for idx in range(len(smile_list))}
22 |
23 | class Tokenizer():
24 | @staticmethod
25 | def seq_tokenizer(seq, type_):
26 | if type_ == 'drug':
27 | max_length = 100
28 | mask = torch.zeros([max_length, CHARCANSMILEN])
29 | seq = np.array([CHARCANSMIDIC[item] for item in seq.split(" ")[:max_length]])
30 | elif type_ == 'protein':
31 | max_length = 1000
32 | mask = torch.zeros([max_length, CHARPROTLEN])
33 | seq = np.array([CHARPROTDIC[item] for item in seq.split(" ")[:max_length]])
34 |
35 | length = seq.shape[0]
36 | mask[range(length), seq] = 1
37 | return mask.transpose_(0, 1).unsqueeze(0)
38 |
39 | @staticmethod
40 | def tokenizer(bt, type_):
41 | bt = [Tokenizer.seq_tokenizer(seq, type_) for seq in bt]
42 | return bt
43 |
44 | class SetBindingDB(object):
45 | def __init__(self, params):
46 | super().__init__()
47 | self.params = params
48 | self.gen_datasets()
49 |
50 | def gen_datasets(self):
51 | np.random.seed(1) # fix dataset
52 | V_size, S_size = self.params.v_size, self.params.s_size
53 | self.dataset = load_bindingdb(self.params)
54 |
55 | data_root = '/root/dataset/bindingdb'
56 | data_path = os.path.join(data_root, 'bindingdb_set_data.pkl')
57 | if os.path.exists(data_path):
58 | print(f'load data from {data_path}')
59 | trainData, valData, testData = pickle.load(open(data_path, "rb"))
60 | self.V_train, self.S_train = trainData['V_train'], trainData['S_train']
61 | self.V_val, self.S_val = valData['V_train'], valData['S_train']
62 | self.V_test, self.S_test = testData['V_train'], testData['S_train']
63 | else:
64 | self.V_train, self.S_train = get_set_bindingdb_dataset_activate(self.dataset, V_size, S_size, self.params, size=1000)
65 | self.V_val, self.S_val = get_set_bindingdb_dataset_activate(self.dataset, V_size, S_size, self.params, size=100)
66 | self.V_test, self.S_test = get_set_bindingdb_dataset_activate(self.dataset, V_size, S_size, self.params, size=100)
67 |
68 | trainData = {'V_train': self.V_train, 'S_train': self.S_train}
69 | valData = {'V_train': self.V_val, 'S_train': self.S_val}
70 | testData = {'V_train': self.V_test, 'S_train': self.S_test}
71 | if not os.path.exists(data_root):
72 | os.makedirs(data_root)
73 | pickle.dump((trainData, valData, testData), open(data_path, "wb"))
74 |
75 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True):
76 | train_dataset = SetDataset(self.dataset, self.V_train, self.S_train, self.params, is_train=True)
77 | val_dataset = SetDataset(self.dataset, self.V_val, self.S_val, self.params)
78 | test_dataset = SetDataset(self.dataset, self.V_test, self.S_test, self.params)
79 |
80 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
81 | collate_fn=collate_train, pin_memory=True, shuffle=shuffle_train, num_workers=num_workers)
82 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
83 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers)
84 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
85 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers) if get_test else None
86 |
87 | return train_loader, val_loader, test_loader
88 |
89 | def collate_train(data):
90 | V_drug, V_target, S, neg_S = map(list, zip(*data))
91 | bs, vs = len(V_drug), V_drug[0].shape[0]
92 |
93 | b_D = [ gs[idx] for gs in V_drug for idx in range(vs)]
94 | b_P = [ gt[idx] for gt in V_target for idx in range(vs)]
95 | b_D = Tokenizer.tokenizer(b_D, 'drug')
96 | b_P = Tokenizer.tokenizer(b_P, 'protein')
97 |
98 | b_D = torch.cat(b_D, dim=0)
99 | b_P = torch.cat(b_P, dim=0)
100 | S = torch.cat(S, dim=0).reshape(bs, -1)
101 | neg_S = torch.cat(neg_S, dim=0).reshape(bs, -1)
102 | return (b_D, b_P), S, neg_S
103 |
104 | def collate_val_and_test(data):
105 | V_drug, V_target, S = map(list, zip(*data))
106 | bs, vs = len(V_drug), V_drug[0].shape[0]
107 |
108 | b_D = [ gs[idx] for gs in V_drug for idx in range(vs)]
109 | b_P = [ gt[idx] for gt in V_target for idx in range(vs)]
110 | b_D = Tokenizer.tokenizer(b_D, 'drug')
111 | b_P = Tokenizer.tokenizer(b_P, 'protein')
112 |
113 | b_D = torch.cat(b_D, dim=0)
114 | b_P = torch.cat(b_P, dim=0)
115 | S = torch.cat(S, dim=0).reshape(bs, -1)
116 | return (b_D, b_P), S
117 |
118 | class SetDataset(Dataset):
119 | def __init__(self, dataset, V_idxs, S_idxs, params, is_train=False):
120 | self.drugs, self.targets = dataset['Drug'], dataset['Target']
121 | self.V_idxs, self.S_idxs = V_idxs, S_idxs
122 | self.is_train = is_train
123 | self.neg_num = params.neg_num
124 | self.v_size = params.v_size
125 |
126 | def __getitem__(self, index):
127 | V_idxs, S = np.array(self.V_idxs[index]), np.array(self.S_idxs[index])
128 | V_drug = np.array([" ".join(item) for item in self.drugs[V_idxs].tolist()])
129 | V_target = np.array([" ".join(item) for item in self.targets[V_idxs].tolist()])
130 |
131 | S_mask = torch.zeros([self.v_size])
132 | S_mask[S] = 1
133 | if self.is_train:
134 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0]
135 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]]
136 | neg_S_mask = torch.zeros([self.v_size])
137 | neg_S_mask[S] = 1
138 | neg_S_mask[neg_S] = 1
139 | return V_drug, V_target, S_mask, neg_S_mask
140 |
141 | return V_drug, V_target, S_mask
142 |
143 | def __len__(self):
144 | return len(self.V_idxs)
145 |
146 | def load_bindingdb(params):
147 | data = DTI(name = 'BindingDB_Kd', path='/root/dataset/bindingdb')
148 | data.harmonize_affinities(mode = 'mean')
149 | return data.get_data()
150 |
151 | def load_dt_pair(dataset, setdata):
152 | drugs, targets = dataset['Drug'], dataset['Target']
153 | V, S = setdata['V_train'], setdata['S_train']
154 |
155 | V_drugs_list, V_targets_list = [], []
156 | for V_idxs in V:
157 | V_drugs = [" ".join(item) for item in drugs[V_idxs].tolist()]
158 | V_drugs_list.append(V_drugs)
159 | V_targets = [" ".join(item) for item in targets[V_idxs].tolist()]
160 | V_targets_list.append(V_targets)
161 |
162 | V_drug = np.array(V_drugs_list)
163 | V_target = np.array(V_targets_list)
164 | S = torch.Tensor(S).type(torch.long)
165 | return (V_drug, V_target), S
166 |
167 | def get_set_bindingdb_dataset_activate(dataset, V_size, S_size, params, size=1000):
168 | """
169 | Generate dataset for compound selection with only the bioactivity filter
170 | """
171 | _, _, labels = dataset['Drug'], dataset['Target'], dataset['Y']
172 | data_size = len(labels)
173 |
174 | V_list, S_list = [], []
175 | for _ in tqdm(range(size)):
176 | V_idxs = np.random.permutation(data_size)[:V_size]
177 | sub_labels = torch.from_numpy(labels[V_idxs].to_numpy())
178 | _, idxs = torch.topk(sub_labels, S_size)
179 |
180 | V_list.append(V_idxs)
181 | S_list.append(idxs)
182 | return np.array(V_list), np.array(S_list)
183 |
184 | def get_set_bindingdb_dataset(dataset, V_size, S_size, params, size=1000):
185 | """
186 | Generate dataset for compound selection with bioactivity and diversity filters
187 | """
188 | drugs, targets, labels = dataset['Drug'], dataset['Target'], dataset['Y']
189 | data_size = len(labels)
190 |
191 | V_list, S_list = [], []
192 | pbar = tqdm(total=size)
193 | num = 0
194 | while True:
195 | if num == size: break
196 |
197 | V_idxs = np.random.permutation(data_size)[:V_size]
198 | sub_labels = torch.from_numpy(labels[V_idxs].to_numpy())
199 | _, idxs = torch.topk(sub_labels, V_size // 3)
200 | filter_idxs = V_idxs[idxs]
201 | S_idxs = get_os_oracle(drugs, filter_idxs)
202 |
203 | if len(S_idxs) == 0: continue
204 |
205 | V_list.append(V_idxs)
206 | S_list.append([np.where(V_idxs == item)[0][0] for item in S_idxs])
207 |
208 | num += 1
209 | pbar.update(1)
210 | return np.array(V_list), np.array(S_list)
211 |
212 | def get_os_oracle(drugs, filter_idxs):
213 | smile_list = drugs[filter_idxs].to_numpy()
214 | n = len(smile_list)
215 |
216 | sm = np.zeros((n, n))
217 | max_cpu = os.cpu_count()
218 | ij_list = [(i, j, smile_list) for i in range(n) for j in range(i+1, n)]
219 | with Pool(max_cpu) as p:
220 | similarity = p.starmap(cal_fingerprint_similarity, ij_list)
221 | i, j, _ = zip(*ij_list)
222 | sm[i,j] = similarity
223 | sm = sm + sm.T + np.eye(n)
224 |
225 | af = AffinityPropagation().fit(sm)
226 | cluster_centers_indices = af.cluster_centers_indices_
227 | return filter_idxs[cluster_centers_indices]
228 |
229 | def cal_fingerprint_similarity(i, j, smiles):
230 | m1, m2 = Chem.MolFromSmiles(smiles[i]), Chem.MolFromSmiles(smiles[j])
231 | return FingerprintSimilarity(Chem.RDKFingerprint(m1), Chem.RDKFingerprint(m2))
232 |
--------------------------------------------------------------------------------
/data_loader/celeba.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import gdown
4 | import pickle
5 | import zipfile
6 | import numpy as np
7 | import pandas as pd
8 | from PIL import Image
9 | from tqdm import tqdm
10 | from torchvision import transforms
11 | from torch.utils.data import Dataset, DataLoader
12 |
13 | image_size = 64
14 | img_transform = transforms.Compose([
15 | transforms.Resize(image_size),
16 | transforms.CenterCrop(image_size),
17 | transforms.ToTensor(),
18 | transforms.Normalize(mean=[0.5, 0.5, 0.5],
19 | std=[0.5, 0.5, 0.5])
20 | ])
21 | img_root_path = '/root/dataset/celeba/img_align_celeba/'
22 |
23 | class Data:
24 | def __init__(self, params):
25 | self.params = params
26 |
27 | def gen_datasets(self):
28 | raise NotImplementedError
29 |
30 | def get_loaders(self, batch_size, num_workers, shuffle_train=False,
31 | get_test=True):
32 | raise NotImplementedError
33 |
34 | class CelebA(Data):
35 | def __init__(self, params):
36 | super().__init__(params)
37 | data_root = self.download_celeba()
38 |
39 | torch.manual_seed(1) # fix dataset
40 | np.random.seed(1)
41 | self.gen_datasets(data_root)
42 |
43 | def download_celeba(self):
44 | url_img = 'https://drive.google.com/uc?id=1iBJh4vHuE9h-eMOqVis94QccxCT_LPFW'
45 | url_anno = 'https://drive.google.com/uc?id=1p0-TEiW4HgT8MblB399ep4YM3u5A0Edc'
46 | data_root = '/root/dataset/celeba'
47 | download_path_img = f'{data_root}/img_align_celeba.zip'
48 | download_path_anno = f'{data_root}/list_attr_celeba.txt'
49 | if not os.path.exists(data_root):
50 | os.makedirs(data_root)
51 |
52 | if not os.listdir(data_root):
53 | gdown.download(url_anno, download_path_anno, quiet=False)
54 | gdown.download(url_img, download_path_img, quiet=False)
55 | with zipfile.ZipFile(download_path_img, 'r') as ziphandler:
56 | ziphandler.extractall(data_root)
57 | return data_root
58 |
59 | def load_data(self, data_root):
60 | data_path = data_root + '/list_attr_celeba.txt'
61 | df = pd.read_csv(data_path, sep="\s+", skiprows=1)
62 | label_names = list(df.columns)[:-1]
63 | df = df.to_numpy()[:, :-1]
64 | df = np.maximum(df, 0) # -1 -> 0
65 | return df, label_names
66 |
67 | def gen_datasets(self, data_root):
68 | data_path = data_root + '/celebA_set_data.pkl'
69 | if os.path.exists(data_path):
70 | print(f'load data from {data_path}')
71 | label_names, trainData, valData, testData = pickle.load(open(data_path, "rb"))
72 | self.V_train, self.S_train, self.labels_train = trainData['V_train'], trainData['S_train'], trainData['labels_train']
73 | self.V_val, self.S_val, self.labels_val = valData['V_train'], valData['S_train'], valData['labels_train']
74 | self.V_test, self.S_test, self.labels_test = testData['V_train'], testData['S_train'], testData['labels_train']
75 | else:
76 | data, label_names = self.load_data(data_root)
77 | self.V_train, self.S_train, self.labels_train = get_set_celeba_dataset(data, data_size=10000, v_size=self.params.v_size)
78 | self.V_val, self.S_val, self.labels_val = get_set_celeba_dataset(data, data_size=1000, v_size=self.params.v_size)
79 | self.V_test, self.S_test, self.labels_test = get_set_celeba_dataset(data, data_size=1000, v_size=self.params.v_size)
80 |
81 | trainData = {'V_train': self.V_train, 'S_train': self.S_train, 'labels_train': self.labels_train}
82 | valData = {'V_train': self.V_val, 'S_train': self.S_val, 'labels_train': self.labels_val}
83 | testData = {'V_train': self.V_test, 'S_train': self.S_test, 'labels_train': self.labels_test}
84 | pickle.dump((label_names, trainData, valData, testData), open(data_path, "wb"))
85 |
86 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True):
87 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True)
88 | val_dataset = SetDataset(self.V_val, self.S_val, self.params)
89 | test_dataset = SetDataset(self.V_test, self.S_test, self.params)
90 |
91 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
92 | collate_fn=collate_train, shuffle=shuffle_train, num_workers=num_workers)
93 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
94 | collate_fn=collate_val_and_test, shuffle=False, num_workers=num_workers)
95 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
96 | collate_fn=collate_val_and_test, shuffle=False, num_workers=num_workers) if get_test else None
97 | return train_loader, val_loader, test_loader
98 |
99 | class SetDataset(Dataset):
100 | def __init__(self, U, S, params, is_train=False):
101 | self.data = U
102 | self.labels = S
103 | self.is_train = is_train
104 | self.neg_num = params.neg_num
105 | self.v_size = params.v_size
106 |
107 | def __getitem__(self, index):
108 | V_id = self.data[index]
109 | S = self.labels[index]
110 | V = torch.cat([ load_img(idx.item()) for idx in V_id ], dim=0)
111 |
112 | S_mask = torch.zeros([self.v_size])
113 | S_mask[S] = 1
114 | if self.is_train:
115 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0]
116 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]]
117 | neg_S_mask = torch.zeros([self.v_size])
118 | neg_S_mask[S] = 1
119 | neg_S_mask[neg_S] = 1
120 | return V, S_mask, neg_S_mask
121 |
122 | return V, S_mask
123 |
124 | def __len__(self):
125 | return len(self.data)
126 |
127 | def get_set_celeba_dataset(data, data_size, v_size):
128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129 | data = torch.Tensor(data).to(device)
130 | img_nums = data.shape[0]
131 |
132 | V_list = []
133 | S_list = []
134 | label_list = []
135 | cur_size = 0
136 | pbar = tqdm(total=data_size)
137 | while True:
138 | if cur_size == data_size: break
139 |
140 | nor_id = np.random.randint(img_nums)
141 | nor_label = data[nor_id]
142 | if torch.sum(nor_label).item() < 2:
143 | continue
144 | nor_lable_idxs = torch.nonzero(nor_label).reshape(-1)
145 | perm = torch.randperm(nor_lable_idxs.size(0))
146 | nor_lable_idxs = nor_lable_idxs[perm[:2]]
147 | nor_label = torch.zeros(nor_label.shape).to(device)
148 | nor_label[nor_lable_idxs] = 1
149 | nor_label = nor_label.reshape(-1, 1)
150 |
151 | s_size = np.random.randint(2, 4)
152 | nor_res = (torch.nonzero((data @ nor_label).squeeze(-1) == 2)).reshape(-1)
153 | ano_res = (torch.nonzero((data @ nor_label).squeeze(-1) == 0)).reshape(-1)
154 | if (nor_res.shape[0] < v_size) or (ano_res.shape[0] < s_size):
155 | continue
156 |
157 | perm = torch.randperm(nor_res.size(0))
158 | U = nor_res[perm[:v_size]].cpu()
159 | perm = torch.randperm(ano_res.size(0))
160 | S = ano_res[perm[:s_size]].cpu()
161 |
162 | S_idx = np.random.choice(list(range(v_size)), s_size, replace=False)
163 | U[S_idx] = S
164 | S = torch.Tensor(S_idx).type(torch.int64)
165 | lable_idxs = nor_lable_idxs.cpu()
166 |
167 | V_list.append(U)
168 | S_list.append(S)
169 | label_list.append(lable_idxs)
170 |
171 | cur_size += 1
172 | pbar.update(1)
173 | pbar.close()
174 | return V_list, S_list, label_list
175 |
176 | def collate_train(data):
177 | V, S, neg_S = map(list, zip(*data))
178 | bs = len(V)
179 |
180 | V = torch.cat(V, dim=0)
181 | S = torch.cat(S, dim=0).reshape(bs, -1)
182 | neg_S = torch.cat(neg_S, dim=0).reshape(bs, -1)
183 | return V, S, neg_S
184 |
185 | def collate_val_and_test(data):
186 | V, S = map(list, zip(*data))
187 | bs = len(V)
188 |
189 | V = torch.cat(V, dim=0)
190 | S = torch.cat(S, dim=0).reshape(bs, -1)
191 | return V, S
192 |
193 | def load_img(img_id):
194 | img_id = str(img_id + 1) # 0 -> 1
195 | img_path = img_root_path + ( '0' * (6 - len(img_id)) ) + img_id + '.jpg'
196 | img = Image.open(img_path).convert('RGB')
197 | img = img_transform(img).unsqueeze(0)
198 | return img
199 |
--------------------------------------------------------------------------------
/data_loader/gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn import datasets
4 | from torch.utils.data import Dataset, DataLoader
5 |
6 | class Data:
7 | def __init__(self, params):
8 | self.params = params
9 | self.gen_datasets()
10 |
11 | def gen_datasets(self):
12 | raise NotImplementedError
13 |
14 | def get_loaders(self, batch_size, num_workers, shuffle_train=False,
15 | get_test=True):
16 | raise NotImplementedError
17 |
18 |
19 | class GaussianMixture(Data):
20 | def __init__(self, params):
21 | super().__init__(params)
22 |
23 | def gen_datasets(self):
24 | np.random.seed(1)
25 | V_size, S_size = self.params.v_size, self.params.s_size
26 |
27 | self.V_train, self.S_train = get_gaussian_mixture_dataset(V_size, S_size)
28 | self.V_val, self.S_val = get_gaussian_mixture_dataset(V_size, S_size)
29 | self.V_test, self.S_test = get_gaussian_mixture_dataset(V_size, S_size)
30 |
31 | self.fea_size = 2
32 | self.x_lim, self.y_lim = 2, 2
33 |
34 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True):
35 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True)
36 | val_dataset = SetDataset(self.V_val, self.S_val, self.params)
37 | test_dataset = SetDataset(self.V_test, self.S_test, self.params)
38 |
39 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
40 | shuffle=shuffle_train, num_workers=num_workers)
41 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
42 | shuffle=False, num_workers=num_workers)
43 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
44 | shuffle=False, num_workers=num_workers) if get_test else None
45 | return train_loader, val_loader, test_loader
46 |
47 | class SetDataset(Dataset):
48 | def __init__(self, V, S, params, is_train=False):
49 | self.data = V
50 | self.labels = S
51 | self.is_train = is_train
52 | self.neg_num = params.neg_num
53 | self.v_size = params.v_size
54 |
55 | def __getitem__(self, index):
56 | V = self.data[index]
57 | S = self.labels[index]
58 |
59 | S_mask = torch.zeros([self.v_size])
60 | S_mask[S] = 1
61 | if self.is_train:
62 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0]
63 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]]
64 | neg_S_mask = torch.zeros([self.v_size])
65 | neg_S_mask[S] = 1
66 | neg_S_mask[neg_S] = 1
67 | return V, S_mask, neg_S_mask
68 |
69 | return V, S_mask
70 |
71 | def __len__(self):
72 | return len(self.data)
73 |
74 | def gen_gaussian_mixture(batch_size):
75 | scale = 1
76 | centers = [ (1. / np.sqrt(2), 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]
77 | centers = [(scale * x, scale * y) for x, y in centers]
78 |
79 | data, labels = [], []
80 | for i in range(batch_size):
81 | point = np.random.randn(2) * 0.5
82 | idx = np.random.randint(2)
83 | center = centers[idx]
84 | point[0] += center[0]
85 | point[1] += center[1]
86 | data.append(point)
87 | labels.append(idx)
88 | data = np.array(data, dtype="float32")
89 | labels = np.array(labels, dtype="int32")
90 |
91 | noise_label = np.random.randint(2)
92 | noise = data[labels == noise_label]
93 | data = data[labels == (1 - noise_label)]
94 | # dataset /= 1.414
95 | return data, noise
96 |
97 | def get_gaussian_mixture_dataset(V_size, S_size):
98 | V_list, S_list = [], []
99 | for _ in range(1000):
100 | data, noise = gen_gaussian_mixture(V_size*4)
101 |
102 | V = data[:V_size]
103 | S = np.random.choice(list(range(0,V_size)), S_size, replace=False)
104 | V[S, :] = noise[:S_size]
105 |
106 | V_list.append(V)
107 | S_list.append(S)
108 |
109 | V = torch.FloatTensor(V_list)
110 | S = torch.LongTensor(S_list)
111 |
112 | return V, S
113 |
--------------------------------------------------------------------------------
/data_loader/moons.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn import datasets
4 | from torch.utils.data import Dataset, DataLoader
5 |
6 | class Data:
7 | def __init__(self, params):
8 | self.params = params
9 | self.gen_datasets()
10 |
11 | def gen_datasets(self):
12 | raise NotImplementedError
13 |
14 | def get_loaders(self, batch_size, num_workers, shuffle_train=False,
15 | get_test=True):
16 | raise NotImplementedError
17 |
18 | class TwoMoons(Data):
19 | def __init__(self, params):
20 | super().__init__(params)
21 |
22 | def gen_datasets(self):
23 | np.random.seed(1) # fix dataset
24 | V_size, S_size = self.params.v_size, self.params.s_size
25 |
26 | self.V_train, self.S_train = get_two_moons_dataset(V_size, S_size, rand_seed=0)
27 | self.V_val, self.S_val = get_two_moons_dataset(V_size, S_size, rand_seed=1)
28 | self.V_test, self.S_test = get_two_moons_dataset(V_size, S_size, rand_seed=2)
29 |
30 | self.fea_size = 2
31 | self.x_lim, self.y_lim = 4, 2
32 |
33 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True):
34 | train_dataset = SetDataset(self.V_train, self.S_train, self.params, is_train=True)
35 | val_dataset = SetDataset(self.V_val, self.S_val, self.params)
36 | test_dataset = SetDataset(self.V_test, self.S_test, self.params)
37 |
38 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
39 | shuffle=shuffle_train, num_workers=num_workers)
40 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
41 | shuffle=False, num_workers=num_workers)
42 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
43 | shuffle=False, num_workers=num_workers) if get_test else None
44 | return train_loader, val_loader, test_loader
45 |
46 | class SetDataset(Dataset):
47 | def __init__(self, V, S, params, is_train=False):
48 | self.data = V
49 | self.labels = S
50 | self.is_train = is_train
51 | self.neg_num = params.neg_num
52 | self.v_size = params.v_size
53 |
54 | def __getitem__(self, index):
55 | V = self.data[index]
56 | S = self.labels[index]
57 |
58 | S_mask = torch.zeros([self.v_size])
59 | S_mask[S] = 1
60 | if self.is_train:
61 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0]
62 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]]
63 | neg_S_mask = torch.zeros([self.v_size])
64 | neg_S_mask[S] = 1
65 | neg_S_mask[neg_S] = 1
66 | return V, S_mask, neg_S_mask
67 |
68 | return V, S_mask
69 |
70 | def __len__(self):
71 | return len(self.data)
72 |
73 | def gen_moons(batch_size, rand_seed):
74 | data, Y = datasets.make_moons(n_samples=batch_size, noise=0.1, random_state=rand_seed)
75 | data = data.astype("float32")
76 | data = data * 2 + np.array([-1, -0.2])
77 |
78 | noise_label = np.random.randint(2)
79 | noise = data[Y == noise_label]
80 | data = data[Y == (1 - noise_label)]
81 | return data, noise
82 |
83 | def get_two_moons_dataset(V_size, S_size, rand_seed):
84 | V_list, S_list = [], []
85 | for idx in range(1000):
86 | data, noise = gen_moons(V_size*2, rand_seed*1000+idx)
87 |
88 | V = data[:V_size]
89 | S = np.random.choice(list(range(0,V_size)), S_size, replace=False)
90 | V[S, :] = noise[:S_size]
91 |
92 | V_list.append(V)
93 | S_list.append(S)
94 |
95 | V = torch.FloatTensor(np.array(V_list))
96 | S = torch.LongTensor(np.array(S_list))
97 |
98 | return V, S
99 |
--------------------------------------------------------------------------------
/data_loader/pdbbind.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 | # SPDX-License-Identifier: Apache-2.0
5 | #
6 | # PDBBind dataset processed by moleculenet.
7 |
8 | import dgl.backend as F
9 | import numpy as np
10 | import multiprocessing
11 | import os
12 | import glob
13 | from functools import partial
14 | import pandas as pd
15 |
16 | from dgl.data.utils import get_download_dir, download, _get_dgl_url, extract_archive
17 |
18 | from dgllife.utils import multiprocess_load_molecules, ACNN_graph_construction_and_featurization
19 |
20 | __all__ = ['PDBBind']
21 |
22 | class PDBBind(object):
23 | """PDBbind dataset processed by moleculenet.
24 |
25 | The description below is mainly based on
26 | `[1] `__.
27 | The PDBBind database consists of experimentally measured binding affinities for
28 | bio-molecular complexes `[2] `__,
29 | `[3] `__. It provides detailed
30 | 3D Cartesian coordinates of both ligands and their target proteins derived from experimental
31 | (e.g., X-ray crystallography) measurements. The availability of coordinates of the
32 | protein-ligand complexes permits structure-based featurization that is aware of the
33 | protein-ligand binding geometry. The authors of
34 | `[1] `__ use the
35 | "refined" and "core" subsets of the database
36 | `[4] `__, more carefully
37 | processed for data artifacts, as additional benchmarking targets.
38 |
39 | References:
40 |
41 | * [1] moleculenet: a benchmark for molecular machine learning
42 | * [2] The PDBbind database: collection of binding affinities for protein-ligand complexes
43 | with known three-dimensional structures
44 | * [3] The PDBbind database: methodologies and updates
45 | * [4] PDB-wide collection of binding data: current status of the PDBbind database
46 |
47 | Parameters
48 | ----------
49 | subset : str
50 | In moleculenet, we can use either the "refined" subset or the "core" subset. We can
51 | retrieve them by setting ``subset`` to be ``'refined'`` or ``'core'``. The size
52 | of the ``'core'`` set is 195 and the size of the ``'refined'`` set is 3706.
53 | pdb_version : str
54 | The version of PDBBind dataset. Currently implemented: ``'v2007'``, ``'v2015'``.
55 | Default to ``'v2015'``. User should not specify the version if using local PDBBind data.
56 | load_binding_pocket : bool
57 | Whether to load binding pockets or full proteins. Default to True.
58 | remove_coreset_from_refinedset: bool
59 | Whether to remove core set from refined set when training with refined set and test with core set.
60 | Default to True.
61 | sanitize : bool
62 | Whether sanitization is performed in initializing RDKit molecule instances. See
63 | https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
64 | Default to False.
65 | calc_charges : bool
66 | Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
67 | ``sanitize`` to be True. Default to False.
68 | remove_hs : bool
69 | Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
70 | slow for large molecules. Default to False.
71 | use_conformation : bool
72 | Whether we need to extract molecular conformation from proteins and ligands.
73 | Default to True.
74 | construct_graph_and_featurize : callable
75 | Construct a DGLGraph for the use of GNNs. Mapping ``self.ligand_mols[i]``,
76 | ``self.protein_mols[i]``, ``self.ligand_coordinates[i]`` and
77 | ``self.protein_coordinates[i]`` to a DGLGraph.
78 | Default to :func:`dgllife.utils.ACNN_graph_construction_and_featurization`.
79 | zero_padding : bool
80 | Whether to perform zero padding. While DGL does not necessarily require zero padding,
81 | pooling operations for variable length inputs can introduce stochastic behaviour, which
82 | is not desired for sensitive scenarios. Default to True.
83 | num_processes : int or None
84 | Number of worker processes to use. If None,
85 | then we will use the number of CPUs in the system. Default None.
86 | local_path : str or None
87 | Local path of existing PDBBind dataset.
88 | Default None, and PDBBind dataset will be downloaded from DGL database.
89 | Specify this argument to a local path of customized dataset, which should follow the structure and the naming format of PDBBind v2015.
90 | """
91 | def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False,
92 | calc_charges=False, remove_hs=False, use_conformation=True,
93 | construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
94 | zero_padding=True, num_processes=None, local_path=None):
95 | self.task_names = ['-logKd/Ki']
96 | self.n_tasks = len(self.task_names)
97 | self._read_data_files(pdb_version, subset, load_binding_pocket, remove_coreset_from_refinedset, local_path)
98 | self._preprocess(load_binding_pocket,
99 | sanitize, calc_charges, remove_hs, use_conformation,
100 | construct_graph_and_featurize, zero_padding, num_processes)
101 | # Prepare for Refined, Agglomerative Sequence Split and Agglomerative Structure Split
102 | if pdb_version == 'v2007' and not local_path:
103 | merged_df = self.df.merge(self.agg_split, on='PDB_code')
104 | self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index)
105 | for target_set in ['train', 'valid', 'test']]
106 | self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index)
107 | for target_set in ['train', 'valid', 'test']]
108 |
109 | def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_coreset_from_refinedset, local_path):
110 | """Download and extract pdbbind data files specified by the version"""
111 | # root_dir_path = get_download_dir()
112 | root_dir_path = '/root/dataset/.dgl'
113 | if local_path:
114 | if local_path[-1] != '/':
115 | local_path += '/'
116 | index_label_file = glob.glob(local_path + '*' + subset + '*data*')[0]
117 | elif pdb_version == 'v2015':
118 | self._url = 'dataset/pdbbind_v2015.tar.gz'
119 | data_path = root_dir_path + '/pdbbind_v2015.tar.gz'
120 | extracted_data_path = root_dir_path + '/pdbbind_v2015'
121 | download(_get_dgl_url(self._url), path=data_path, overwrite=False)
122 | extract_archive(data_path, extracted_data_path)
123 |
124 | if subset == 'core':
125 | index_label_file = extracted_data_path + '/v2015/INDEX_core_data.2013'
126 | elif subset == 'refined':
127 | index_label_file = extracted_data_path + '/v2015/INDEX_refined_data.2015'
128 | else:
129 | raise ValueError('Expect the subset_choice to be either core or refined, got {}'.format(subset))
130 | elif pdb_version == 'v2007':
131 | self._url = 'dataset/pdbbind_v2007.tar.gz'
132 | data_path = root_dir_path + '/pdbbind_v2007.tar.gz'
133 | extracted_data_path = root_dir_path + '/pdbbind_v2007'
134 | download(_get_dgl_url(self._url), path=data_path, overwrite=False)
135 | extract_archive(data_path, extracted_data_path, overwrite=False)
136 | extracted_data_path += '/home/ubuntu' # extra layer
137 |
138 | # DataFrame containing the pdbbind_2007_agglomerative_split.txt
139 | self.agg_split = pd.read_csv(extracted_data_path + '/v2007/pdbbind_2007_agglomerative_split.txt')
140 | self.agg_split.rename(columns={'PDB ID':'PDB_code', 'Sequence-based assignment':'sequence', 'Structure-based assignment':'structure'}, inplace=True)
141 | self.agg_split.loc[self.agg_split['PDB_code']=='1.00E+66', 'PDB_code'] = '1e66' # fix typo
142 | if subset == 'core':
143 | index_label_file = extracted_data_path + '/v2007/INDEX.2007.core.data'
144 | elif subset == 'refined':
145 | index_label_file = extracted_data_path + '/v2007/INDEX.2007.refined.data'
146 | else:
147 | raise ValueError('Expect the subset_choice to be either core or refined, got {}'.format(subset))
148 |
149 | contents = []
150 | with open(index_label_file, 'r') as f:
151 | for line in f.readlines():
152 | if line[0] != "#":
153 | splitted_elements = line.split()
154 | if pdb_version == 'v2015':
155 | if len(splitted_elements) == 8:
156 | # Ignore "//"
157 | contents.append(splitted_elements[:5] + splitted_elements[6:])
158 | else:
159 | print('Incorrect data format.')
160 | print(splitted_elements)
161 | elif pdb_version == 'v2007':
162 | if len(splitted_elements) == 6:
163 | contents.append(splitted_elements)
164 | else:
165 | contents.append(splitted_elements[:5] + [' '.join(splitted_elements[5:])])
166 |
167 | if pdb_version == 'v2015':
168 | self.df = pd.DataFrame(contents, columns=(
169 | 'PDB_code', 'resolution', 'release_year',
170 | '-logKd/Ki', 'Kd/Ki', 'reference', 'ligand_name'))
171 | elif pdb_version == 'v2007':
172 | self.df = pd.DataFrame(contents, columns=(
173 | 'PDB_code', 'resolution', 'release_year',
174 | '-logKd/Ki', 'Kd/Ki', 'cluster_ID'))
175 |
176 | pdbs = self.df['PDB_code'].tolist()
177 |
178 | # remove core set from refined set if using refined
179 | if remove_coreset_from_refinedset and subset == 'refined':
180 | if local_path:
181 | core_path = glob.glob(local_path + '*core*data*')[0]
182 | elif pdb_version == 'v2015':
183 | core_path = extracted_data_path + '/v2015/INDEX_core_data.2013'
184 | elif pdb_version == 'v2007':
185 | core_path = extracted_data_path + '/v2007/INDEX.2007.core.data'
186 |
187 | with open(core_path,'r') as f:
188 | for line in f:
189 | fields = line.strip().split()
190 | if fields[0] != "#" and fields[0] in pdbs:
191 | pdbs.remove(fields[0])
192 |
193 | if local_path:
194 | pdb_path = local_path
195 | else:
196 | pdb_path = os.path.join(extracted_data_path, pdb_version)
197 | print('Loading PDBBind data from', pdb_path)
198 | self.ligand_files = [os.path.join(pdb_path, pdb, '{}_ligand.sdf'.format(pdb)) for pdb in pdbs]
199 | if load_binding_pocket:
200 | self.protein_files = [os.path.join(pdb_path, pdb, '{}_pocket.pdb'.format(pdb)) for pdb in pdbs]
201 | else:
202 | self.protein_files = [os.path.join(pdb_path, pdb, '{}_protein.pdb'.format(pdb)) for pdb in pdbs]
203 |
204 | def _filter_out_invalid(self, ligands_loaded, proteins_loaded, use_conformation):
205 | """Filter out invalid ligand-protein pairs.
206 |
207 | Parameters
208 | ----------
209 | ligands_loaded : list
210 | Each element is a 2-tuple of the RDKit molecule instance and its associated atom
211 | coordinates. None is used to represent invalid/non-existing molecule or coordinates.
212 | proteins_loaded : list
213 | Each element is a 2-tuple of the RDKit molecule instance and its associated atom
214 | coordinates. None is used to represent invalid/non-existing molecule or coordinates.
215 | use_conformation : bool
216 | Whether we need conformation information (atom coordinates) and filter out molecules
217 | without valid conformation.
218 | """
219 | num_pairs = len(proteins_loaded)
220 | self.indices, self.ligand_mols, self.protein_mols = [], [], []
221 | if use_conformation:
222 | self.ligand_coordinates, self.protein_coordinates = [], []
223 | else:
224 | # Use None for placeholders.
225 | self.ligand_coordinates = [None for _ in range(num_pairs)]
226 | self.protein_coordinates = [None for _ in range(num_pairs)]
227 |
228 | for i in range(num_pairs):
229 | ligand_mol, ligand_coordinates = ligands_loaded[i]
230 | protein_mol, protein_coordinates = proteins_loaded[i]
231 | if (not use_conformation) and all(v is not None for v in [protein_mol, ligand_mol]):
232 | self.indices.append(i)
233 | self.ligand_mols.append(ligand_mol)
234 | self.protein_mols.append(protein_mol)
235 | elif all(v is not None for v in [
236 | protein_mol, protein_coordinates, ligand_mol, ligand_coordinates]):
237 | self.indices.append(i)
238 | self.ligand_mols.append(ligand_mol)
239 | self.ligand_coordinates.append(ligand_coordinates)
240 | self.protein_mols.append(protein_mol)
241 | self.protein_coordinates.append(protein_coordinates)
242 |
243 | def _preprocess(self, load_binding_pocket,
244 | sanitize, calc_charges, remove_hs, use_conformation,
245 | construct_graph_and_featurize, zero_padding, num_processes):
246 | """Preprocess the dataset.
247 |
248 | The pre-processing proceeds as follows:
249 |
250 | 1. Load the dataset
251 | 2. Clean the dataset and filter out invalid pairs
252 | 3. Construct graphs
253 | 4. Prepare node and edge features
254 |
255 | Parameters
256 | ----------
257 | load_binding_pocket : bool
258 | Whether to load binding pockets or full proteins.
259 | sanitize : bool
260 | Whether sanitization is performed in initializing RDKit molecule instances. See
261 | https://www.rdkit.org/docs/RDKit_Book.html for details of the sanitization.
262 | calc_charges : bool
263 | Whether to add Gasteiger charges via RDKit. Setting this to be True will enforce
264 | ``sanitize`` to be True.
265 | remove_hs : bool
266 | Whether to remove hydrogens via RDKit. Note that removing hydrogens can be quite
267 | slow for large molecules.
268 | use_conformation : bool
269 | Whether we need to extract molecular conformation from proteins and ligands.
270 | construct_graph_and_featurize : callable
271 | Construct a DGLHeteroGraph for the use of GNNs. Mapping self.ligand_mols[i],
272 | self.protein_mols[i], self.ligand_coordinates[i] and self.protein_coordinates[i]
273 | to a DGLHeteroGraph. Default to :func:`ACNN_graph_construction_and_featurization`.
274 | zero_padding : bool
275 | Whether to perform zero padding. While DGL does not necessarily require zero padding,
276 | pooling operations for variable length inputs can introduce stochastic behaviour, which
277 | is not desired for sensitive scenarios.
278 | num_processes : int or None
279 | Number of worker processes to use. If None,
280 | then we will use the number of CPUs in the system.
281 | """
282 | if num_processes is None:
283 | num_processes = multiprocessing.cpu_count() - 1
284 | num_processes = min(num_processes, len(self.df))
285 |
286 | print('Loading ligands...')
287 | ligands_loaded = multiprocess_load_molecules(self.ligand_files,
288 | sanitize=sanitize,
289 | calc_charges=calc_charges,
290 | remove_hs=remove_hs,
291 | use_conformation=use_conformation,
292 | num_processes=num_processes)
293 |
294 | print('Loading proteins...')
295 | proteins_loaded = multiprocess_load_molecules(self.protein_files,
296 | sanitize=sanitize,
297 | calc_charges=calc_charges,
298 | remove_hs=remove_hs,
299 | use_conformation=use_conformation,
300 | num_processes=num_processes)
301 |
302 | self._filter_out_invalid(ligands_loaded, proteins_loaded, use_conformation)
303 | self.df = self.df.iloc[self.indices]
304 | self.labels = F.zerocopy_from_numpy(self.df[self.task_names].values.astype(np.float32))
305 | print('Finished cleaning the dataset, '
306 | 'got {:d}/{:d} valid pairs'.format(len(self), len(self.ligand_files))) # account for the ones use_conformation failed
307 |
308 | # Prepare zero padding
309 | if zero_padding:
310 | max_num_ligand_atoms = 0
311 | max_num_protein_atoms = 0
312 | for i in range(len(self)):
313 | max_num_ligand_atoms = max(
314 | max_num_ligand_atoms, self.ligand_mols[i].GetNumAtoms())
315 | max_num_protein_atoms = max(
316 | max_num_protein_atoms, self.protein_mols[i].GetNumAtoms())
317 | else:
318 | max_num_ligand_atoms = None
319 | max_num_protein_atoms = None
320 |
321 | construct_graph_and_featurize = partial(construct_graph_and_featurize,
322 | max_num_ligand_atoms=max_num_ligand_atoms,
323 | max_num_protein_atoms=max_num_protein_atoms)
324 |
325 | print('Start constructing graphs and featurizing them.')
326 | num_mols = len(self)
327 |
328 | # construct graphs with multiprocessing
329 | # from signal import signal, SIGPIPE, SIG_DFL
330 | # signal(SIGPIPE,SIG_DFL)
331 | pool = multiprocessing.Pool(processes=num_processes)
332 | self.graphs = pool.starmap(construct_graph_and_featurize,
333 | zip(self.ligand_mols, self.protein_mols,
334 | self.ligand_coordinates, self.protein_coordinates))
335 | print(f'Done constructing {len(self.graphs)} graphs.')
336 |
337 |
338 | def __len__(self):
339 | """Get the size of the dataset.
340 |
341 | Returns
342 | -------
343 | int
344 | Number of valid ligand-protein pairs in the dataset.
345 | """
346 | return len(self.indices)
347 |
348 | def __getitem__(self, item):
349 | """Get the datapoint associated with the index.
350 |
351 | Parameters
352 | ----------
353 | item : int
354 | Index for the datapoint.
355 |
356 | Returns
357 | -------
358 | int
359 | Index for the datapoint.
360 | rdkit.Chem.rdchem.Mol
361 | RDKit molecule instance for the ligand molecule.
362 | rdkit.Chem.rdchem.Mol
363 | RDKit molecule instance for the protein molecule.
364 | DGLGraph or tuple of DGLGraphs
365 | Pre-processed DGLGraph with features extracted.
366 | For ACNN, a single DGLGraph;
367 | For PotentialNet, a tuple of DGLGraphs that consists of a molecular graph and a KNN graph of the complex.
368 | Float32 tensor
369 | Label for the datapoint.
370 | """
371 | return item, self.ligand_mols[item], self.protein_mols[item], \
372 | self.graphs[item], self.labels[item]
373 |
--------------------------------------------------------------------------------
/data_loader/set_pdbbind.py:
--------------------------------------------------------------------------------
1 | import os
2 | import dgl
3 | import torch
4 | import pickle
5 | import numpy as np
6 | from tqdm import tqdm
7 | from rdkit import Chem
8 | from sklearn.cluster import AffinityPropagation
9 | from rdkit.DataStructs import FingerprintSimilarity
10 | from torch.utils.data import Dataset, DataLoader
11 |
12 | from .pdbbind import PDBBind
13 |
14 | class SetPDBBind(object):
15 | def __init__(self, params):
16 | super().__init__()
17 | self.params = params
18 | self.gen_datasets()
19 |
20 | def gen_datasets(self):
21 | np.random.seed(1) # fix dataset
22 | V_size, S_size = self.params.v_size, self.params.s_size
23 | self.dataset = load_pdbbind(self.params)
24 |
25 | data_root = '/root/dataset/pdbbind'
26 | data_path = os.path.join(data_root, 'pdbbind_set_data.pkl')
27 | if os.path.exists(data_path):
28 | print(f'load data from {data_path}')
29 | trainData, valData, testData = pickle.load(open(data_path, "rb"))
30 | self.V_train, self.S_train = trainData['V_train'], trainData['S_train']
31 | self.V_val, self.S_val = valData['V_train'], valData['S_train']
32 | self.V_test, self.S_test = testData['V_train'], testData['S_train']
33 | else:
34 | self.V_train, self.S_train = get_set_pdbbind_dataset_activate(self.dataset, V_size, S_size, size=1000)
35 | self.V_val, self.S_val = get_set_pdbbind_dataset_activate(self.dataset, V_size, S_size, size=100)
36 | self.V_test, self.S_test = get_set_pdbbind_dataset_activate(self.dataset, V_size, S_size, size=100)
37 |
38 | trainData = {'V_train': self.V_train, 'S_train': self.S_train}
39 | valData = {'V_train': self.V_val, 'S_train': self.S_val}
40 | testData = {'V_train': self.V_test, 'S_train': self.S_test}
41 | if not os.path.exists(data_root):
42 | os.makedirs(data_root)
43 | pickle.dump((trainData, valData, testData), open(data_path, "wb"))
44 |
45 | def get_loaders(self, batch_size, num_workers, shuffle_train=False, get_test=True):
46 | train_dataset = SetDataset(self.dataset, self.V_train, self.S_train, self.params, is_train=True)
47 | val_dataset = SetDataset(self.dataset, self.V_val, self.S_val, self.params)
48 | test_dataset = SetDataset(self.dataset, self.V_test, self.S_test, self.params)
49 |
50 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
51 | collate_fn=collate_train, pin_memory=True, shuffle=shuffle_train, num_workers=num_workers)
52 | val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size,
53 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers)
54 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
55 | collate_fn=collate_val_and_test, pin_memory=True, shuffle=False, num_workers=num_workers) if get_test else None
56 |
57 | return train_loader, val_loader, test_loader
58 |
59 | def collate_train(data):
60 | U, S, neg_S = map(list, zip(*data))
61 | bs, vs = len(U), U[0].shape[0]
62 |
63 | bg = dgl.batch([ gs[idx] for gs in U for idx in range(vs)])
64 | for nty in bg.ntypes:
65 | bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty)
66 | for ety in bg.canonical_etypes:
67 | bg.set_e_initializer(dgl.init.zero_initializer, etype=ety)
68 |
69 | S = torch.cat(S, dim=0).reshape(bs, -1)
70 | neg_S = torch.cat(neg_S, dim=0).reshape(bs, -1)
71 | return bg, S, neg_S
72 |
73 | def collate_val_and_test(data):
74 | U, S = map(list, zip(*data))
75 | bs, vs = len(U), U[0].shape[0]
76 |
77 | bg = dgl.batch([ gs[idx] for gs in U for idx in range(vs)])
78 | for nty in bg.ntypes:
79 | bg.set_n_initializer(dgl.init.zero_initializer, ntype=nty)
80 | for ety in bg.canonical_etypes:
81 | bg.set_e_initializer(dgl.init.zero_initializer, etype=ety)
82 |
83 | S = torch.cat(S, dim=0).reshape(bs, -1)
84 | return bg, S
85 |
86 | class SetDataset(Dataset):
87 | def __init__(self, dataset, V_idxs, S_idxs, params, is_train=False):
88 | _, self.graphs, _ = dataset
89 | self.V_idxs, self.S_idxs = V_idxs, S_idxs
90 | self.is_train = is_train
91 | self.neg_num = params.neg_num
92 | self.v_size = params.v_size
93 |
94 | def __getitem__(self, index):
95 | V_idxs, S = np.array(self.V_idxs[index]), np.array(self.S_idxs[index])
96 | V_graphs = np.array([self.graphs[item] for item in V_idxs])
97 |
98 | S_mask = torch.zeros([self.v_size])
99 | S_mask[S] = 1
100 | if self.is_train:
101 | idxs = (S_mask == 0).nonzero(as_tuple=True)[0]
102 | neg_S = idxs[torch.randperm(idxs.shape[0])[:S.shape[0] * self.neg_num]]
103 | neg_S_mask = torch.zeros([self.v_size])
104 | neg_S_mask[S] = 1
105 | neg_S_mask[neg_S] = 1
106 | return V_graphs, S_mask, neg_S_mask
107 |
108 | return V_graphs, S_mask
109 |
110 | def __len__(self):
111 | return len(self.V_idxs)
112 |
113 | def load_pdbbind(params):
114 | params.subset = 'core'
115 | dataset = PDBBind(subset=params.subset)
116 |
117 | # decompose dataset
118 | _, ligand_mols, protein_mols, graphs, labels = map(list, zip(*dataset))
119 | ligand_mols = np.array(ligand_mols)
120 | graphs = np.array(graphs)
121 | labels = torch.stack(labels, dim=0)
122 | return (ligand_mols, graphs, labels)
123 |
124 | def get_set_pdbbind_dataset_activate(dataset, V_size, S_size, size=1000):
125 | """
126 | Generate dataset for compound selection with only the bioactivity filter
127 | """
128 | mols, _, labels = dataset
129 | data_size = len(mols)
130 |
131 | V_list, S_list = [], []
132 | for _ in tqdm(range(size)):
133 | V_idxs = np.random.permutation(data_size)[:V_size]
134 | sub_labels = labels[V_idxs].squeeze(dim=-1)
135 | _, idxs = torch.topk(sub_labels, S_size)
136 |
137 | V_list.append(V_idxs)
138 | S_list.append(idxs)
139 | return np.array(V_list), np.array(S_list)
140 |
141 | def get_set_pdbbind_dataset(dataset, V_size, S_size, size=1000):
142 | """
143 | Generate dataset for compound selection with bioactivity and diversity filters
144 | """
145 | mols, _, labels = dataset
146 | data_size = len(mols)
147 |
148 | V_list, S_list = [], []
149 | pbar = tqdm(total=size)
150 | num = 0
151 | while True:
152 | if num == size: break
153 |
154 | V_idxs = np.random.permutation(data_size)[:V_size]
155 | sub_labels = labels[V_idxs].squeeze(dim=-1)
156 | _, idxs = torch.topk(sub_labels, V_size // 3)
157 | filter_idxs = V_idxs[idxs]
158 | S_idxs = get_os_oracle(mols, filter_idxs)
159 |
160 | if len(S_idxs) == 0: continue
161 |
162 | V_list.append(V_idxs)
163 | S_list.append([np.where(V_idxs == item)[0][0] for item in S_idxs])
164 |
165 | num += 1
166 | pbar.update(1)
167 | return np.array(V_list), np.array(S_list)
168 |
169 | def get_os_oracle(mols, filter_idxs):
170 | # reference: https://nyxflower.com/2020/07/17/molecule-clustering/
171 | mol_list = mols[filter_idxs]
172 | n = len(mol_list)
173 |
174 | sm = np.zeros((n, n))
175 | for i in range(n):
176 | for j in range(i, n):
177 | m1, m2 = mol_list[i], mol_list[j]
178 | sm[i, j] = FingerprintSimilarity(Chem.RDKFingerprint(m1), Chem.RDKFingerprint(m2))
179 | sm = sm + sm.T - np.eye(n)
180 | af = AffinityPropagation().fit(sm)
181 | cluster_centers_indices = af.cluster_centers_indices_
182 | return filter_idxs[cluster_centers_indices]
183 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from model.EquiVSet import EquiVSet
2 | from utils.config import MOONS_CONFIG, GAUSSIAN_CONFIG, BINDINGDB_CONFIG, AMAZON_CONFIG, CELEBA_CONFIG, PDBBIND_CONFIG
3 |
4 | if __name__ == "__main__":
5 | argparser = EquiVSet.get_model_specific_argparser()
6 | hparams = argparser.parse_args()
7 |
8 | data_name = hparams.data_name
9 | if data_name == 'moons':
10 | hparams.__dict__.update(MOONS_CONFIG)
11 | elif data_name == 'gaussian':
12 | hparams.__dict__.update(GAUSSIAN_CONFIG)
13 | elif data_name == 'amazon':
14 | hparams.__dict__.update(AMAZON_CONFIG)
15 | elif data_name == 'celeba':
16 | hparams.__dict__.update(CELEBA_CONFIG)
17 | elif data_name == 'pdbbind':
18 | hparams.__dict__.update(PDBBIND_CONFIG)
19 | elif data_name == 'bindingdb':
20 | hparams.__dict__.update(BINDINGDB_CONFIG)
21 | else:
22 | raise ValueError('invalid dataset...')
23 |
24 | model = EquiVSet(hparams)
25 |
26 | if hparams.train:
27 | model.run_training_sessions()
28 | else:
29 | model.load()
30 | print('Loaded model with: %s' % model.flag_hparams())
31 |
32 | val_perf, test_perf = model.run_test()
33 | print('Val: {:8.2f}'.format(val_perf))
34 | print('Test: {:8.2f}'.format(test_perf))
--------------------------------------------------------------------------------
/model/EquiVSet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from model.base_model import Base_Model
4 | from model.modules import SetFuction, RecNet
5 |
6 |
7 | class EquiVSet(Base_Model):
8 | def __init__(self, hparams):
9 | super().__init__(hparams=hparams)
10 |
11 | def define_parameters(self):
12 | self.set_func = SetFuction(params=self.hparams)
13 | self.rec_net = RecNet(params=self.hparams) if self.hparams.mode != 'diffMF' else None
14 |
15 | def configure_optimizers(self):
16 | optim_energy = torch.optim.Adam(self.set_func.parameters(), lr=self.hparams.lr,
17 | weight_decay=self.hparams.weight_decay)
18 | optim_var = torch.optim.Adam(self.rec_net.parameters(), lr=self.hparams.lr,
19 | weight_decay=self.hparams.weight_decay) if self.hparams.mode != 'diffMF' else None
20 | return optim_energy, optim_var
21 |
22 | def configure_gradient_clippers(self):
23 | return [(self.parameters(), self.hparams.clip)]
24 |
25 | def inference(self, V, bs):
26 | if self.hparams.mode == 'diffMF':
27 | bs, vs = V.shape[:2]
28 | q = .5 * torch.ones(bs, vs).to(V.device)
29 | else:
30 | # mode == 'ind' or 'copula'
31 | q = self.rec_net.get_vardist(V, bs)
32 |
33 | for i in range(self.hparams.RNN_steps):
34 | sample_matrix_1, sample_matrix_0 = self.set_func.MC_sampling(q, self.hparams.num_samples)
35 | q = self.set_func.mean_field_iteration(V, sample_matrix_1, sample_matrix_0)
36 |
37 | return q
38 |
39 | def get_hparams_grid(self):
40 | grid = Base_Model.get_general_hparams_grid()
41 | grid.update({
42 | 'RNN_steps': [1],
43 | 'num_samples': [1, 5, 10, 15, 20],
44 | 'rank': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
45 | 'tau': [0.01, 0.03, 0.05, 0.07, 0.09, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 5, 10],
46 | })
47 | return grid
48 |
49 | @staticmethod
50 | def get_model_specific_argparser():
51 | parser = Base_Model.get_general_argparser()
52 |
53 | parser.add_argument('--mode', type=str, default='copula',
54 | choices=['diffMF', 'ind', 'copula'],
55 | help='name of the variant model [%(default)s]')
56 |
57 | parser.add_argument('--RNN_steps', type=int, default=1,
58 | help='num of RNN steps [%(default)d]')
59 | parser.add_argument('--num_samples', type=int, default=5,
60 | help='num of Monte Carlo samples [%(default)d]')
61 | parser.add_argument('--rank', type=int, default=5,
62 | help='rank of the perturbation matrix [%(default)d]')
63 | parser.add_argument('--tau', type=float, default=0.1,
64 | help='temperature of the relaxed multivariate bernoulli [%(default)g]')
65 | parser.add_argument('--neg_num', type=int, default=1,
66 | help='num of the negtive item [%(default)d]')
67 |
68 | return parser
69 |
--------------------------------------------------------------------------------
/model/acnn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 | # SPDX-License-Identifier: Apache-2.0
5 | #
6 | # Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity"""
7 | # pylint: disable=C0103, C0123, W0221, E1101, R1721
8 |
9 | import itertools
10 | import dgl
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 |
15 | from dgl.nn.pytorch import AtomicConv
16 |
17 | __all__ = ['ACNN']
18 |
19 | def truncated_normal_(tensor, mean=0., std=1.):
20 | """Fills the given tensor in-place with elements sampled from the truncated normal
21 | distribution parameterized by mean and std.
22 |
23 | The generated values follow a normal distribution with specified mean and
24 | standard deviation, except that values whose magnitude is more than 2 std
25 | from the mean are dropped.
26 |
27 | We credit to Ruotian Luo for this implementation:
28 | https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15.
29 |
30 | Parameters
31 | ----------
32 | tensor : Float32 tensor of arbitrary shape
33 | Tensor to be filled.
34 | mean : float
35 | Mean of the truncated normal distribution.
36 | std : float
37 | Standard deviation of the truncated normal distribution.
38 | """
39 | shape = tensor.shape
40 | tmp = tensor.new_empty(shape + (4,)).normal_()
41 | valid = (tmp < 2) & (tmp > -2)
42 | ind = valid.max(-1, keepdim=True)[1]
43 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
44 | tensor.data.mul_(std).add_(mean)
45 |
46 | class MLP(nn.Module):
47 | """MLP with linear output"""
48 |
49 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
50 | """MLP layers construction
51 |
52 | Paramters
53 | ---------
54 | num_layers: int
55 | The number of linear layers
56 | input_dim: int
57 | The dimensionality of input features
58 | hidden_dim: int
59 | The dimensionality of hidden units at ALL layers
60 | output_dim: int
61 | The number of classes for prediction
62 |
63 | """
64 | super(MLP, self).__init__()
65 | self.linear_or_not = True # default is linear model
66 | self.num_layers = num_layers
67 | self.output_dim = output_dim
68 |
69 | if num_layers < 1:
70 | raise ValueError("number of layers should be positive!")
71 | elif num_layers == 1:
72 | # Linear model
73 | self.linear = nn.Linear(input_dim, output_dim)
74 | else:
75 | # Multi-layer model
76 | self.linear_or_not = False
77 | self.linears = torch.nn.ModuleList()
78 | self.batch_norms = torch.nn.ModuleList()
79 |
80 | self.linears.append(nn.Linear(input_dim, hidden_dim))
81 | for layer in range(num_layers - 2):
82 | self.linears.append(nn.Linear(hidden_dim, hidden_dim))
83 | self.linears.append(nn.Linear(hidden_dim, output_dim))
84 |
85 | for layer in range(num_layers - 1):
86 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
87 |
88 | def forward(self, x):
89 | if self.linear_or_not:
90 | # If linear model
91 | return self.linear(x)
92 | else:
93 | # If MLP
94 | h = x
95 | for i in range(self.num_layers - 1):
96 | h = F.relu(self.batch_norms[i](self.linears[i](h)))
97 | return self.linears[-1](h)
98 |
99 | class ACNNPredictor(nn.Module):
100 | """Predictor for ACNN.
101 |
102 | Parameters
103 | ----------
104 | in_size : int
105 | Number of radial filters used.
106 | hidden_sizes : list of int
107 | Specifying the hidden sizes for all layers in the predictor.
108 | weight_init_stddevs : list of float
109 | Specifying the standard deviations to use for truncated normal
110 | distributions in initialzing weights for the predictor.
111 | dropouts : list of float
112 | Specifying the dropouts to use for all layers in the predictor.
113 | features_to_use : None or float tensor of shape (T)
114 | In the original paper, these are atomic numbers to consider, representing the types
115 | of atoms. T for the number of types of atomic numbers. Default to None.
116 | num_tasks : int
117 | Output size.
118 | """
119 | def __init__(self, in_size, hidden_sizes, weight_init_stddevs,
120 | dropouts, features_to_use, num_tasks):
121 | super(ACNNPredictor, self).__init__()
122 |
123 | if type(features_to_use) != type(None):
124 | in_size *= len(features_to_use)
125 |
126 | modules = []
127 | for i, h in enumerate(hidden_sizes):
128 | linear_layer = nn.Linear(in_size, h)
129 | truncated_normal_(linear_layer.weight, std=weight_init_stddevs[i])
130 | modules.append(linear_layer)
131 | modules.append(nn.ReLU())
132 | modules.append(nn.Dropout(dropouts[i]))
133 | in_size = h
134 | linear_layer = nn.Linear(in_size, num_tasks)
135 | truncated_normal_(linear_layer.weight, std=weight_init_stddevs[-1])
136 | modules.append(linear_layer)
137 | self.project = nn.Sequential(*modules)
138 | self.fea_layer = MLP(num_layers=1, input_dim=1922, hidden_dim=2048, output_dim=256)
139 |
140 | def forward(self, batch_size, frag1_node_indices_in_complex, frag2_node_indices_in_complex,
141 | ligand_conv_out, protein_conv_out, complex_conv_out):
142 | """Perform the prediction.
143 |
144 | Parameters
145 | ----------
146 | batch_size : int
147 | Number of datapoints in a batch.
148 | frag1_node_indices_in_complex : Int64 tensor of shape (V1)
149 | Indices for atoms in the first fragment (protein) in the batched complex.
150 | frag2_node_indices_in_complex : list of int of length V2
151 | Indices for atoms in the second fragment (ligand) in the batched complex.
152 | ligand_conv_out : Float32 tensor of shape (V2, K * T)
153 | Updated ligand node representations. V2 for the number of atoms in the
154 | ligand, K for the number of radial filters, and T for the number of types
155 | of atomic numbers.
156 | protein_conv_out : Float32 tensor of shape (V1, K * T)
157 | Updated protein node representations. V1 for the number of
158 | atoms in the protein, K for the number of radial filters,
159 | and T for the number of types of atomic numbers.
160 | complex_conv_out : Float32 tensor of shape (V1 + V2, K * T)
161 | Updated complex node representations. V1 and V2 separately
162 | for the number of atoms in the ligand and protein, K for
163 | the number of radial filters, and T for the number of
164 | types of atomic numbers.
165 |
166 | Returns
167 | -------
168 | Float32 tensor of shape (B, O)
169 | Predicted protein-ligand binding affinity. B for the number
170 | of protein-ligand pairs in the batch and O for the number of tasks.
171 | """
172 |
173 | ligand_feats = self.project(ligand_conv_out) # (V1, O)
174 | protein_feats = self.project(protein_conv_out) # (V2, O)
175 | complex_feats = self.project(complex_conv_out) # (V1+V2, O)
176 |
177 | ligand_energy = ligand_feats.reshape(batch_size, -1) # (B, O)
178 | protein_energy = protein_feats.reshape(batch_size, -1) # (B, O)
179 |
180 | complex_ligand_energy = complex_feats[frag1_node_indices_in_complex].reshape(
181 | batch_size, -1)
182 | complex_protein_energy = complex_feats[frag2_node_indices_in_complex].reshape(
183 | batch_size, -1)
184 | complex_energy = torch.cat([complex_ligand_energy, complex_protein_energy], dim=-1)
185 |
186 | fea = torch.cat([ligand_energy, protein_energy, complex_energy], dim=-1)
187 | return self.fea_layer(fea)
188 |
189 |
190 | class ACNN(nn.Module):
191 | """Atomic Convolutional Networks.
192 |
193 | The model was proposed in `Atomic Convolutional Networks for
194 | Predicting Protein-Ligand Binding Affinity `__.
195 |
196 | The prediction proceeds as follows:
197 |
198 | 1. Perform message passing to update atom representations for the
199 | ligand, protein and protein-ligand complex.
200 | 2. Predict the energy of atoms from their representations with an MLP.
201 | 3. Take the sum of predicted energy of atoms within each molecule for
202 | predicted energy of the ligand, protein and protein-ligand complex.
203 | 4. Make the final prediction by subtracting the predicted ligand and protein
204 | energy from the predicted complex energy.
205 |
206 | Parameters
207 | ----------
208 | hidden_sizes : list of int
209 | ``hidden_sizes[i]`` gives the size of hidden representations in the i-th
210 | hidden layer of the MLP. By Default, ``[32, 32, 16]`` will be used.
211 | weight_init_stddevs : list of float
212 | ``weight_init_stddevs[i]`` gives the std to initialize parameters in the
213 | i-th layer of the MLP. Note that ``len(weight_init_stddevs) == len(hidden_sizes) + 1``
214 | due to the output layer. By default, we use ``1 / sqrt(hidden_sizes[i])`` for hidden
215 | layers and 0.01 for the output layer.
216 | dropouts : list of float
217 | ``dropouts[i]`` gives the dropout in the i-th hidden layer of the MLP. By default,
218 | no dropout is used.
219 | features_to_use : None or float tensor of shape (T)
220 | In the original paper, these are atomic numbers to consider, representing the types
221 | of atoms. T for the number of types of atomic numbers. If None, we use same parameters
222 | for all atoms regardless of their type. Default to None.
223 | radial : list
224 | The list consists of 3 sublists of floats, separately for the
225 | options of interaction cutoff, the options of rbf kernel mean and the
226 | options of rbf kernel scaling. By default,
227 | ``[[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]`` will be used.
228 | num_tasks : int
229 | Number of output tasks. Default to 1.
230 | """
231 | def __init__(self, hidden_sizes=None, weight_init_stddevs=None, dropouts=None,
232 | features_to_use=None, radial=None, num_tasks=1):
233 | super(ACNN, self).__init__()
234 |
235 | if hidden_sizes is None:
236 | hidden_sizes = [32, 32, 16]
237 |
238 | if weight_init_stddevs is None:
239 | weight_init_stddevs = [1. / float(np.sqrt(hidden_sizes[i]))
240 | for i in range(len(hidden_sizes))]
241 | weight_init_stddevs.append(0.01)
242 |
243 | if dropouts is None:
244 | dropouts = [0. for _ in range(len(hidden_sizes))]
245 |
246 | if radial is None:
247 | radial = [[12.0], [0.0, 2.0, 4.0, 6.0, 8.0], [4.0]]
248 | # Take the product of sets of options and get a list of 3-tuples.
249 | radial_params = [x for x in itertools.product(*radial)]
250 | radial_params = torch.stack(list(map(torch.tensor, zip(*radial_params))), dim=1)
251 |
252 | interaction_cutoffs = radial_params[:, 0]
253 | rbf_kernel_means = radial_params[:, 1]
254 | rbf_kernel_scaling = radial_params[:, 2]
255 |
256 | self.ligand_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
257 | rbf_kernel_scaling, features_to_use)
258 | self.protein_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
259 | rbf_kernel_scaling, features_to_use)
260 | self.complex_conv = AtomicConv(interaction_cutoffs, rbf_kernel_means,
261 | rbf_kernel_scaling, features_to_use)
262 | self.predictor = ACNNPredictor(radial_params.shape[0], hidden_sizes,
263 | weight_init_stddevs, dropouts, features_to_use, num_tasks)
264 |
265 | def forward(self, graph):
266 | """Apply the model for prediction.
267 |
268 | Parameters
269 | ----------
270 | graph : DGLHeteroGraph
271 | DGLHeteroGraph consisting of the ligand graph, the protein graph
272 | and the complex graph, along with preprocessed features. For a batch of
273 | protein-ligand pairs, we assume zero padding is performed so that the
274 | number of ligand and protein atoms is the same in all pairs.
275 |
276 | Returns
277 | -------
278 | Float32 tensor of shape (B, O)
279 | Predicted protein-ligand binding affinity. B for the number
280 | of protein-ligand pairs in the batch and O for the number of tasks.
281 | """
282 | ligand_graph = graph[('ligand_atom', 'ligand', 'ligand_atom')]
283 | ligand_graph_node_feats = ligand_graph.ndata['atomic_number']
284 | assert ligand_graph_node_feats.shape[-1] == 1
285 | ligand_graph_distances = ligand_graph.edata['distance']
286 | ligand_conv_out = self.ligand_conv(ligand_graph,
287 | ligand_graph_node_feats,
288 | ligand_graph_distances)
289 |
290 | protein_graph = graph[('protein_atom', 'protein', 'protein_atom')]
291 | protein_graph_node_feats = protein_graph.ndata['atomic_number']
292 | assert protein_graph_node_feats.shape[-1] == 1
293 | protein_graph_distances = protein_graph.edata['distance']
294 | protein_conv_out = self.protein_conv(protein_graph,
295 | protein_graph_node_feats,
296 | protein_graph_distances)
297 |
298 | complex_graph = dgl.edge_type_subgraph(graph,
299 | [('ligand_atom', 'complex', 'ligand_atom'),
300 | ('ligand_atom', 'complex', 'protein_atom'),
301 | ('protein_atom', 'complex', 'ligand_atom'),
302 | ('protein_atom', 'complex', 'protein_atom')])
303 | complex_graph = dgl.to_homogeneous(
304 | complex_graph, ndata=['atomic_number'], edata=['distance'])
305 | complex_graph_node_feats = complex_graph.ndata['atomic_number']
306 | assert complex_graph_node_feats.shape[-1] == 1
307 | complex_graph_distances = complex_graph.edata['distance']
308 | complex_conv_out = self.complex_conv(complex_graph,
309 | complex_graph_node_feats,
310 | complex_graph_distances)
311 |
312 | frag1_node_indices_in_complex = torch.where(complex_graph.ndata['_TYPE'] == 0)[0]
313 | frag2_node_indices_in_complex = list(set(range(complex_graph.num_nodes())) -
314 | set(frag1_node_indices_in_complex.tolist()))
315 |
316 | return self.predictor(
317 | graph.batch_size,
318 | frag1_node_indices_in_complex,
319 | frag2_node_indices_in_complex,
320 | ligand_conv_out, protein_conv_out, complex_conv_out)
321 |
--------------------------------------------------------------------------------
/model/base_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import random
4 | import argparse
5 | import datetime
6 | import numpy as np
7 | import torch.nn as nn
8 | from copy import deepcopy
9 | from datetime import timedelta
10 | from collections import OrderedDict
11 | from collections import defaultdict
12 | from timeit import default_timer as timer
13 |
14 | from utils.logger import Logger
15 | from utils.evaluation import compute_metrics
16 | from utils.pytorch_helper import move_to_device
17 | from data_loader import TwoMoons, GaussianMixture, Amazon, CelebA
18 | from data_loader import SetPDBBind, SetBindingDB
19 |
20 |
21 | class Base_Model(nn.Module):
22 | def __init__(self, hparams):
23 | super().__init__()
24 | self.hparams = hparams
25 | self.hparams.save_path = self.hparams.root_path + self.hparams.model_name
26 | self.load_data()
27 |
28 | def load_data(self):
29 | data_name = self.hparams.data_name
30 | if data_name == 'moons':
31 | self.data = TwoMoons(self.hparams)
32 | elif data_name == 'gaussian':
33 | self.data = GaussianMixture(self.hparams)
34 | elif data_name == 'amazon':
35 | self.data = Amazon(self.hparams)
36 | elif data_name == 'celeba':
37 | self.data = CelebA(self.hparams)
38 | elif data_name == 'pdbbind':
39 | self.data = SetPDBBind(self.hparams)
40 | elif data_name == 'bindingdb':
41 | self.data = SetBindingDB(self.hparams)
42 | else:
43 | raise ValueError("invalid dataset...")
44 |
45 | def configure_optimizers(self):
46 | raise NotImplementedError
47 |
48 | def configure_gradient_clippers(self):
49 | raise NotImplementedError
50 |
51 | def run_training_sessions(self):
52 | logger = Logger(self.hparams.save_path + '.log', on=True)
53 | val_perfs = []
54 | test_perfs = []
55 | best_val_perf = float('-inf')
56 | start = timer()
57 | random.seed(self.hparams.seed) # For reproducible random runs
58 |
59 | for run_num in range(1, self.hparams.num_runs + 1):
60 | state_dict, val_perf, test_perf = self.run_training_session(run_num, logger)
61 | val_perfs.append(val_perf)
62 | test_perfs.append(test_perf)
63 |
64 | if val_perf > best_val_perf:
65 | best_val_perf = val_perf
66 | logger.log('----New best {:8.2f}, saving'.format(val_perf))
67 | torch.save({'hparams': self.hparams,
68 | 'state_dict': state_dict}, self.hparams.save_path)
69 |
70 | logger.log('Time: %s' % str(timedelta(seconds=round(timer() - start))))
71 | self.load()
72 | if self.hparams.num_runs > 1:
73 | logger.log_perfs(val_perfs)
74 | logger.log('best hparams: ' + self.flag_hparams())
75 | if self.hparams.auto_repar == False:
76 | logger.log_test_perfs(test_perfs, self.hparams)
77 |
78 | val_perf, test_perf = self.run_test()
79 | logger.log('Val: {:8.2f}'.format(val_perf))
80 | logger.log('Test: {:8.2f}'.format(test_perf))
81 |
82 | def run_training_session(self, run_num, logger):
83 | self.train()
84 |
85 | # Scramble hyperparameters if number of runs is greater than 1.
86 | if self.hparams.num_runs > 1:
87 | logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs))
88 | if self.hparams.auto_repar == True:
89 | for hparam, values in self.get_hparams_grid().items():
90 | assert hasattr(self.hparams, hparam)
91 | self.hparams.__dict__[hparam] = random.choice(values)
92 | else:
93 | self.hparams.seed = np.random.randint(100000)
94 |
95 | np.random.seed(self.hparams.seed)
96 | random.seed(self.hparams.seed)
97 | torch.manual_seed(self.hparams.seed)
98 | torch.cuda.manual_seed_all(self.hparams.seed)
99 |
100 | self.define_parameters()
101 | logger.log(str(self))
102 | logger.log('#params: %d' % sum([p.numel() for p in self.parameters()]))
103 | logger.log('hparams: %s' % self.flag_hparams())
104 |
105 | device = torch.device('cuda' if self.hparams.cuda else 'cpu')
106 | self.to(device)
107 |
108 | optim_energy, optim_var = self.configure_optimizers()
109 | gradient_clippers = self.configure_gradient_clippers()
110 | train_loader, val_loader, test_loader = self.data.get_loaders(
111 | self.hparams.batch_size, self.hparams.num_workers,
112 | shuffle_train=True, get_test=True)
113 | best_val_perf = float('-inf')
114 | best_state_dict = None
115 | forward_sum = defaultdict(float)
116 | num_steps = 0
117 | bad_epochs = 0
118 |
119 | times = []
120 | try:
121 | for epoch in range(1, self.hparams.epochs + 1):
122 | starttime = datetime.datetime.now()
123 |
124 | for batch_num, batch in enumerate(train_loader):
125 | V_set, S_set, neg_S_set = move_to_device(batch, device)
126 |
127 | if self.hparams.mode != 'diffMF':
128 | # optimize variational distribution (the q distribution)
129 | optim_var.zero_grad()
130 | neg_elbo = self.rec_net(V_set, self.set_func, bs=S_set.size(0))
131 | neg_elbo.backward()
132 | for params, clip in gradient_clippers:
133 | nn.utils.clip_grad_norm_(params, clip)
134 | optim_var.step()
135 |
136 | if math.isnan(neg_elbo):
137 | logger.log('Stopping epoch because loss is NaN')
138 | break
139 |
140 | # optimize energy function (the p distribution)
141 | optim_energy.zero_grad()
142 | entropy_loss = self.set_func(V_set, S_set, neg_S_set, self.rec_net)
143 | entropy_loss.backward()
144 | for params, clip in gradient_clippers:
145 | nn.utils.clip_grad_norm_(params, clip)
146 | optim_energy.step()
147 |
148 | num_steps += 1
149 | forward_sum['neg_elbo'] += neg_elbo.item() if self.hparams.mode != 'diffMF' else 0
150 | forward_sum['entropy'] += entropy_loss.item()
151 | if math.isnan(entropy_loss):
152 | logger.log('Stopping epoch because loss is NaN')
153 | break
154 |
155 | endtime = datetime.datetime.now()
156 | times.append(endtime - starttime)
157 |
158 | if math.isnan(forward_sum['neg_elbo']) or math.isnan(forward_sum['entropy']):
159 | logger.log('Stopping training session because loss is NaN')
160 | break
161 |
162 | val_perf = self.evaluate(val_loader, device)
163 | logger.log('End of epoch {:3d}'.format(epoch), False)
164 | logger.log(' '.join([' | {:s} {:8.2f}'.format(
165 | key, forward_sum[key] / num_steps)
166 | for key in forward_sum]), False)
167 | logger.log(' | val perf {:8.2f}'.format(val_perf), False)
168 |
169 | if val_perf > best_val_perf:
170 | best_val_perf = val_perf
171 | bad_epochs = 0
172 | logger.log('\t\t*Best model so far, deep copying*')
173 | best_state_dict = deepcopy(self.state_dict())
174 | test_perf = self.evaluate(test_loader, device)
175 | else:
176 | bad_epochs += 1
177 | logger.log('\t\tBad epoch %d' % bad_epochs)
178 |
179 | if bad_epochs > self.hparams.num_bad_epochs:
180 | break
181 |
182 | except KeyboardInterrupt:
183 | logger.log('-' * 89)
184 | logger.log('Exiting from training early')
185 |
186 | logger.log("time per training epoch: " + str(np.mean(times)))
187 | return best_state_dict, best_val_perf, test_perf
188 |
189 | def evaluate(self, eval_loader, device):
190 | self.eval()
191 | with torch.no_grad():
192 | perf = compute_metrics(eval_loader, self.inference, self.hparams.v_size, device)
193 | self.train()
194 | return perf
195 |
196 | def run_test(self):
197 | device = torch.device('cuda' if self.hparams.cuda else 'cpu')
198 | _, val_loader, test_loader = self.data.get_loaders(self.hparams.batch_size,
199 | self.hparams.num_workers, shuffle_train=True, get_test=True)
200 | val_perf = self.evaluate(val_loader, device)
201 | test_perf = self.evaluate(test_loader, device)
202 | return val_perf, test_perf
203 |
204 | def load(self):
205 | device = torch.device('cuda' if self.hparams.cuda else 'cpu')
206 | checkpoint = torch.load(self.hparams.save_path) if self.hparams.cuda \
207 | else torch.load(self.hparams.save_path,
208 | map_location=torch.device('cpu'))
209 | if checkpoint['hparams'].cuda and not self.hparams.cuda:
210 | checkpoint['hparams'].cuda = False
211 | self.hparams = checkpoint['hparams']
212 | self.define_parameters()
213 | self.load_state_dict(checkpoint['state_dict'])
214 | self.to(device)
215 |
216 | def flag_hparams(self):
217 | flags = '%s' % (self.hparams.model_name)
218 | for hparam in vars(self.hparams):
219 | val = getattr(self.hparams, hparam)
220 | if str(val) == 'False':
221 | continue
222 | elif str(val) == 'True':
223 | flags += ' --%s' % (hparam)
224 | elif str(hparam) in {'model_name', 'data_path', 'num_runs',
225 | 'auto_repar', 'save_path'}:
226 | continue
227 | else:
228 | flags += ' --%s %s' % (hparam, val)
229 | return flags
230 |
231 | @staticmethod
232 | def get_general_hparams_grid():
233 | grid = OrderedDict({
234 | 'seed': list(range(100000)),
235 | 'lr': [0.003, 0.001, 0.0005, 0.0001],
236 | 'clip': [1, 5, 10],
237 | 'batch_size': [32, 64, 128],
238 | 'init': [0, 0.5, 0.1, 0.05, 0.01],
239 | })
240 | return grid
241 |
242 | @staticmethod
243 | def get_general_argparser():
244 | parser = argparse.ArgumentParser()
245 |
246 | parser.add_argument('model_name', type=str)
247 | parser.add_argument('--data_name', type=str, default='moons',
248 | choices=['moons', 'gaussian', 'amazon', 'celeba', 'pdbbind', 'bindingdb'],
249 | help='name of dataset [%(default)d]')
250 | parser.add_argument('--amazon_cat', type=str, default='toys',
251 | choices=['toys', 'furniture', 'gear', 'carseats', 'bath', 'health', 'diaper', 'bedding',
252 | 'safety', 'feeding', 'apparel', 'media'],
253 | help='category of amazon baby registry dataset [%(default)d]')
254 | parser.add_argument('--root_path', type=str,
255 | default='./')
256 | parser.add_argument('--train', action='store_true',
257 | help='train a model?')
258 | parser.add_argument('--auto_repar', action='store_true',
259 | help='use auto parameterization?')
260 |
261 | parser.add_argument('--v_size', type=int, default=30,
262 | help='size of ground set [%(default)d]')
263 | parser.add_argument('--s_size', type=int, default=10,
264 | help='size of subset [%(default)d]')
265 | parser.add_argument('--num_layers', type=int, default=2,
266 | help='num layers [%(default)d]')
267 |
268 | parser.add_argument('--batch_size', type=int, default=4,
269 | help='batch size [%(default)d]')
270 | parser.add_argument('--lr', type=float, default=0.0001,
271 | help='initial learning rate [%(default)g]')
272 | parser.add_argument("--weight_decay", type=float, default=1e-5,
273 | help='weight decay rate [%(default)g]')
274 | parser.add_argument('--init', type=float, default=0.05,
275 | help='unif init range (default if 0) [%(default)g]')
276 | parser.add_argument('--clip', type=float, default=10,
277 | help='gradient clipping [%(default)g]')
278 | parser.add_argument('--epochs', type=int, default=100,
279 | help='max number of epochs [%(default)d]')
280 | parser.add_argument('--num_runs', type=int, default=1,
281 | help='num random runs (not random if 1) '
282 | '[%(default)d]')
283 |
284 | parser.add_argument('--num_bad_epochs', type=int, default=6,
285 | help='num indulged bad epochs [%(default)d]')
286 | parser.add_argument('--num_workers', type=int, default=2,
287 | help='num dataloader workers [%(default)d]')
288 | parser.add_argument('--cuda', action='store_true',
289 | help='use CUDA?')
290 | parser.add_argument('--seed', type=int, default=50971,
291 | help='random seed [%(default)d]')
292 |
293 | return parser
294 |
--------------------------------------------------------------------------------
/model/celebaCNN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class celebaCNN(nn.Sequential):
5 | def __init__(self):
6 | super(celebaCNN, self).__init__()
7 |
8 | in_ch = [3] + [32,64,128]
9 | kernels = [3,4,5]
10 | strides = [2,2,2]
11 | layer_size = 3
12 | self.conv = nn.ModuleList([nn.Conv2d(in_channels = in_ch[i],
13 | out_channels = in_ch[i+1],
14 | kernel_size = kernels[i],
15 | stride = strides[i]) for i in range(layer_size)])
16 | self.conv = self.conv.double()
17 | self.fc1 = nn.Linear(128, 256)
18 |
19 | def _forward_features(self, x):
20 | for l in self.conv:
21 | x = F.relu(l(x))
22 | x = F.adaptive_max_pool2d(x, output_size=1)
23 | return x
24 |
25 | def forward(self, v):
26 | v = self._forward_features(v.double())
27 | v = v.view(v.size(0), -1)
28 | v = self.fc1(v.float())
29 | return v
--------------------------------------------------------------------------------
/model/deepDTA.py:
--------------------------------------------------------------------------------
1 | # reference: https://github.com/mims-harvard/TDC/blob/main/examples/multi_pred/dti_dg/domainbed/networks.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | class Identity(nn.Module):
8 | """An identity layer"""
9 | def __init__(self):
10 | super(Identity, self).__init__()
11 |
12 | def forward(self, x):
13 | return x
14 |
15 | class MLP(nn.Module):
16 | """Just an MLP"""
17 | def __init__(self, n_inputs, n_outputs, hparams):
18 | super(MLP, self).__init__()
19 | self.input = nn.Linear(n_inputs, hparams['mlp_width'])
20 | self.dropout = nn.Dropout(hparams['mlp_dropout'])
21 | self.hiddens = nn.ModuleList([
22 | nn.Linear(hparams['mlp_width'], hparams['mlp_width'])
23 | for _ in range(hparams['mlp_depth']-2)])
24 | self.output = nn.Linear(hparams['mlp_width'], n_outputs)
25 | self.n_outputs = n_outputs
26 |
27 | def forward(self, x):
28 | x = self.input(x)
29 | x = self.dropout(x)
30 | x = F.relu(x)
31 | for hidden in self.hiddens:
32 | x = hidden(x)
33 | x = self.dropout(x)
34 | x = F.relu(x)
35 | x = self.output(x)
36 | return x
37 |
38 | class CNN(nn.Sequential):
39 | def __init__(self, encoding):
40 | super(CNN, self).__init__()
41 | if encoding == 'drug':
42 | in_ch = [41] + [32,64,96]
43 | kernels = [4,6,8]
44 | layer_size = 3
45 | self.conv = nn.ModuleList([nn.Conv1d(in_channels = in_ch[i],
46 | out_channels = in_ch[i+1],
47 | kernel_size = kernels[i]) for i in range(layer_size)])
48 | self.conv = self.conv.double()
49 | n_size_d = self._get_conv_output((41, 100))
50 | self.fc1 = nn.Linear(n_size_d, 256)
51 | elif encoding == 'protein':
52 | in_ch = [20] + [32,64,96]
53 | kernels = [4,8,12]
54 | layer_size = 3
55 | self.conv = nn.ModuleList([nn.Conv1d(in_channels = in_ch[i],
56 | out_channels = in_ch[i+1],
57 | kernel_size = kernels[i]) for i in range(layer_size)])
58 | self.conv = self.conv.double()
59 | n_size_p = self._get_conv_output((20, 1000))
60 | self.fc1 = nn.Linear(n_size_p, 256)
61 |
62 | def _get_conv_output(self, shape):
63 | bs = 1
64 | input = Variable(torch.rand(bs, *shape))
65 | output_feat = self._forward_features(input.double())
66 | n_size = output_feat.data.view(bs, -1).size(1)
67 | return n_size
68 |
69 | def _forward_features(self, x):
70 | for l in self.conv:
71 | x = F.relu(l(x))
72 | x = F.adaptive_max_pool1d(x, output_size=1)
73 | return x
74 |
75 | def forward(self, v):
76 | v = self._forward_features(v.double())
77 | v = v.view(v.size(0), -1)
78 | v = self.fc1(v.float())
79 | return v
80 |
81 | class DeepDTA_Encoder(nn.Sequential):
82 | def __init__(self):
83 | super(DeepDTA_Encoder, self).__init__()
84 | self.input_dim_drug = 256
85 | self.input_dim_protein = 256
86 |
87 | self.model_drug = CNN('drug')
88 | self.model_protein = CNN('protein')
89 | self.predictor = nn.Linear(self.input_dim_drug + self.input_dim_protein, 256)
90 |
91 | def forward(self, V):
92 | v_D, v_P = V
93 | # each encoding
94 | v_D = self.model_drug(v_D)
95 | v_P = self.model_protein(v_P)
96 | # concatenate and output feature
97 | v_f = torch.cat((v_D, v_P), 1)
98 | v_f = self.predictor(v_f)
99 | return v_f
100 |
--------------------------------------------------------------------------------
/model/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datetime
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from model.acnn import ACNN
7 | from utils.config import ACNN_CONFIG
8 | from model.celebaCNN import celebaCNN
9 | from model.deepDTA import DeepDTA_Encoder
10 | from utils.pytorch_helper import FF, normal_cdf
11 |
12 |
13 | class SetFuction(nn.Module):
14 | def __init__(self, params):
15 | super(SetFuction, self).__init__()
16 | self.params = params
17 | self.dim_feature = 256
18 |
19 | self.init_layer = self.define_init_layer()
20 | self.ff = FF(self.dim_feature, 500, 1, self.params.num_layers)
21 |
22 | def define_init_layer(self):
23 | data_name = self.params.data_name
24 | if data_name == 'moons':
25 | return nn.Linear(2, self.dim_feature)
26 | elif data_name == 'gaussian':
27 | return nn.Linear(2, self.dim_feature)
28 | elif data_name == 'amazon':
29 | return nn.Linear(768, self.dim_feature)
30 | elif data_name == 'celeba':
31 | return celebaCNN()
32 | elif data_name == 'pdbbind':
33 | return ACNN(hidden_sizes=ACNN_CONFIG['hidden_sizes'],
34 | weight_init_stddevs=ACNN_CONFIG['weight_init_stddevs'],
35 | dropouts=ACNN_CONFIG['dropouts'],
36 | features_to_use=ACNN_CONFIG['atomic_numbers_considered'],
37 | radial=ACNN_CONFIG['radial'])
38 | elif data_name == 'bindingdb':
39 | return DeepDTA_Encoder()
40 | else:
41 | raise ValueError("invalid dataset...")
42 |
43 | def MC_sampling(self, q, M):
44 | """
45 | Bernoulli sampling using q as parameters.
46 | Args:
47 | q: paramter of Bernoulli distribution
48 | M: number of samples
49 |
50 | Returns:
51 | Sampled subsets F(S+i), F(S)
52 |
53 | """
54 | bs, vs = q.shape
55 | q = q.reshape(bs, 1, 1, vs).expand(bs, M, vs, vs)
56 | sample_matrix = torch.bernoulli(q)
57 |
58 | mask = torch.cat([torch.eye(vs, vs).unsqueeze(0) for _ in range(M)], dim=0).unsqueeze(0).to(q.device)
59 | matrix_0 = sample_matrix * (1 - mask)
60 | matrix_1 = matrix_0 + mask
61 | return matrix_1, matrix_0
62 |
63 | def mean_field_iteration(self, V, subset_i, subset_not_i):
64 | F_1 = self.F_S(V, subset_i, fpi=True).squeeze(-1)
65 | F_0 = self.F_S(V, subset_not_i, fpi=True).squeeze(-1)
66 | q = torch.sigmoid((F_1 - F_0).mean(1))
67 | return q
68 |
69 | def cross_entropy(self, q, S, neg_S):
70 | loss = - torch.sum((S * torch.log(q + 1e-12) + (1 - S) * torch.log(1 - q + 1e-12)) * neg_S, dim=-1)
71 | return loss.mean()
72 |
73 | def forward(self, V, S, neg_S, rec_net): # return cross-entropy loss
74 | if self.params.mode == 'diffMF':
75 | bs, vs = V.shape[:2]
76 | q = .5 * torch.ones(bs, vs).to(V.device)
77 | else:
78 | # mode == 'ind' or 'copula'
79 | q = rec_net.get_vardist(V, S.shape[0]).detach() # notice the detach here
80 |
81 | for i in range(self.params.RNN_steps):
82 | sample_matrix_1, sample_matrix_0 = self.MC_sampling(q, self.params.num_samples)
83 | q = self.mean_field_iteration(V, sample_matrix_1, sample_matrix_0)
84 |
85 | loss = self.cross_entropy(q, S, neg_S)
86 | return loss
87 |
88 | def F_S(self, V, subset_mat, fpi=False):
89 | if fpi:
90 | # to fix point iteration (aka mean-field iteration)
91 | fea = self.init_layer(V).reshape(subset_mat.shape[0], 1, -1, self.dim_feature)
92 | else:
93 | # to encode variational dist
94 | fea = self.init_layer(V).reshape(subset_mat.shape[0], -1, self.dim_feature)
95 | fea = subset_mat @ fea
96 | fea = self.ff(fea)
97 | return fea
98 |
99 |
100 | class RecNet(nn.Module):
101 | def __init__(self, params):
102 | super(RecNet, self).__init__()
103 | self.params = params
104 | self.dim_feature = 256
105 | num_layers = self.params.num_layers
106 |
107 | self.init_layer = self.define_init_layer()
108 | self.ff = FF(self.dim_feature, 500, 500, num_layers - 1 if num_layers > 0 else 0)
109 | self.h_to_mu = nn.Linear(500, 1)
110 | if self.params.mode == 'copula':
111 | self.h_to_std = nn.Linear(500, 1)
112 | self.h_to_U = nn.ModuleList([nn.Linear(500, 1) for i in range(self.params.rank)])
113 |
114 | def define_init_layer(self):
115 | data_name = self.params.data_name
116 | if data_name == 'moons':
117 | return nn.Linear(2, self.dim_feature)
118 | elif data_name == 'gaussian':
119 | return nn.Linear(2, self.dim_feature)
120 | elif data_name == 'amazon':
121 | return nn.Linear(768, self.dim_feature)
122 | elif data_name == 'celeba':
123 | return celebaCNN()
124 | elif data_name == 'pdbbind':
125 | return ACNN(hidden_sizes=ACNN_CONFIG['hidden_sizes'],
126 | weight_init_stddevs=ACNN_CONFIG['weight_init_stddevs'],
127 | dropouts=ACNN_CONFIG['dropouts'],
128 | features_to_use=ACNN_CONFIG['atomic_numbers_considered'],
129 | radial=ACNN_CONFIG['radial'])
130 | elif data_name == 'bindingdb':
131 | return DeepDTA_Encoder()
132 | else:
133 | raise ValueError("invalid dataset...")
134 |
135 | def encode(self, V, bs):
136 | """
137 |
138 | Args:
139 | V: the ground set. [batch_size, v_size, fea_dim]
140 | bs: batch_size
141 |
142 | Returns:
143 | ber: predicted probabilities. [batch_size, v_size]
144 | std: the diagonal matrix D [batch_size, v_size]
145 | u_perturbation: the low rank perturbation matrix [batch_size, v_size, rank]
146 |
147 | """
148 | fea = self.init_layer(V).reshape(bs, -1, self.dim_feature)
149 | h = torch.relu(self.ff(fea))
150 | ber = torch.sigmoid(self.h_to_mu(h)).squeeze(-1) # [batch_size, v_size]
151 |
152 | if self.params.mode == 'copula':
153 | std = F.softplus(self.h_to_std(h)).squeeze(-1) # [batch_size, v_size]
154 | rs = []
155 | for i in range(self.params.rank):
156 | rs.append(torch.tanh(self.h_to_U[i](h)))
157 | u_perturbation = torch.cat(rs, -1) # [batch_size, v_size, rank]
158 |
159 | return ber, std, u_perturbation
160 | return ber, None, None
161 |
162 | def MC_sampling(self, ber, std, u_pert, M):
163 | """
164 | Sampling using CopulaBernoulli
165 |
166 | Args:
167 | ber: location parameter (0, 1) [batch_size, v_size]
168 | std: standard deviation (0, +infinity) [batch_size, v_size]
169 | u_pert: lower rank perturbation (-1, 1) [batch_size, v_size, rank]
170 | M: number of MC approximation
171 |
172 | Returns:
173 | Sampled subsets
174 | """
175 | bs, vs = ber.shape
176 |
177 | if self.params.mode == 'copula':
178 | eps = torch.randn((bs, M, vs)).to(ber.device)
179 | eps_corr = torch.randn((bs, M, self.params.rank, 1)).to(ber.device)
180 | g = eps * std.unsqueeze(1) + torch.matmul(u_pert.unsqueeze(1), eps_corr).squeeze(-1)
181 | u = normal_cdf(g, 0, 1)
182 | else:
183 | # mode == 'ind'
184 | u = torch.rand((bs, M, vs)).to(ber.device)
185 |
186 | ber = ber.unsqueeze(1)
187 | l = torch.log(ber + 1e-12) - torch.log(1 - ber + 1e-12) + \
188 | torch.log(u + 1e-12) - torch.log(1 - u + 1e-12)
189 |
190 | prob = torch.sigmoid(l / self.params.tau)
191 | r = torch.bernoulli(prob) # binary vector
192 | s = prob + (r - prob).detach() # straight through estimator
193 | return s
194 |
195 | def cal_elbo(self, V, sample_mat, set_func, q):
196 | f_mt = set_func.F_S(V, sample_mat).squeeze(-1).mean(-1)
197 | entropy = - torch.sum(q * torch.log(q + 1e-12) + (1 - q) * torch.log(1 - q + 1e-12), dim=-1)
198 | elbo = f_mt + entropy
199 | return elbo.mean()
200 |
201 | def forward(self, V, set_func, bs): # return negative ELBO
202 | ber, std, u_perturbation = self.encode(V, bs)
203 | sample_mat = self.MC_sampling(ber, std, u_perturbation, self.params.num_samples)
204 | elbo = self.cal_elbo(V, sample_mat, set_func, ber)
205 | return -elbo
206 |
207 | def get_vardist(self, V, bs):
208 | fea = self.init_layer(V).reshape(bs, -1, self.dim_feature)
209 | h = torch.relu(self.ff(fea))
210 | ber = torch.sigmoid(self.h_to_mu(h)).squeeze(-1)
211 | return ber
212 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | MOONS_CONFIG = {
5 | 'data_name': 'moons',
6 | 'v_size': 100,
7 | 's_size': 10,
8 | 'batch_size': 128
9 | }
10 |
11 | GAUSSIAN_CONFIG = {
12 | 'data_name': 'gaussian',
13 | 'v_size': 100,
14 | 's_size': 10,
15 | 'batch_size': 128
16 | }
17 |
18 | AMAZON_CONFIG = {
19 | 'data_name': 'amazon',
20 | 'v_size': 30,
21 | 'batch_size': 128
22 | }
23 |
24 | CELEBA_CONFIG = {
25 | 'data_name': 'celeba',
26 | 'v_size': 8,
27 | 'batch_size': 128
28 | }
29 |
30 | PDBBIND_CONFIG = {
31 | 'data_name': 'pdbbind',
32 | 'v_size': 30,
33 | 's_size': 5,
34 | 'batch_size': 32
35 | }
36 |
37 | BINDINGDB_CONFIG = {
38 | 'data_name': 'bindingdb',
39 | 'v_size': 300,
40 | 's_size': 15,
41 | 'batch_size': 4
42 | }
43 |
44 | ACNN_CONFIG = {
45 | 'hidden_sizes': [32, 32, 16],
46 | 'weight_init_stddevs': [1. / float(np.sqrt(32)), 1. / float(np.sqrt(32)),
47 | 1. / float(np.sqrt(16)), 0.01],
48 | 'dropouts': [0., 0., 0.],
49 | 'atomic_numbers_considered': torch.tensor([
50 | 1., 6., 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35., 53.]),
51 | 'radial': [[12.0], [0.0, 4.0, 8.0], [4.0]],
52 | }
--------------------------------------------------------------------------------
/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils.pytorch_helper import set_value_according_index, move_to_device
4 |
5 | def compute_metrics(loader, infer_func, v_size, device):
6 | jc_list = []
7 | for batch_num, batch in enumerate(loader):
8 | V_set, S_set = move_to_device(batch, device)
9 |
10 | q = infer_func(V_set, S_set.shape[0])
11 | _, idx = torch.topk(q, S_set.shape[-1], dim=1, largest=True)
12 |
13 | pre_list = []
14 | for i in range(len(idx)):
15 | pre_mask = torch.zeros([S_set.shape[-1]]).to(device)
16 | ids = idx[i][:int(torch.sum(S_set[i]).item())]
17 | pre_mask[ids] = 1
18 | pre_list.append(pre_mask.unsqueeze(0))
19 | pre_mask = torch.cat(pre_list, dim=0)
20 | true_mask = S_set
21 |
22 | intersection = true_mask * pre_mask
23 | union = true_mask + pre_mask - intersection
24 | jc = intersection.sum(dim=-1) / union.sum(dim=-1)
25 | jc_list.append(jc)
26 |
27 | jca = torch.cat(jc_list, 0).mean(0).item()
28 | return jca * 100
29 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import statistics as stat
4 | import sys
5 |
6 | class Logger(object):
7 |
8 | def __init__(self, log_path, on=True):
9 | self.log_path = log_path
10 | self.on = on
11 | # self.on = False
12 |
13 | if self.on:
14 | while os.path.isfile(self.log_path):
15 | self.log_path += '+'
16 |
17 | def log(self, string, newline=True):
18 | if self.on:
19 | with open(self.log_path, 'a') as logf:
20 | logf.write(string)
21 | if newline: logf.write('\n')
22 |
23 | sys.stdout.write(string)
24 | if newline: sys.stdout.write('\n')
25 | sys.stdout.flush()
26 |
27 | def log_perfs(self, perfs):
28 | valid_perfs = [perf for perf in perfs if not math.isinf(perf)]
29 | best_perf = max(valid_perfs)
30 | self.log('-' * 89)
31 | self.log('%d perfs: %s' % (len(perfs), str(perfs)))
32 | self.log('perf max: %g' % best_perf)
33 | self.log('perf min: %g' % min(valid_perfs))
34 | self.log('perf avg: %g' % stat.mean(valid_perfs))
35 | self.log('perf std: %g' % (stat.stdev(valid_perfs)
36 | if len(valid_perfs) > 1 else 0.0))
37 | self.log('(excluded %d out of %d runs that produced -inf)' %
38 | (len(perfs) - len(valid_perfs), len(perfs)))
39 | self.log('-' * 89)
40 |
41 | def log_test_perfs(self, perfs, params):
42 | valid_perfs = [perf for perf in perfs if not math.isinf(perf)]
43 | mean = stat.mean(valid_perfs)
44 | std = stat.stdev(valid_perfs if len(valid_perfs) > 1 else 0.0)
45 |
46 | if self.on:
47 | name = params.root_path + 'EquiVSet.log'
48 | with open(name, 'a') as logf:
49 | logf.write( params.model_name + ': ' + '\n')
50 | logf.write(str(perfs) + '\n')
51 | string = f'avg: {mean} ' + f'std: {std}\n'
52 | logf.write(string)
53 | self.log('-' * 89)
54 | self.log('%d perfs: %s' % (len(perfs), str(perfs)))
55 | self.log('perf avg: %g' % mean)
56 | self.log('perf std: %g' % std)
57 | self.log('-' * 89)
--------------------------------------------------------------------------------
/utils/pytorch_helper.py:
--------------------------------------------------------------------------------
1 | import dgl
2 | import math
3 | import torch
4 | import pickle
5 | import torch.nn as nn
6 |
7 |
8 | def set_value_according_index(tensor, idx, value):
9 | mask_val = torch.ones(idx.shape).to(tensor.device) * value
10 | tensor.scatter_(1, idx, mask_val) # fill the values along dimension 1
11 | return tensor
12 |
13 | def normal_cdf(value, loc, scale):
14 | return 0.5 * (1 + torch.erf( (value - loc) / (scale * math.sqrt(2)) ))
15 |
16 | def get_init_function(init_value):
17 | def init_function(m):
18 | if init_value > 0.:
19 | if hasattr(m, 'weight'):
20 | m.weight.data.uniform_(-init_value, init_value)
21 | if hasattr(m, 'bias'):
22 | m.bias.data.fill_(0.)
23 |
24 | return init_function
25 |
26 |
27 | def move_to_device(obj, device):
28 | if torch.is_tensor(obj):
29 | return obj.to(device)
30 | elif isinstance(obj, dgl.DGLGraph):
31 | return obj.to(device)
32 | elif isinstance(obj, dict):
33 | res = {}
34 | for k, v in obj.items():
35 | res[k] = move_to_device(v, device)
36 | return res
37 | elif isinstance(obj, list):
38 | res = []
39 | for v in obj:
40 | res.append(move_to_device(v, device))
41 | return res
42 | elif isinstance(obj, tuple):
43 | res = ()
44 | for v in obj:
45 | res += (move_to_device(v, device),)
46 | return res
47 | else:
48 | raise TypeError("Invalid type for move_to_device")
49 |
50 | class FF(nn.Module):
51 | def __init__(self, dim_input, dim_hidden, dim_output, num_layers,
52 | activation='relu', dropout_rate=0, layer_norm=False,
53 | residual_connection=False):
54 | super().__init__()
55 |
56 | assert num_layers >= 0 # 0 = Linear
57 | if num_layers > 0:
58 | assert dim_hidden > 0
59 | if residual_connection:
60 | assert dim_hidden == dim_input
61 |
62 | self.residual_connection = residual_connection
63 | self.stack = nn.ModuleList()
64 | for l in range(num_layers):
65 | layer = []
66 |
67 | if layer_norm:
68 | layer.append(nn.LayerNorm(dim_input if l == 0 else dim_hidden))
69 |
70 | layer.append(nn.Linear(dim_input if l == 0 else dim_hidden,
71 | dim_hidden))
72 | layer.append({'tanh': nn.Tanh(), 'relu': nn.ReLU()}[activation])
73 |
74 | if dropout_rate > 0:
75 | layer.append(nn.Dropout(dropout_rate))
76 |
77 | self.stack.append(nn.Sequential(*layer))
78 |
79 | self.out = nn.Linear(dim_input if num_layers < 1 else dim_hidden,
80 | dim_output)
81 |
82 | def forward(self, x):
83 | for layer in self.stack:
84 | x = x + layer(x) if self.residual_connection else layer(x)
85 | return self.out(x)
86 |
87 | def read_from_pickle(filename):
88 | filename = filename + '.pickle'
89 | with open(filename, 'rb') as f:
90 | x = pickle.load(f)
91 | return x
92 |
93 | def find_not_in_set(U,S):
94 | Ind = torch.ones(U.shape[0],dtype=bool)
95 | Ind[S] = False
96 | return U[Ind]
97 |
--------------------------------------------------------------------------------