├── .gitignore
├── LICENSE
├── README.md
├── assets
├── delaney_solubility
│ ├── 19.png
│ ├── 3.png
│ └── 54.png
├── tox21_pyridine
│ ├── 13_bayes.png
│ ├── 21_bayes.png
│ └── 6_bayes.png
└── tox21_srmmp
│ ├── 27_bayes.png
│ ├── 2_bayes.png
│ └── 3_bayes.png
├── experiments
├── delaney
│ ├── plot.py
│ └── train.py
└── tox21
│ ├── calc_prcauc_with_seeds.py
│ ├── calc_prcauc_with_seeds.sh
│ ├── data.py
│ ├── plot_precision_recall.py
│ ├── train_few_with_seeds.sh
│ ├── train_tox21.py
│ ├── utils.py
│ ├── visualize-saliency-pyrigine.ipynb
│ └── visualize-saliency-tox21.ipynb
├── models
├── __init__.py
├── ggnn_drop.py
├── nfp_drop.py
└── predictor.py
├── requirements.txt
└── saliency
├── __init__.py
├── calculator
├── __init__.py
├── base_calculator.py
├── gradient_calculator.py
├── integrated_gradients_calculator.py
└── occlusion_calculator.py
└── visualizer
├── __init__.py
├── base_visualizer.py
└── smiles_visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 | .idea/
103 | experiments/tox21/results/
104 | experiments/delaney/results/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Preferred Networks, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bayesgrad
2 | BayesGrad: Explaining Predictions of Graph Convolutional Networks
3 |
4 | The paper is available on arXiv, [https://arxiv.org/abs/1807.01985](https://arxiv.org/abs/1807.01985).
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | From left: tox21 pyridine (C5H5N), tox21 SR-MMP, delaney solubility visualization.
13 |
14 |
15 | ## Citation
16 | If you find our work useful in your research, please consider citing:
17 |
18 | ```
19 | @article{akita2018bayesgrad,
20 | title={BayesGrad: Explaining Predictions of Graph Convolutional Networks},
21 | author={Akita, Hirotaka and Nakago, Kosuke and Komatsu, Tomoki and Sugawara, Yohei and Maeda, Shin-ichi and Baba, Yukino and Kashima, Hisashi},
22 | journal={arXiv preprint arXiv:1807.01985},
23 | year={2018}
24 | }
25 | ```
26 |
27 | ## Setup
28 |
29 | [Chainer Chemistry](https://github.com/pfnet-research/chainer-chemistry) [1] is used in our code.
30 | It is an extension library for deep learning framework [Chainer](https://github.com/chainer/chainer) [2],
31 | and it supports several graph-convolutional neural network together with chemical dataset management.
32 |
33 | The experiment is executed under following environment:
34 |
35 | - OS: Linux
36 | - python: 3.6.1
37 | - conda version: 4.4.4
38 |
39 | ```bash
40 | conda create -n bayesgrad python=3.6.1
41 | source activate bayesgrad
42 | pip install chainer==4.2.0
43 | pip install chainer-chemistry==0.4.0
44 | conda install -c rdkit rdkit==2017.09.3.0
45 | pip install matplotlib==2.2.2
46 | pip install future==0.16.0
47 | pip install cairosvg==2.1.3
48 | pip install ipython==5.1.0
49 | ```
50 |
51 | [Note]
52 | Please install specified python version & rdkit version.
53 | Latest python version and rdkit may not work well as discussed [here](https://github.com/pfnet-research/chainer-chemistry/issues/138).
54 | If you face error try
55 | ```bash
56 | conda install libgcc
57 | ```
58 |
59 | If you want to use GPU, please install `cupy` as well.
60 | ```bash
61 | # XX should be CUDA version (80, 90 or 91)
62 | pip install cupy-cudaXX==4.2.0
63 | ```
64 |
65 | ## Experiments
66 |
67 | Each experiment can be executed as follows.
68 |
69 | ### Tox21 Pyridine experiment
70 | Experiments described in Section 4.1 in the paper. Tox21 [3] dataset is used.
71 |
72 | ```bash
73 | cd experiments/tox21
74 | ```
75 |
76 | #### Training with all train data, plot precision-recall curve
77 |
78 | Set `-g -1` to use CPU, `-g 0` to use GPU.
79 | ```bash
80 | python train_tox21.py --iterator-type=balanced --label=pyridine --method=ggnndrop --epoch=50 --unit-num=16 --n-layers=1 -b 32 --conv-layers=4 --num-train=-1 --dropout-ratio=0.25 --out=results/ggnndrop_pyridine -g 0
81 | python plot_precision_recall.py --dirpath=results/ggnndrop_pyridine
82 | ```
83 |
84 | #### Visualization with trained model
85 | See `visualize-saliency-pyrigine.ipynb`.
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 | Our method successfully focuses on pyridine (C5H5N) substructures.
94 |
95 | #### Training 30 different models with few train data, calculate RPC-AUC score
96 | Argument: `-1` to use CPU, `0` to use GPU.
97 |
98 | Note that this experiment takes time (took around 2.5 hours with GPU in our environment),
99 | since it trains 30 different models.
100 |
101 | ```bash
102 | bash -x ./train_few_with_seeds.sh 0
103 | bash -x ./calc_prcauc_with_seeds.sh 0
104 | ```
105 |
106 | Then see `results/ggnndrop_pyridin_numtrain1000-seed0-29/prcauc_stats_absolute_0.15.csv` for the results.
107 |
108 | ### Tox21 SR-MMP experiment
109 | Experiments described in Section 4.2 in the paper. Tox21 [3] dataset is used.
110 |
111 | ```bash
112 | cd experiments/tox21
113 | ```
114 |
115 | #### Training the model
116 | Set `-g -1` to use CPU, `-g 0` to use GPU.
117 | ```bash
118 | python train_tox21.py --iterator-type=balanced --label=SR-MMP --method=nfpdrop --epoch=200 --unit-num=16 --n-layers=1 -b 32 --conv-layers=4 --num-train=-1 --dropout-ratio=0.25 --out=results/nfpdrop_srmmp -g 0
119 | ```
120 |
121 | #### Visualization of tox21 data & Tyrphostin 9 with trained model
122 | See `visualize-saliency-tox21.ipynb`.
123 |
124 | Jupyter notebook interactive visualization:
125 |
126 |
127 |
128 |
129 | Several picked images:
130 |
131 |
132 |
133 |
134 |
135 |
136 | Toxicity mechanism is still in an active research topic and it is difficult to quantitatively analyze its results.
137 | We hope these visualization helps to analyze and establish further knowledge about toxicity.
138 |
139 | ### Solubility experiment
140 | Experiment done in Section 4.3 in the paper. ESOL [4] dataset (provided by MoleculeNet [5]) is used.
141 |
142 | ```bash
143 | cd experiments/delaney
144 | ```
145 |
146 | #### Training the model
147 | Set `-g -1` to use CPU, `-g 0` to use GPU.
148 |
149 | ```bash
150 | python train.py -e 100 -n 3 --method=nfpdrop -g 0
151 | ```
152 |
153 | #### Visualization with trained model
154 | ```bash
155 | python plot.py --dirpath=./results/nfpdrop_M30_conv3_unit32_b32
156 | ```
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 | Red color represents these atoms are hydrophilic, and blue color represents hydrophobic.
165 | Above figure is consistent with fundamental physicochemical knowledge as explained in the paper.
166 |
167 | ## Saliency Calculation
168 |
169 | Although only results of gradient method [6, 7, 8] are reported in the paper,
170 | this repository contains saliency calculation code for several other algorithms as well.
171 |
172 | We can apply SmoothGrad [8] and/or BayesGrad (Ours) into following algorithms.
173 |
174 | - Vanilla Gradients [6, 7]
175 | - Integrated Gradients [9]
176 | - Occlusion [10]
177 |
178 | The code design is inspired by [PAIR-code/saliency](https://github.com/PAIR-code/saliency).
179 |
180 | ## License
181 |
182 | Our code is released under MIT License (see [LICENSE](https://github.com/pfnet-research/bayesgrad/blob/master/LICENSE) file for details).
183 |
184 | ## Reference
185 |
186 | [1] pfnet research. chainer-chemistry https://github.com/pfnet-research/chainer-chemistry
187 |
188 | [2] Seiya Tokui, Kenta Oono, Shohei Hido, and Justin Clayton. Chainer: a next-generation open source framework for deep learning. In *Proceedings of Workshop on Machine Learning Systems (LearningSys) in Advances in Neural Information Processing System (NIPS) 28*, 2015.
189 |
190 | [3] Ruili Huang, Menghang Xia, Dac-Trung Nguyen, Tongan Zhao, Srilatha Sakamuru, Jinghua Zhao, Sampada A Shahane, Anna Rossoshek, and Anton Simeonov. Tox21challenge to build predictive models of nuclear receptor and stress response pathways as mediated by exposure to environmental chemicals and drugs. Frontiers in Environmental Science, 3:85, 2016.
191 |
192 | [4] John S. Delaney. Esol: Estimating aqueous solubility directly from molecular structure. Journal of Chemical Information and Computer Sciences, 44(3):1000{1005,2004. PMID: 15154768.
193 |
194 | [5] Zhenqin Wu, Bharath Ramsundar, Evan N. Feinberg, Joseph Gomes, Caleb Geniesse, Aneesh S. Pappu, Karl Leswing, Vijay Pande, MoleculeNet: A Benchmark for Molecular Machine Learning, arXiv preprint, arXiv: 1703.00564, 2017.
195 |
196 | [6] Dumitru Erhan, Yoshua Bengio, Aaron Courville, Pascal Vincent. Visualizing Higher-Layer Features of a Deep Network. 2009.
197 |
198 | [7] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Visualising image classication models and saliency maps. arXiv preprint arXiv:1312.6034, 2013.
199 |
200 | [8] Daniel Smilkov, Nikhil Thorat, Been Kim, Fernanda Viegas, and Martin Wattenberg. SmoothGrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825, 2017.
201 |
202 | [9] Mukund Sundararajan, Ankur Taly, and Qiqi Yan. Axiomatic attribution for deep networks. In Doina Precup and Yee Whye Teh (eds.),
203 | Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pp. 3319–3328, International Convention Centre, Sydney, Australia, 06–11 Aug 2017. PMLR.
204 | URL http://proceedings.mlr.press/v70/sundararajan17a.html.
205 |
206 | [10] Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In
207 | European conference on computer vision, pp. 818–833. Springer, 2014.
208 |
--------------------------------------------------------------------------------
/assets/delaney_solubility/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/delaney_solubility/19.png
--------------------------------------------------------------------------------
/assets/delaney_solubility/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/delaney_solubility/3.png
--------------------------------------------------------------------------------
/assets/delaney_solubility/54.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/delaney_solubility/54.png
--------------------------------------------------------------------------------
/assets/tox21_pyridine/13_bayes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_pyridine/13_bayes.png
--------------------------------------------------------------------------------
/assets/tox21_pyridine/21_bayes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_pyridine/21_bayes.png
--------------------------------------------------------------------------------
/assets/tox21_pyridine/6_bayes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_pyridine/6_bayes.png
--------------------------------------------------------------------------------
/assets/tox21_srmmp/27_bayes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_srmmp/27_bayes.png
--------------------------------------------------------------------------------
/assets/tox21_srmmp/2_bayes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_srmmp/2_bayes.png
--------------------------------------------------------------------------------
/assets/tox21_srmmp/3_bayes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_srmmp/3_bayes.png
--------------------------------------------------------------------------------
/experiments/delaney/plot.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import os
5 | import sys
6 |
7 | import matplotlib
8 | matplotlib.use('agg')
9 | import matplotlib.pyplot as plt
10 |
11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
12 | from saliency.visualizer.smiles_visualizer import SmilesVisualizer
13 |
14 |
15 | def visualize(dir_path):
16 | parent_dir = os.path.dirname(dir_path)
17 | saliency_vanilla = np.load(os.path.join(dir_path, "saliency_vanilla.npy"))
18 | saliency_smooth = np.load(os.path.join(dir_path, "saliency_smooth.npy"))
19 | saliency_bayes = np.load(os.path.join(dir_path, "saliency_bayes.npy"))
20 |
21 | visualizer = SmilesVisualizer()
22 | os.makedirs(os.path.join(parent_dir, "result_vanilla"), exist_ok=True)
23 | os.makedirs(os.path.join(parent_dir, "result_smooth"), exist_ok=True)
24 | os.makedirs(os.path.join(parent_dir, "result_bayes"), exist_ok=True)
25 |
26 | test_idx = np.load(os.path.join(dir_path, "test_idx.npy"))
27 | answer = np.load(os.path.join(dir_path, "answer.npy"))
28 | output = np.load(os.path.join(dir_path, "output.npy"))
29 |
30 | smiles_all = np.load(os.path.join(parent_dir, "smiles.npy"))
31 |
32 | def calc_range(saliency):
33 | vmax = float('-inf')
34 | vmin = float('inf')
35 | for v in saliency:
36 | vmax = max(vmax, np.max(v))
37 | vmin = min(vmin, np.min(v))
38 | return vmin, vmax
39 |
40 | v_range_vanilla = calc_range(saliency_vanilla)
41 | v_range_smooth = calc_range(saliency_smooth)
42 | v_range_bayes = calc_range(saliency_bayes)
43 |
44 | def get_scaler(v_range):
45 | def scaler(saliency_):
46 | saliency = np.copy(saliency_)
47 | minv, maxv = v_range
48 | if maxv == minv:
49 | saliency = np.zeros_like(saliency)
50 | else:
51 | pos = saliency >= 0.0
52 | saliency[pos] = saliency[pos]/maxv
53 | nega = saliency < 0.0
54 | saliency[nega] = saliency[nega]/(np.abs(minv))
55 | return saliency
56 | return scaler
57 |
58 | scaler_vanilla = get_scaler(v_range_vanilla)
59 | scaler_smooth = get_scaler(v_range_smooth)
60 | scaler_bayes = get_scaler(v_range_bayes)
61 |
62 | def color(x):
63 | if x > 0:
64 | # Red for positive value
65 | return 1., 1. - x, 1. - x
66 | else:
67 | # Blue for negative value
68 | x *= -1
69 | return 1. - x, 1. - x, 1.
70 |
71 | for i, id in enumerate(test_idx):
72 | smiles = smiles_all[id]
73 | out = output[i]
74 | ans = answer[i]
75 | # legend = "t:{}, p:{}".format(ans, out)
76 | legend = ''
77 | ext = '.png' # '.svg'
78 | # visualizer.visualize(
79 | # saliency_vanilla[id], smiles, save_filepath=os.path.join(parent_dir, "result_vanilla", str(id) + ext),
80 | # visualize_ratio=1.0, legend=legend, scaler=scaler_vanilla, color_fn=color)
81 | # visualizer.visualize(
82 | # saliency_smooth[id], smiles, save_filepath=os.path.join(parent_dir, "result_smooth", str(id) + ext),
83 | # visualize_ratio=1.0, legend=legend, scaler=scaler_smooth, color_fn=color)
84 | visualizer.visualize(
85 | saliency_bayes[id], smiles, save_filepath=os.path.join(parent_dir, "result_bayes", str(id) + ext),
86 | visualize_ratio=1.0, legend=legend, scaler=scaler_bayes, color_fn=color)
87 |
88 |
89 | def plot_result(prediction, answer, save_filepath='result.png'):
90 | plt.scatter(prediction, answer, marker='.')
91 | plt.plot([-100, 100], [-100, 100], c='r')
92 | max_v = max(np.max(prediction), np.max(answer))
93 | min_v = min(np.min(prediction), np.min(answer))
94 | plt.xlim([min_v-0.1, max_v+0.1])
95 | plt.xlabel("prediction")
96 | plt.ylim([min_v-0.1, max_v+0.1])
97 | plt.ylabel("ground truth")
98 | plt.savefig(save_filepath)
99 | plt.close()
100 |
101 |
102 | def main():
103 | parser = argparse.ArgumentParser(
104 | description='Regression with own dataset.')
105 | parser.add_argument('--dirpath', '-d', type=str, default='./results/M_30_3_32_32')
106 | args = parser.parse_args()
107 | path = args.dirpath
108 | n_split = 5
109 | output = []
110 | answer = []
111 | for i in range(n_split):
112 | suffix = str(i) + "-" + str(n_split)
113 | output.append(np.load(os.path.join(path, suffix, "output.npy")))
114 | answer.append(np.load(os.path.join(path, suffix, "answer.npy")))
115 | output = np.concatenate(output)
116 | answer = np.concatenate(answer)
117 |
118 | plot_result(output, answer, save_filepath=os.path.join(path, "result.png"))
119 | for i in range(n_split):
120 | suffix = str(i) + "-" + str(n_split)
121 | print(suffix)
122 | visualize(os.path.join(path, suffix))
123 |
124 |
125 | if __name__ == '__main__':
126 | main()
127 |
128 |
--------------------------------------------------------------------------------
/experiments/delaney/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import sys
5 |
6 | import numpy as np
7 | from sklearn.preprocessing import StandardScaler
8 |
9 | import matplotlib as mpl
10 | mpl.use('Agg')
11 |
12 | from chainer import functions as F, cuda, Variable
13 | from chainer import iterators as I
14 | from chainer import optimizers as O
15 | from chainer import training
16 | from chainer.training import extensions as E
17 | from chainer.datasets import SubDataset
18 | from chainer import serializers
19 |
20 | from chainer_chemistry.dataset.converters import concat_mols
21 | from chainer_chemistry.dataset.preprocessors import preprocess_method_dict
22 | from chainer_chemistry.datasets import NumpyTupleDataset
23 | from chainer_chemistry.datasets.molnet import get_molnet_dataset
24 | from chainer_chemistry.models.prediction import Regressor
25 |
26 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
27 | from models import predictor
28 | from saliency.calculator.gradient_calculator import GradientCalculator
29 | from plot import plot_result
30 |
31 |
32 | def save_result(dataset, model, dir_path, M):
33 | regressor = Regressor(model, lossfun=F.mean_squared_error)
34 | # model.to_cpu()
35 |
36 | def preprocess_fun(*inputs):
37 | atom, adj, t = inputs
38 | # HACKING for now...
39 | atom_embed = regressor.predictor.graph_conv.embed(atom)
40 | return atom_embed, adj, t
41 |
42 | def eval_fun(*inputs):
43 | atom_embed, adj, t = inputs
44 | prob = regressor.predictor(atom_embed, adj)
45 | out = F.sum(prob)
46 | return out
47 |
48 | gradient_calculator = GradientCalculator(
49 | regressor, eval_fun=eval_fun,
50 | target_key=0, multiply_target=True
51 | )
52 |
53 | def clip_original_size(saliency_, num_atoms_):
54 | """`saliency` array is 0 padded, this method align to have original
55 | molecule's length
56 | """
57 | assert len(saliency_) == len(num_atoms_)
58 | saliency_list = []
59 | for i in range(len(saliency_)):
60 | saliency_list.append(saliency_[i, :num_atoms_[i]])
61 | return saliency_list
62 |
63 | atoms = dataset.features[:, 0]
64 | num_atoms = [len(a) for a in atoms]
65 |
66 | print('calculating saliency... M={}'.format(M))
67 | # --- VanillaGrad ---
68 | saliency_arrays = gradient_calculator.compute_vanilla(
69 | dataset, converter=concat_mols, preprocess_fn=preprocess_fun)
70 | saliency = gradient_calculator.transform(
71 | saliency_arrays, ch_axis=3, method='raw')
72 | saliency_vanilla = clip_original_size(saliency, num_atoms)
73 | np.save(os.path.join(dir_path, "saliency_vanilla"), saliency_vanilla)
74 |
75 | # --- SmoothGrad ---
76 | saliency_arrays = gradient_calculator.compute_smooth(
77 | dataset, converter=concat_mols, preprocess_fn=preprocess_fun, M=M)
78 | saliency = gradient_calculator.transform(
79 | saliency_arrays, ch_axis=3, method='raw')
80 | saliency_smooth = clip_original_size(saliency, num_atoms)
81 | np.save(os.path.join(dir_path, "saliency_smooth"), saliency_smooth)
82 |
83 | # --- BayesGrad ---
84 | # train=True corresponds to BayesGrad
85 | saliency_arrays = gradient_calculator.compute_vanilla(
86 | dataset, converter=concat_mols, preprocess_fn=preprocess_fun, M=M,
87 | train=True)
88 | saliency = gradient_calculator.transform(
89 | saliency_arrays, ch_axis=3, method='raw', lam=0)
90 | saliency_bayes = clip_original_size(saliency, num_atoms)
91 | np.save(os.path.join(dir_path, "saliency_bayes"), saliency_bayes)
92 |
93 |
94 | def get_dir_path(batchsize, n_unit, conv_layers, M, method):
95 | dir_path = "results/{}_M{}_conv{}_unit{}_b{}".format(method, M, conv_layers, n_unit, batchsize)
96 | dir_path = os.path.join("./", dir_path)
97 | return dir_path
98 |
99 |
100 | def train(gpu, method, epoch, batchsize, n_unit, conv_layers, dataset, smiles, M, n_split, split_idx, order):
101 | n = len(dataset)
102 | assert len(order) == n
103 | left_idx = (n // n_split) * split_idx
104 | is_right_most_split = (n_split == split_idx + 1)
105 | if is_right_most_split:
106 | test_order = order[left_idx:]
107 | train_order = order[:left_idx]
108 | else:
109 | right_idx = (n // n_split) * (split_idx + 1)
110 | test_order = order[left_idx:right_idx]
111 | train_order = np.concatenate([order[:left_idx], order[right_idx:]])
112 |
113 | new_order = np.concatenate([train_order, test_order])
114 | n_train = len(train_order)
115 |
116 | # Standard Scaler for labels
117 | ss = StandardScaler()
118 | labels = dataset.get_datasets()[-1]
119 | train_label = labels[new_order[:n_train]]
120 | ss = ss.fit(train_label) # fit only by train
121 | labels = ss.transform(dataset.get_datasets()[-1])
122 | dataset = NumpyTupleDataset(*(dataset.get_datasets()[:-1] + (labels,)))
123 |
124 | dataset_train = SubDataset(dataset, 0, n_train, new_order)
125 | dataset_test = SubDataset(dataset, n_train, n, new_order)
126 |
127 | # Network
128 | model = predictor.build_predictor(
129 | method, n_unit, conv_layers, 1, dropout_ratio=0.25, n_layers=1)
130 |
131 | train_iter = I.SerialIterator(dataset_train, batchsize)
132 | val_iter = I.SerialIterator(dataset_test, batchsize, repeat=False, shuffle=False)
133 |
134 | def scaled_abs_error(x0, x1):
135 | if isinstance(x0, Variable):
136 | x0 = cuda.to_cpu(x0.data)
137 | if isinstance(x1, Variable):
138 | x1 = cuda.to_cpu(x1.data)
139 | scaled_x0 = ss.inverse_transform(cuda.to_cpu(x0))
140 | scaled_x1 = ss.inverse_transform(cuda.to_cpu(x1))
141 | diff = scaled_x0 - scaled_x1
142 | return np.mean(np.absolute(diff), axis=0)[0]
143 |
144 | regressor = Regressor(
145 | model, lossfun=F.mean_squared_error,
146 | metrics_fun={'abs_error': scaled_abs_error}, device=gpu)
147 |
148 | optimizer = O.Adam(alpha=0.0005)
149 | optimizer.setup(regressor)
150 |
151 | updater = training.StandardUpdater(train_iter, optimizer, device=gpu, converter=concat_mols)
152 |
153 | dir_path = get_dir_path(batchsize, n_unit, conv_layers, M, method)
154 | dir_path = os.path.join(dir_path, str(split_idx) + "-" + str(n_split))
155 | os.makedirs(dir_path, exist_ok=True)
156 | print('creating ', dir_path)
157 | np.save(os.path.join(dir_path, "test_idx"), np.array(test_order))
158 |
159 | trainer = training.Trainer(updater, (epoch, 'epoch'), out=dir_path)
160 | trainer.extend(E.Evaluator(val_iter, regressor, device=gpu,
161 | converter=concat_mols))
162 | trainer.extend(E.LogReport())
163 | trainer.extend(E.PrintReport(['epoch', 'main/loss', 'main/abs_error',
164 | 'validation/main/loss',
165 | 'validation/main/abs_error',
166 | 'elapsed_time']))
167 | trainer.extend(E.ProgressBar())
168 | trainer.run()
169 |
170 | # --- Plot regression evaluation result ---
171 | dataset_test = SubDataset(dataset, n_train, n, new_order)
172 | batch_all = concat_mols(dataset_test, device=gpu)
173 | serializers.save_npz(os.path.join(dir_path, "model.npz"), model)
174 | result = model(batch_all[0], batch_all[1])
175 | result = ss.inverse_transform(cuda.to_cpu(result.data))
176 | answer = ss.inverse_transform(cuda.to_cpu(batch_all[2]))
177 | plot_result(result, answer, save_filepath=os.path.join(dir_path, "result.png"))
178 |
179 | # --- Plot regression evaluation result end ---
180 | np.save(os.path.join(dir_path, "output.npy"), result)
181 | np.save(os.path.join(dir_path, "answer.npy"), answer)
182 | smiles_part = np.array(smiles)[test_order]
183 | np.save(os.path.join(dir_path, "smiles.npy"), smiles_part)
184 |
185 | # calculate saliency and save it.
186 | save_result(dataset, model, dir_path, M)
187 |
188 |
189 | def main():
190 | # Supported preprocessing/network list
191 | parser = argparse.ArgumentParser(
192 | description='Regression with own dataset.')
193 | parser.add_argument('--gpu', '-g', type=int, default=-1)
194 | parser.add_argument('--method', type=str, default='nfpdrop',
195 | choices=['nfpdrop', 'ggnndrop', 'nfp', 'ggnn'])
196 | parser.add_argument('--epoch', '-e', type=int, default=20)
197 | parser.add_argument('--seed', '-s', type=int, default=777)
198 | parser.add_argument('--layer', '-n', type=int, default=3)
199 | parser.add_argument('--batchsize', '-b', type=int, default=32)
200 | parser.add_argument('--m', '-m', type=int, default=30)
201 | args = parser.parse_args()
202 |
203 | dataset_name = 'delaney'
204 | # labels = "measured log solubility in mols per litre"
205 | labels = None
206 |
207 | # Dataset preparation
208 | print('Preprocessing dataset...')
209 | method = args.method
210 | if 'nfp' in method:
211 | preprocess_method = 'nfp'
212 | elif 'ggnn' in method:
213 | preprocess_method = 'ggnn'
214 | else:
215 | raise ValueError('Unexpected method', method)
216 | preprocessor = preprocess_method_dict[preprocess_method]()
217 | data = get_molnet_dataset(
218 | dataset_name, preprocessor, labels=labels, return_smiles=True,
219 | frac_train=1.0, frac_valid=0.0, frac_test=0.0)
220 | dataset = data['dataset'][0]
221 | smiles = data['smiles'][0]
222 |
223 | epoch = args.epoch
224 | gpu = args.gpu
225 |
226 | n_unit_list = [32]
227 | random_state = np.random.RandomState(args.seed)
228 | n = len(dataset)
229 |
230 | M = args.m
231 | order = np.arange(n)
232 | random_state.shuffle(order)
233 | batchsize = args.batchsize
234 | for n_unit in n_unit_list:
235 | n_layer = args.layer
236 | n_split = 5
237 | for idx in range(n_split):
238 | print('Start training: ', idx+1, "/", n_split)
239 | dir_path = get_dir_path(batchsize, n_unit, n_layer, M, method)
240 | os.makedirs(dir_path, exist_ok=True)
241 | np.save(os.path.join(dir_path, "smiles.npy"), np.array(smiles))
242 | train(gpu, method, epoch, batchsize, n_unit, n_layer, dataset, smiles, M, n_split, idx, order)
243 |
244 |
245 | if __name__ == '__main__':
246 | main()
247 |
--------------------------------------------------------------------------------
/experiments/tox21/calc_prcauc_with_seeds.py:
--------------------------------------------------------------------------------
1 | """
2 | Calculate statistics by seeds.
3 | """
4 | import argparse
5 |
6 | import matplotlib
7 | import pandas
8 |
9 | matplotlib.use('agg')
10 | import matplotlib.pyplot as plt
11 |
12 | from chainer import functions as F
13 | from chainer import links as L
14 | from chainer import serializers
15 | from chainer_chemistry.dataset.converters import concat_mols
16 | from chainer_chemistry.datasets import NumpyTupleDataset
17 | import numpy as np
18 | from rdkit import RDLogger, Chem
19 | from sklearn.metrics import auc
20 | from tqdm import tqdm
21 |
22 | import sys
23 | import os
24 | import logging
25 |
26 |
27 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
28 | from models import predictor
29 | from saliency.calculator.gradient_calculator import GradientCalculator
30 |
31 | import data
32 | from data import PYRIDINE_SMILES
33 | from plot_precision_recall import calc_recall_precision
34 |
35 |
36 | def calc_prc_auc(saliency, rates, haspiindex):
37 | recall, precision = calc_recall_precision(saliency, rates, haspiindex)
38 | print('recall', recall)
39 | print('precision', precision)
40 | prcauc = auc(recall, precision)
41 | print('prcauc', prcauc)
42 | return recall, precision, prcauc
43 |
44 |
45 | def parse():
46 | parser = argparse.ArgumentParser(
47 | description='Multitask Learning with Tox21.')
48 | parser.add_argument('--method', '-m', type=str,
49 | default='nfp', help='graph convolution model to use '
50 | 'as a predictor.')
51 | parser.add_argument('--label', '-l', type=str,
52 | default='pyridine', help='target label for logistic '
53 | 'regression. Use all labels if this option '
54 | 'is not specified.')
55 | parser.add_argument('--conv-layers', '-c', type=int, default=4,
56 | help='number of convolution layers')
57 | parser.add_argument('--n-layers', type=int, default=1,
58 | help='number of mlp layers')
59 | parser.add_argument('--batchsize', '-b', type=int, default=128,
60 | help='batch size')
61 | parser.add_argument('--gpu', '-g', type=int, default=-1,
62 | help='GPU ID to use. Negative value indicates '
63 | 'not to use GPU and to run the code in CPU.')
64 | parser.add_argument('--out', '-o', type=str, default='result',
65 | help='path to output directory')
66 | parser.add_argument('--epoch', '-e', type=int, default=10,
67 | help='number of epochs')
68 | parser.add_argument('--unit-num', '-u', type=int, default=16,
69 | help='number of units in one layer of the model')
70 | parser.add_argument('--dropout-ratio', '-d', type=float, default=0.25,
71 | help='dropout_ratio')
72 | parser.add_argument('--seeds', type=int, default=0,
73 | help='number of seed to use for calculation')
74 | parser.add_argument('--num-train', type=int, default=-1,
75 | help='number of training data to be used, '
76 | 'negative value indicates use all train data')
77 | parser.add_argument('--scale', type=float, default=0.15, help='scale for smoothgrad')
78 | parser.add_argument('--mode', type=str, default='absolute', help='mode for smoothgrad')
79 | args = parser.parse_args()
80 | return args
81 |
82 |
83 | def main(method, labels, unit_num, conv_layers, class_num, n_layers,
84 | dropout_ratio, model_path_list, save_dir_path, scale=0.15, mode='relative', M=5):
85 | # Dataset preparation
86 | train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels)
87 |
88 | # --- model preparation ---
89 | model = predictor.build_predictor(
90 | method, unit_num, conv_layers, class_num, dropout_ratio, n_layers)
91 |
92 | classifier = L.Classifier(model,
93 | lossfun=F.sigmoid_cross_entropy,
94 | accfun=F.binary_accuracy)
95 |
96 | target_dataset = val
97 | target_smiles = val_smiles
98 |
99 | val_mols = [Chem.MolFromSmiles(smi) for smi in tqdm(val_smiles)]
100 |
101 | pi = Chem.MolFromSmarts(PYRIDINE_SMILES)
102 | piindex = np.where(np.array([mol.HasSubstructMatch(pi) for mol in val_mols]) == True)
103 | haspi = np.array(val_mols)[piindex]
104 |
105 | # It only extracts one substructure, not expected behavior
106 | # haspiindex = [set(mol.GetSubstructMatch(pi)) for mol in haspi]
107 | def flatten_tuple(x):
108 | return [element for tupl in x for element in tupl]
109 | haspiindex = [flatten_tuple(mol.GetSubstructMatches(pi)) for mol in haspi]
110 | print('piindex', piindex)
111 | print('haspi', haspi.shape)
112 | print('haspiindex', haspiindex)
113 | print('haspiindex length', [len(k) for k in haspiindex])
114 |
115 | pyrigine_dataset = NumpyTupleDataset(*target_dataset.features[piindex, :])
116 | pyrigine_smiles = target_smiles[piindex]
117 |
118 | atoms = pyrigine_dataset.features[:, 0]
119 | num_atoms = [len(a) for a in atoms]
120 |
121 | def clip_original_size(saliency, num_atoms):
122 | """`saliency` array is 0 padded, this method align to have original
123 | molecule's length
124 | """
125 | assert len(saliency) == len(num_atoms)
126 | saliency_list = []
127 | for i in range(len(saliency)):
128 | saliency_list.append(saliency[i, :num_atoms[i]])
129 | return saliency_list
130 |
131 | def preprocess_fun(*inputs):
132 | atom, adj, t = inputs
133 | # HACKING for now...
134 | # classifier.predictor.pick = True
135 | # result = classifier.predictor(atom, adj)
136 | atom_embed = classifier.predictor.graph_conv.embed(atom)
137 | return atom_embed, adj, t
138 |
139 | def eval_fun(*inputs):
140 | atom_embed, adj, t = inputs
141 | prob = classifier.predictor(atom_embed, adj)
142 | # print('embed', atom_embed.shape, 'prob', prob.shape)
143 | out = F.sum(prob)
144 | # return {'embed': atom_embed, 'out': out}
145 | return out
146 |
147 | gradient_calculator = GradientCalculator(
148 | classifier, eval_fun=eval_fun,
149 | # target_key='embed', eval_key='out',
150 | target_key=0,
151 | )
152 |
153 | print('M', M)
154 | # rates = np.array(list(range(1, 11))) * 0.1
155 | num = 20
156 | # rates = np.linspace(0, 1, num=num+1)[1:]
157 | rates = np.linspace(0.1, 1, num=num)
158 | print('rates', len(rates), rates)
159 |
160 | fig = plt.figure(figsize=(7, 5), dpi=200)
161 |
162 | precisions_vanilla = []
163 | precisions_smooth = []
164 | precisions_bayes = []
165 | precisions_bayes_smooth = []
166 | prcauc_vanilla = []
167 | prcauc_smooth = []
168 | prcauc_bayes = []
169 | prcauc_bayes_smooth = []
170 |
171 | prcauc_diff_smooth_vanilla = []
172 | prcauc_diff_bayes_vanilla = []
173 | for model_path in model_path_list:
174 | serializers.load_npz(model_path, model)
175 |
176 | # --- VanillaGrad ---
177 | saliency_arrays = gradient_calculator.compute_vanilla(
178 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun)
179 | saliency = gradient_calculator.transform(
180 | saliency_arrays, ch_axis=3, method='square')
181 | # saliency_arrays (1, 28, 43, 64) -> M, batch_size, max_atom, ch_dim
182 | print('saliency_arrays', saliency_arrays.shape)
183 | # saliency (28, 43) -> batch_size, max_atom
184 | print('saliency', saliency.shape)
185 | saliency_vanilla = clip_original_size(saliency, num_atoms)
186 |
187 | # recall & precision
188 | print('vanilla')
189 | naiverecall, naiveprecision, naiveprcauc = calc_prc_auc(saliency_vanilla, rates, haspiindex)
190 | precisions_vanilla.append(naiveprecision)
191 | prcauc_vanilla.append(naiveprcauc)
192 |
193 | # --- SmoothGrad ---
194 | saliency_arrays = gradient_calculator.compute_smooth(
195 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,
196 | M=M, scale=scale, mode=mode)
197 | saliency = gradient_calculator.transform(
198 | saliency_arrays, ch_axis=3, method='square')
199 |
200 | saliency_smooth = clip_original_size(saliency, num_atoms)
201 |
202 | # recall & precision
203 | print('smooth')
204 | smoothrecall, smoothprecision, smoothprcauc = calc_prc_auc(saliency_smooth, rates, haspiindex)
205 | precisions_smooth.append(smoothprecision)
206 | prcauc_smooth.append(smoothprcauc)
207 |
208 | # --- BayesGrad ---
209 | saliency_arrays = gradient_calculator.compute_vanilla(
210 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun, train=True, M=M)
211 | saliency = gradient_calculator.transform(
212 | saliency_arrays, ch_axis=3, method='square', lam=0)
213 | saliency_bayes = clip_original_size(saliency, num_atoms)
214 |
215 | bgrecall0, bgprecision0, bayesprcauc = calc_prc_auc(saliency_bayes, rates, haspiindex)
216 | precisions_bayes.append(bgprecision0)
217 | prcauc_bayes.append(bayesprcauc)
218 | prcauc_diff_smooth_vanilla.append(smoothprcauc - naiveprcauc)
219 | prcauc_diff_bayes_vanilla.append(bayesprcauc - naiveprcauc)
220 |
221 | # --- BayesSmoothGrad ---
222 | saliency_arrays = gradient_calculator.compute_smooth(
223 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,
224 | M=M, scale=scale, mode=mode, train=True)
225 | saliency = gradient_calculator.transform(
226 | saliency_arrays, ch_axis=3, method='square')
227 | saliency_bayes_smooth = clip_original_size(saliency, num_atoms)
228 | # recall & precision
229 | print('bayes smooth')
230 | bayes_smoothrecall, bayes_smoothprecision, bayes_smoothprcauc = calc_prc_auc(saliency_bayes_smooth, rates, haspiindex)
231 | precisions_bayes_smooth.append(bayes_smoothprecision)
232 | prcauc_bayes_smooth.append(bayes_smoothprcauc)
233 |
234 | precisions_vanilla = np.array(precisions_vanilla)
235 | precisions_smooth = np.array(precisions_smooth)
236 | precisions_bayes = np.array(precisions_bayes)
237 | precisions_bayes_smooth = np.array(precisions_bayes_smooth)
238 |
239 | df = pandas.DataFrame({
240 | 'model_path': model_path,
241 | 'prcauc_vanilla': prcauc_vanilla,
242 | 'prcauc_smooth': prcauc_smooth,
243 | 'prcauc_bayes': prcauc_bayes,
244 | 'prcauc_bayes_smooth': prcauc_bayes_smooth,
245 | 'prcauc_diff_smooth_vanilla': prcauc_diff_smooth_vanilla,
246 | 'prcauc_diff_bayes_vanilla': prcauc_diff_bayes_vanilla
247 | })
248 | save_csv_path = save_dir_path + '/prcauc_{}_{}.csv'.format(mode, scale)
249 | print('save to ', save_csv_path)
250 | df.to_csv(save_csv_path)
251 |
252 | prcauc_vanilla = np.array(prcauc_vanilla)
253 | prcauc_smooth = np.array(prcauc_smooth)
254 | prcauc_bayes = np.array(prcauc_bayes)
255 | prcauc_bayes_smooth = np.array(prcauc_bayes_smooth)
256 | prcauc_diff_smooth_vanilla = np.array(prcauc_diff_smooth_vanilla)
257 | prcauc_diff_bayes_vanilla = np.array(prcauc_diff_bayes_vanilla)
258 |
259 | def show_avg_std(array, tag=''):
260 | print('{}: mean {:8.03}, std {:8.03}'
261 | .format(tag, np.mean(array, axis=0), np.std(array, axis=0)))
262 | return {'method': tag, 'mean': np.mean(array, axis=0), 'std': np.std(array, axis=0)}
263 |
264 | df = pandas.DataFrame([
265 | show_avg_std(prcauc_vanilla, tag='vanilla'),
266 | show_avg_std(prcauc_smooth, tag='smooth'),
267 | show_avg_std(prcauc_bayes, tag='bayes'),
268 | show_avg_std(prcauc_bayes_smooth, tag='bayes_smooth'),
269 | show_avg_std(prcauc_diff_smooth_vanilla, tag='diff_smooth_vanilla'),
270 | show_avg_std(prcauc_diff_bayes_vanilla, tag='diff_bayes_vanilla'),
271 | ])
272 | save_csv_path = save_dir_path + '/prcauc_stats_{}_{}.csv'.format(mode, scale)
273 | print('save to ', save_csv_path)
274 | df.to_csv(save_csv_path)
275 | # import IPython; IPython.embed()
276 |
277 | def _plot_with_errorbar(x, precisions, color='blue', alpha=None, label=None):
278 | y = np.mean(precisions, axis=0)
279 | plt.errorbar(x, y, yerr=np.std(precisions, axis=0), fmt='ro', ecolor=color) # fmt=''
280 | plt.plot(x, y, 'k-', color=color, label=label, alpha=alpha)
281 |
282 | alpha = 0.5
283 | _plot_with_errorbar(rates, precisions_vanilla, color='blue', alpha=alpha, label='VanillaGrad')
284 | _plot_with_errorbar(rates, precisions_smooth, color='green', alpha=alpha, label='SmoothGrad')
285 | _plot_with_errorbar(rates, precisions_bayes, color='yellow', alpha=alpha, label='BayesGrad')
286 | _plot_with_errorbar(rates, precisions_bayes_smooth, color='orange', alpha=alpha, label='BayesSmoothGrad')
287 | plt.legend()
288 | plt.xlabel("recall")
289 | plt.ylabel("precision")
290 | save_path = os.path.join(save_dir_path, 'artificial_pr.png')
291 | print('saved to ', save_path)
292 | plt.savefig(save_path)
293 |
294 |
295 | if __name__ == '__main__':
296 | # Disable errors by RDKit occurred in preprocessing Tox21 dataset.
297 | lg = RDLogger.logger()
298 | lg.setLevel(RDLogger.CRITICAL)
299 | # show INFO level log from chainer chemistry
300 | logging.basicConfig(level=logging.INFO)
301 |
302 | args = parse()
303 | # --- config ---
304 | method = args.method
305 | labels = args.label
306 | unit_num = args.unit_num
307 | conv_layers = args.conv_layers
308 | class_num = 1
309 | n_layers = args.n_layers
310 | dropout_ratio = args.dropout_ratio
311 | num_train = args.num_train
312 | seeds = args.seeds
313 |
314 | root = '.'
315 | model_path_list = []
316 | for i in range(seeds):
317 | dir_path = '{}/results/{}_{}_numtrain{}_seed{}'.format(
318 | root, method, labels, num_train, i)
319 | model_path = os.path.join(dir_path, 'predictor.npz')
320 | model_path_list.append(model_path)
321 |
322 | save_dir_path = '{}/results/{}_{}_numtrain{}_seed0-{}'.format(
323 | root, method, labels, num_train, seeds-1)
324 | if not os.path.exists(save_dir_path):
325 | os.mkdir(save_dir_path)
326 |
327 | # --- config end ---
328 |
329 | main(method, labels, unit_num, conv_layers, class_num, n_layers,
330 | dropout_ratio, model_path_list, save_dir_path, args.scale, args.mode, M=100)
331 |
--------------------------------------------------------------------------------
/experiments/tox21/calc_prcauc_with_seeds.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -eu
4 |
5 | gpu=${1:--1}
6 | label=pyridine
7 | seeds=30
8 | num_train=1000
9 | unit_num=16
10 |
11 | # mode=relative
12 | mode=absolute
13 | method=ggnndrop
14 | ratio=0.25
15 |
16 | for scale in 0.05 0.10 0.15 0.20; do
17 | python calc_prcauc_with_seeds.py -g ${gpu} --label=${label} --unit-num=${unit_num} --n-layers=1 --dropout-ratio=${ratio} --num-train=${num_train} --seed=${seeds} --method=${method} --scale=${scale} --mode=${mode}
18 | done
19 |
--------------------------------------------------------------------------------
/experiments/tox21/data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy
4 | from chainer_chemistry.dataset.preprocessors import preprocess_method_dict
5 | from chainer_chemistry import datasets as D
6 | from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset
7 | from rdkit import Chem
8 | from tqdm import tqdm
9 |
10 | import utils
11 |
12 |
13 | class _CacheNamePolicy(object):
14 |
15 | train_file_name = 'train.npz'
16 | val_file_name = 'val.npz'
17 | test_file_name = 'test.npz'
18 | smiles_file_name = 'smiles.npz'
19 |
20 | def _get_cache_directory_path(self, method, labels, prefix):
21 | if labels:
22 | return os.path.join(prefix, '{}_{}'.format(method, labels))
23 | else:
24 | return os.path.join(prefix, '{}_all'.format(method))
25 |
26 | def __init__(self, method, labels, prefix='input'):
27 | self.method = method
28 | self.labels = labels
29 | self.prefix = prefix
30 | self.cache_dir = self._get_cache_directory_path(method, labels, prefix)
31 |
32 | def get_train_file_path(self):
33 | return os.path.join(self.cache_dir, self.train_file_name)
34 |
35 | def get_val_file_path(self):
36 | return os.path.join(self.cache_dir, self.val_file_name)
37 |
38 | def get_test_file_path(self):
39 | return os.path.join(self.cache_dir, self.test_file_name)
40 |
41 | def get_smiles_path(self):
42 | return os.path.join(self.cache_dir, self.smiles_file_name)
43 |
44 | def create_cache_directory(self):
45 | try:
46 | os.makedirs(self.cache_dir)
47 | except OSError:
48 | if not os.path.isdir(self.cache_dir):
49 | raise
50 |
51 |
52 | PYRIDINE_SMILES = 'c1ccncc1'
53 |
54 |
55 | def hassubst(mol, smart=PYRIDINE_SMILES):
56 | return numpy.array(int(mol.HasSubstructMatch(Chem.MolFromSmarts(smart)))).astype('int32')
57 |
58 |
59 | def load_dataset(method, labels, prefix='input'):
60 | method = 'nfp' if 'nfp' in method else method # to deal with nfpdrop
61 | method = 'ggnn' if 'ggnn' in method else method # to deal with ggnndrop
62 | policy = _CacheNamePolicy(method, labels, prefix)
63 | train_path = policy.get_train_file_path()
64 | val_path = policy.get_val_file_path()
65 | test_path = policy.get_test_file_path()
66 | smiles_path = policy.get_smiles_path()
67 |
68 | train, val, test = None, None, None
69 | train_smiles, val_smiles, test_smiles = None, None, None
70 | print()
71 | if os.path.exists(policy.cache_dir):
72 | print('load from cache {}'.format(policy.cache_dir))
73 | train = NumpyTupleDataset.load(train_path)
74 | val = NumpyTupleDataset.load(val_path)
75 | test = NumpyTupleDataset.load(test_path)
76 | train_smiles, val_smiles, test_smiles = utils.load_npz(smiles_path)
77 | if train is None or val is None or test is None:
78 | print('preprocessing dataset...')
79 | preprocessor = preprocess_method_dict[method]()
80 | if labels == 'pyridine':
81 | train, val, test, train_smiles, val_smiles, test_smiles = D.get_tox21(
82 | preprocessor, labels=None, return_smiles=True)
83 | print('converting label into pyridine...')
84 | # --- Pyridine = 1 ---
85 | train_pyridine_label = [
86 | hassubst(Chem.MolFromSmiles(smi), smart=PYRIDINE_SMILES) for smi in tqdm(train_smiles)]
87 | val_pyridine_label = [
88 | hassubst(Chem.MolFromSmiles(smi), smart=PYRIDINE_SMILES) for smi in tqdm(val_smiles)]
89 | test_pyridine_label = [
90 | hassubst(Chem.MolFromSmiles(smi), smart=PYRIDINE_SMILES) for smi in tqdm(test_smiles)]
91 |
92 | train_pyridine_label = numpy.array(train_pyridine_label)[:, None]
93 | val_pyridine_label = numpy.array(val_pyridine_label)[:, None]
94 | test_pyridine_label = numpy.array(test_pyridine_label)[:, None]
95 | print('train positive/negative', numpy.sum(train_pyridine_label == 1), numpy.sum(train_pyridine_label == 0))
96 | train = NumpyTupleDataset(*train.features[:, :-1], train_pyridine_label)
97 | val = NumpyTupleDataset(*val.features[:, :-1], val_pyridine_label)
98 | test = NumpyTupleDataset(*test.features[:, :-1], test_pyridine_label)
99 | else:
100 | train, val, test, train_smiles, val_smiles, test_smiles = D.get_tox21(
101 | preprocessor, labels=labels, return_smiles=True)
102 |
103 | # Cache dataset
104 | policy.create_cache_directory()
105 | NumpyTupleDataset.save(train_path, train)
106 | NumpyTupleDataset.save(val_path, val)
107 | NumpyTupleDataset.save(test_path, test)
108 | train_smiles = numpy.array(train_smiles)
109 | val_smiles = numpy.array(val_smiles)
110 | test_smiles = numpy.array(test_smiles)
111 | utils.save_npz(smiles_path, (train_smiles, val_smiles, test_smiles))
112 | return train, val, test, train_smiles, val_smiles, test_smiles
113 |
--------------------------------------------------------------------------------
/experiments/tox21/plot_precision_recall.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import logging
4 | import argparse
5 | import json
6 |
7 | try:
8 | import matplotlib
9 | matplotlib.use('agg')
10 | except:
11 | pass
12 | import matplotlib.pyplot as plt
13 |
14 | from chainer import functions as F
15 | from chainer import links as L
16 | from tqdm import tqdm
17 | from chainer import serializers
18 | import numpy as np
19 | from rdkit import RDLogger, Chem
20 |
21 | from chainer_chemistry.dataset.converters import concat_mols
22 | from chainer_chemistry.datasets import NumpyTupleDataset
23 |
24 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
25 | from models import predictor
26 | from saliency.calculator.gradient_calculator import GradientCalculator
27 | from saliency.calculator.integrated_gradients_calculator import IntegratedGradientsCalculator
28 | from saliency.calculator.occlusion_calculator import OcclusionCalculator
29 |
30 | import data
31 | from data import PYRIDINE_SMILES, hassubst
32 |
33 |
34 | def percentile_index(ar, num):
35 | """
36 | ar (numpy.ndarray): array
37 | num (float): rate
38 |
39 | Extract `num` rate of largest index in this array.
40 | """
41 | threshold = int(len(ar) * num)
42 | idx = np.argsort(ar)
43 | return idx[-threshold:]
44 |
45 |
46 | def calc_recall_precision_for_rate(grads, rate, haspiindex):
47 | recall_list = []
48 | hit_rate_list = []
49 | for i in range(len(grads)):
50 | largest_index = percentile_index(grads[i], float(rate))
51 | set_largest_index = set(largest_index)
52 | hit_index = set_largest_index.intersection(haspiindex[i])
53 | hit_num = len(hit_index)
54 | hit_rate = float(hit_num) / float(len(set_largest_index))
55 |
56 | recall_list.append(float(hit_num) / len(haspiindex[i]))
57 | hit_rate_list.append(hit_rate)
58 | recall = np.mean(np.array(recall_list))
59 | precision = np.mean(np.array(hit_rate_list))
60 | return recall, precision
61 |
62 |
63 | def calc_recall_precision(grads, rates, haspiindex):
64 | r_list = []
65 | p_list = []
66 | for rate in rates:
67 | r, p = calc_recall_precision_for_rate(grads, rate, haspiindex)
68 | r_list.append(r)
69 | p_list.append(p)
70 | return r_list, p_list
71 |
72 |
73 | def parse():
74 | parser = argparse.ArgumentParser(
75 | description='Multitask Learning with Tox21.')
76 | parser.add_argument('--batchsize', '-b', type=int, default=128,
77 | help='batch size')
78 | parser.add_argument('--gpu', '-g', type=int, default=-1,
79 | help='GPU ID to use. Negative value indicates '
80 | 'not to use GPU and to run the code in CPU.')
81 | parser.add_argument('--dirpath', '-d', type=str, default='results',
82 | help='path to train results directory')
83 | parser.add_argument('--calculator', type=str, default='gradient')
84 | args = parser.parse_args()
85 | return args
86 |
87 |
88 | def main(method, labels, unit_num, conv_layers, class_num, n_layers,
89 | dropout_ratio, model_path, save_path):
90 | # Dataset preparation
91 | train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels)
92 |
93 | # --- model preparation ---
94 | model = predictor.build_predictor(
95 | method, unit_num, conv_layers, class_num, dropout_ratio, n_layers)
96 |
97 | classifier = L.Classifier(model,
98 | lossfun=F.sigmoid_cross_entropy,
99 | accfun=F.binary_accuracy)
100 |
101 | print('Loading model parameter from ', model_path)
102 | serializers.load_npz(model_path, model)
103 |
104 | target_dataset = val
105 | target_smiles = val_smiles
106 |
107 | val_mols = [Chem.MolFromSmiles(smi) for smi in tqdm(val_smiles)]
108 |
109 | pyridine_mol = Chem.MolFromSmarts(PYRIDINE_SMILES)
110 | pyridine_index = np.where(np.array([mol.HasSubstructMatch(pyridine_mol) for mol in val_mols]) == True)
111 | val_pyridine_mols = np.array(val_mols)[pyridine_index]
112 |
113 | # It only extracts one substructure, not expected behavior
114 | # val_pyridine_pos = [set(mol.GetSubstructMatch(pi)) for mol in val_pyridine_mols]
115 | def flatten_tuple(x):
116 | return [element for tupl in x for element in tupl]
117 |
118 | val_pyridine_pos = [flatten_tuple(mol.GetSubstructMatches(pyridine_mol)) for mol in val_pyridine_mols]
119 |
120 | # print('pyridine_index', pyridine_index)
121 | # print('val_pyridine_mols', val_pyridine_mols.shape)
122 | # print('val_pyridine_pos', val_pyridine_pos)
123 | # print('val_pyridine_pos length', [len(k) for k in val_pyridine_pos])
124 |
125 | pyrigine_dataset = NumpyTupleDataset(*target_dataset.features[pyridine_index, :])
126 | pyrigine_smiles = target_smiles[pyridine_index]
127 | print('pyrigine_dataset', len(pyrigine_dataset), len(pyrigine_smiles))
128 |
129 | atoms = pyrigine_dataset.features[:, 0]
130 | num_atoms = [len(a) for a in atoms]
131 |
132 | def clip_original_size(saliency, num_atoms):
133 | """`saliency` array is 0 padded, this method align to have original
134 | molecule's length
135 | """
136 | assert len(saliency) == len(num_atoms)
137 | saliency_list = []
138 | for i in range(len(saliency)):
139 | saliency_list.append(saliency[i, :num_atoms[i]])
140 | return saliency_list
141 |
142 | def preprocess_fun(*inputs):
143 | atom, adj, t = inputs
144 | # HACKING for now...
145 | atom_embed = classifier.predictor.graph_conv.embed(atom)
146 | return atom_embed, adj, t
147 |
148 | def eval_fun(*inputs):
149 | atom_embed, adj, t = inputs
150 | prob = classifier.predictor(atom_embed, adj)
151 | out = F.sum(prob)
152 | return out
153 |
154 | calculator_method = args.calculator
155 | print('calculator method', calculator_method)
156 | if calculator_method == 'gradient':
157 | # option1: Gradient
158 | calculator = GradientCalculator(
159 | classifier, eval_fun=eval_fun,
160 | # target_key='embed', eval_key='out',
161 | target_key=0,
162 | # multiply_target=True # this will calculate grad * input
163 | )
164 | elif calculator_method == 'integrated_gradients':
165 | # option2: IntegratedGradients
166 | calculator = IntegratedGradientsCalculator(
167 | classifier, eval_fun=eval_fun,
168 | # target_key='embed', eval_key='out',
169 | target_key=0, steps=10
170 | )
171 | elif calculator_method == 'occlusion':
172 | # option3: Occlusion
173 | def eval_fun_occlusion(*inputs):
174 | atom_embed, adj, t = inputs
175 | prob = classifier.predictor(atom_embed, adj)
176 | # Do not take sum, instead return batch-wise score
177 | out = F.sigmoid(prob)
178 | return out
179 | calculator = OcclusionCalculator(
180 | classifier, eval_fun=eval_fun_occlusion,
181 | # target_key='embed', eval_key='out',
182 | target_key=0, slide_axis=1
183 | )
184 | else:
185 | raise ValueError("[ERROR] Unexpected value calculator_method={}".format(calculator_method))
186 |
187 | M = 100
188 | num = 20
189 | rates = np.linspace(0.1, 1, num=num)
190 | print('M', M)
191 |
192 | # --- VanillaGrad ---
193 | saliency_arrays = calculator.compute_vanilla(
194 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun)
195 | saliency = calculator.transform(
196 | saliency_arrays, ch_axis=3, method='square')
197 | # saliency_arrays -> M, batch_size, max_atom, ch_dim
198 | # print('saliency_arrays', saliency_arrays.shape)
199 | # saliency -> batch_size, max_atom
200 | # print('saliency', saliency.shape)
201 | saliency_vanilla = clip_original_size(saliency, num_atoms)
202 |
203 | # recall & precision
204 | vanilla_recall, vanilla_precision = calc_recall_precision(saliency_vanilla, rates, val_pyridine_pos)
205 | print('vanilla_recall', vanilla_recall)
206 | print('vanilla_precision', vanilla_precision)
207 |
208 | # --- SmoothGrad ---
209 | saliency_arrays = calculator.compute_smooth(
210 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,
211 | M=M,
212 | mode='absolute', scale=0.15 # previous implementation
213 | # mode='relative', scale=0.05
214 | )
215 | saliency = calculator.transform(
216 | saliency_arrays, ch_axis=3, method='square')
217 |
218 | saliency_smooth = clip_original_size(saliency, num_atoms)
219 |
220 | # recall & precision
221 | smooth_recall, smooth_precision = calc_recall_precision(saliency_smooth, rates, val_pyridine_pos)
222 | print('smooth_recall', smooth_recall)
223 | print('smooth_precision', smooth_precision)
224 |
225 | # --- BayesGrad ---
226 | # bayes grad is calculated by compute_vanilla with train=True
227 | saliency_arrays = calculator.compute_vanilla(
228 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,
229 | M=M, train=True)
230 | saliency = calculator.transform(
231 | saliency_arrays, ch_axis=3, method='square', lam=0)
232 | saliency_bayes = clip_original_size(saliency, num_atoms)
233 |
234 | bayes_recall, bayes_precision = calc_recall_precision(saliency_bayes, rates, val_pyridine_pos)
235 | print('bayes_recall', bayes_recall)
236 | print('bayes_precision', bayes_precision)
237 |
238 | plt.figure(figsize=(7, 5), dpi=200)
239 | plt.plot(vanilla_recall, vanilla_precision, 'k-', color='blue', label='VanillaGrad')
240 | plt.plot(smooth_recall, smooth_precision, 'k-', color='green', label='SmoothGrad')
241 | plt.plot(bayes_recall, bayes_precision, 'k-', color='red', label='BayesGrad(Ours)')
242 | plt.axhline(y=vanilla_precision[-1], color='gray', linestyle='--')
243 | plt.legend()
244 | plt.xlabel("recall")
245 | plt.ylabel("precision")
246 | if save_path:
247 | print('saved to ', save_path)
248 | plt.savefig(save_path)
249 | # plt.savefig('artificial_pr.eps')
250 | else:
251 | plt.show()
252 |
253 |
254 | if __name__ == '__main__':
255 | # Disable errors by RDKit occurred in preprocessing Tox21 dataset.
256 | lg = RDLogger.logger()
257 | lg.setLevel(RDLogger.CRITICAL)
258 | # show INFO level log from chainer chemistry
259 | logging.basicConfig(level=logging.INFO)
260 |
261 | args = parse()
262 | # --- extracting configs ---
263 | dirpath = args.dirpath
264 | json_path = os.path.join(dirpath, 'args.json')
265 | if not os.path.exists(json_path):
266 | raise ValueError(
267 | 'json_path {} not found! Execute train_tox21.py beforehand.'.format(json))
268 | with open(json_path, 'r') as f:
269 | train_args = json.load(f)
270 |
271 | method = train_args['method']
272 | labels = train_args['label'] # 'pyridine'
273 |
274 | unit_num = train_args['unit_num']
275 | conv_layers = train_args['conv_layers']
276 | class_num = 1
277 | n_layers = train_args['n_layers']
278 | dropout_ratio = train_args['dropout_ratio']
279 | num_train = train_args['num_train']
280 | # seed = train_args['seed']
281 | # --- extracting configs end ---
282 |
283 | model_path = os.path.join(dirpath, 'predictor.npz')
284 | save_path = os.path.join(
285 | dirpath, 'precision_recall_{}.png'.format(args.calculator))
286 | # --- config end ---
287 |
288 | main(method, labels, unit_num, conv_layers, class_num, n_layers,
289 | dropout_ratio, model_path, save_path)
290 |
--------------------------------------------------------------------------------
/experiments/tox21/train_few_with_seeds.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -eu
4 |
5 | gpu=${1:--1}
6 | # seed start to end
7 | start=0
8 | end=30
9 | num_train=1000
10 | unit_num=16
11 | epoch=100
12 | label=pyridine
13 | method=ggnndrop
14 | ratio=0.25
15 |
16 | for ((seed=$start; seed < $end; seed++)); do
17 | python train_tox21.py -g ${gpu} --iterator-type=balanced --label=${label} --method=${method} --epoch=${epoch} --unit-num=${unit_num} --n-layers=1 -b 32 --conv-layers=4 --num-train=${num_train} --seed=${seed} --dropout-ratio=${ratio} --out=results/${method}_${label}_numtrain${num_train}_seed${seed}
18 | done
19 |
--------------------------------------------------------------------------------
/experiments/tox21/train_tox21.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Tox21 data training
4 | Additionally, artificial pyridine dataset training is supported.
5 | """
6 |
7 | from __future__ import print_function
8 |
9 | import logging
10 | import sys
11 | import os
12 |
13 | try:
14 | import matplotlib
15 | matplotlib.use('Agg')
16 | except ImportError:
17 | pass
18 |
19 | import numpy as np
20 | import argparse
21 | import chainer
22 | from chainer import functions as F
23 | from chainer import iterators as I
24 | from chainer import links as L
25 | from chainer import optimizers as O
26 | from chainer import training
27 | from chainer.training import extensions as E
28 | import json
29 | from rdkit import RDLogger, Chem
30 |
31 | from chainer_chemistry.datasets import NumpyTupleDataset
32 | from chainer_chemistry.dataset.converters import concat_mols
33 | from chainer_chemistry import datasets as D
34 | from chainer_chemistry.iterators.balanced_serial_iterator import BalancedSerialIterator # NOQA
35 | from chainer_chemistry.training.extensions.roc_auc_evaluator import ROCAUCEvaluator
36 |
37 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
38 | from models import predictor
39 |
40 | import data
41 |
42 |
43 | def main():
44 | # Supported preprocessing/network list
45 | method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'nfpdrop', 'ggnndrop']
46 | label_names = D.get_tox21_label_names() + ['pyridine']
47 | iterator_type = ['serial', 'balanced']
48 |
49 | parser = argparse.ArgumentParser(
50 | description='Multitask Learning with Tox21.')
51 | parser.add_argument('--method', '-m', type=str, choices=method_list,
52 | default='nfp', help='graph convolution model to use '
53 | 'as a predictor.')
54 | parser.add_argument('--label', '-l', type=str, choices=label_names,
55 | default='', help='target label for logistic '
56 | 'regression. Use all labels if this option '
57 | 'is not specified.')
58 | parser.add_argument('--iterator-type', type=str, choices=iterator_type,
59 | default='serial', help='iterator type. If `balanced` '
60 | 'is specified, data is sampled to take same number of'
61 | 'positive/negative labels during training.')
62 | parser.add_argument('--conv-layers', '-c', type=int, default=4,
63 | help='number of convolution layers')
64 | parser.add_argument('--n-layers', type=int, default=1,
65 | help='number of mlp layers')
66 | parser.add_argument('--batchsize', '-b', type=int, default=32,
67 | help='batch size')
68 | parser.add_argument('--gpu', '-g', type=int, default=-1,
69 | help='GPU ID to use. Negative value indicates '
70 | 'not to use GPU and to run the code in CPU.')
71 | parser.add_argument('--out', '-o', type=str, default='result',
72 | help='path to output directory')
73 | parser.add_argument('--epoch', '-e', type=int, default=10,
74 | help='number of epochs')
75 | parser.add_argument('--unit-num', '-u', type=int, default=16,
76 | help='number of units in one layer of the model')
77 | parser.add_argument('--resume', '-r', type=str, default='',
78 | help='path to a trainer snapshot')
79 | parser.add_argument('--frequency', '-f', type=int, default=-1,
80 | help='Frequency of taking a snapshot')
81 | parser.add_argument('--dropout-ratio', '-d', type=float, default=0.25,
82 | help='dropout_ratio')
83 | parser.add_argument('--seed', type=int, default=0,
84 | help='random seed')
85 | parser.add_argument('--num-train', type=int, default=-1,
86 | help='number of training data to be used, '
87 | 'negative value indicates use all train data')
88 | args = parser.parse_args()
89 |
90 | method = args.method
91 | if args.label:
92 | labels = args.label
93 | class_num = len(labels) if isinstance(labels, list) else 1
94 | else:
95 | labels = None
96 | class_num = len(label_names)
97 |
98 | # Dataset preparation
99 | train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels)
100 |
101 | num_train = args.num_train # 100
102 | if num_train > 0:
103 | # reduce size of train data
104 | seed = args.seed # 0
105 | np.random.seed(seed)
106 | train_selected_label = np.random.permutation(np.arange(len(train)))[:num_train]
107 | print('num_train', num_train, len(train_selected_label), train_selected_label)
108 | train = NumpyTupleDataset(*train.features[train_selected_label, :])
109 | # Network
110 | predictor_ = predictor.build_predictor(
111 | method, args.unit_num, args.conv_layers, class_num, args.dropout_ratio,
112 | args.n_layers
113 | )
114 |
115 | iterator_type = args.iterator_type
116 | if iterator_type == 'serial':
117 | train_iter = I.SerialIterator(train, args.batchsize)
118 | elif iterator_type == 'balanced':
119 | if class_num > 1:
120 | raise ValueError('BalancedSerialIterator can be used with only one'
121 | 'label classification, please specify label to'
122 | 'be predicted by --label option.')
123 | train_iter = BalancedSerialIterator(
124 | train, args.batchsize, train.features[:, -1], ignore_labels=-1)
125 | train_iter.show_label_stats()
126 | else:
127 | raise ValueError('Invalid iterator type {}'.format(iterator_type))
128 | val_iter = I.SerialIterator(val, args.batchsize,
129 | repeat=False, shuffle=False)
130 | classifier = L.Classifier(predictor_,
131 | lossfun=F.sigmoid_cross_entropy,
132 | accfun=F.binary_accuracy)
133 | if args.gpu >= 0:
134 | chainer.cuda.get_device_from_id(args.gpu).use()
135 | classifier.to_gpu()
136 |
137 | optimizer = O.Adam()
138 | optimizer.setup(classifier)
139 |
140 | updater = training.StandardUpdater(
141 | train_iter, optimizer, device=args.gpu, converter=concat_mols)
142 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
143 |
144 | trainer.extend(E.Evaluator(val_iter, classifier,
145 | device=args.gpu, converter=concat_mols))
146 | trainer.extend(E.LogReport())
147 |
148 | # --- ROCAUC Evaluator ---
149 | train_eval_iter = I.SerialIterator(train, args.batchsize,
150 | repeat=False, shuffle=False)
151 | trainer.extend(ROCAUCEvaluator(
152 | train_eval_iter, classifier, eval_func=predictor_,
153 | device=args.gpu, converter=concat_mols, name='train'))
154 | trainer.extend(ROCAUCEvaluator(
155 | val_iter, classifier, eval_func=predictor_,
156 | device=args.gpu, converter=concat_mols, name='val'))
157 | trainer.extend(E.PrintReport([
158 | 'epoch', 'main/loss', 'main/accuracy', 'train/main/roc_auc',
159 | 'validation/main/loss', 'validation/main/accuracy',
160 | 'val/main/roc_auc', 'elapsed_time']))
161 |
162 | trainer.extend(E.ProgressBar(update_interval=10))
163 | if args.resume:
164 | chainer.serializers.load_npz(args.resume, trainer)
165 |
166 | trainer.run()
167 |
168 | with open(os.path.join(args.out, 'args.json'), 'w') as f:
169 | json.dump(vars(args), f, indent=4)
170 | chainer.serializers.save_npz(
171 | os.path.join(args.out, 'predictor.npz'), predictor_)
172 |
173 |
174 | if __name__ == '__main__':
175 | # Disable errors by RDKit occurred in preprocessing Tox21 dataset.
176 | lg = RDLogger.logger()
177 | lg.setLevel(RDLogger.CRITICAL)
178 | # show INFO level log from chainer chemistry
179 | logging.basicConfig(level=logging.INFO)
180 |
181 | main()
182 |
--------------------------------------------------------------------------------
/experiments/tox21/utils.py:
--------------------------------------------------------------------------------
1 | import numpy
2 |
3 |
4 | def save_npz(filepath, datasets):
5 | if not isinstance(datasets, (list, tuple)):
6 | datasets = (datasets, )
7 | numpy.savez(filepath, *datasets)
8 |
9 |
10 | def load_npz(filepath):
11 | load_data = numpy.load(filepath)
12 | result = []
13 | i = 0
14 | while True:
15 | key = 'arr_{}'.format(i)
16 | if key in load_data.keys():
17 | result.append(load_data[key])
18 | i += 1
19 | else:
20 | break
21 | if len(result) == 1:
22 | result = result[0]
23 | return result
24 |
--------------------------------------------------------------------------------
/experiments/tox21/visualize-saliency-tox21.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": false
8 | },
9 | "outputs": [],
10 | "source": [
11 | "#from load import *\n",
12 | "%load_ext autoreload\n",
13 | "%autoreload 2\n",
14 | "import pickle\n",
15 | "from chainer.datasets import TupleDataset\n",
16 | "from chainer.dataset import concat_examples\n",
17 | "from chainer import functions as F, cuda\n",
18 | "from chainer import iterators as I\n",
19 | "from chainer import links as L\n",
20 | "from chainer import optimizers as O\n",
21 | "from chainer import training\n",
22 | "from ipywidgets import interact\n",
23 | "import chainer\n",
24 | "import cupy\n",
25 | "from tqdm import tqdm_notebook as tqdm\n",
26 | "from chainer import serializers\n",
27 | "\n",
28 | "\n",
29 | "from chainer_chemistry.datasets import NumpyTupleDataset\n",
30 | "from chainer_chemistry.dataset.converters import concat_mols\n",
31 | "\n",
32 | "import numpy\n",
33 | "import numpy as np\n",
34 | "\n",
35 | "from rdkit import RDLogger\n",
36 | "from rdkit import Chem\n",
37 | "from rdkit.Chem import rdchem\n",
38 | "from rdkit.Chem import rdDepictor\n",
39 | "from rdkit.Chem.Draw import rdMolDraw2D\n",
40 | "from rdkit.Chem.Draw import IPythonConsole\n",
41 | "from IPython.display import SVG\n",
42 | "from rdkit.Chem import Draw\n",
43 | "from rdkit import rdBase"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "metadata": {
50 | "collapsed": false
51 | },
52 | "outputs": [],
53 | "source": [
54 | "import sys\n",
55 | "import os\n",
56 | "\n",
57 | "sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath('__file__')))))\n",
58 | "from saliency.calculator.gradient_calculator import GradientCalculator\n",
59 | "from models import predictor\n",
60 | "\n",
61 | "sys.path.append(os.path.dirname(os.path.abspath('__file__')))\n",
62 | "import data\n"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 3,
68 | "metadata": {
69 | "collapsed": false
70 | },
71 | "outputs": [],
72 | "source": [
73 | "import logging\n",
74 | "\n",
75 | "# Disable errors by RDKit occurred in preprocessing Tox21 dataset.\n",
76 | "lg = RDLogger.logger()\n",
77 | "lg.setLevel(RDLogger.CRITICAL)\n",
78 | "# show INFO level log from chainer chemistry\n",
79 | "logging.basicConfig(level=logging.INFO)\n"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {
86 | "collapsed": false
87 | },
88 | "outputs": [],
89 | "source": [
90 | "import json\n",
91 | "\n",
92 | "# training result directory\n",
93 | "dirpath = './results/nfpdrop_srmmp'\n",
94 | "\n",
95 | "json_path = os.path.join(dirpath, 'args.json')\n",
96 | "if not os.path.exists(json_path):\n",
97 | " raise ValueError(\n",
98 | " 'json_path {} not found! Execute train_tox21.py beforehand.'.format(json_path))\n",
99 | "with open(json_path, 'r') as f:\n",
100 | " train_args = json.load(f)\n",
101 | "\n",
102 | "method = train_args['method']\n",
103 | "labels = train_args['label'] # 'pyridine'\n",
104 | "\n",
105 | "unit_num = train_args['unit_num']\n",
106 | "conv_layers = train_args['conv_layers']\n",
107 | "class_num = 1\n",
108 | "n_layers = train_args['n_layers']\n",
109 | "dropout_ratio = train_args['dropout_ratio']\n",
110 | "num_train = train_args['num_train']\n",
111 | "\n",
112 | "model_path = os.path.join(dirpath, 'predictor.npz')\n"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 5,
118 | "metadata": {
119 | "collapsed": false
120 | },
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "\n",
127 | "load from cache input/nfp_SR-MMP\n",
128 | "\n"
129 | ]
130 | }
131 | ],
132 | "source": [
133 | "# Pyridine Dataset preparation\n",
134 | "train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels)\n",
135 | "val_mols = [Chem.MolFromSmiles(smi) for smi in tqdm(val_smiles)]"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 6,
141 | "metadata": {
142 | "collapsed": false
143 | },
144 | "outputs": [],
145 | "source": [
146 | "visualize_only_positive = True\n",
147 | "if visualize_only_positive:\n",
148 | " # visualize all only label=1\n",
149 | " pos_index = val.features[:, -1][:, 0] == 1\n",
150 | " # print('pos_index', pos_index.shape, pos_index)\n",
151 | " target_dataset = NumpyTupleDataset(*val.features[pos_index, :])\n",
152 | " target_smiles = val_smiles[pos_index]\n",
153 | "else:\n",
154 | " # visualize all validation data\n",
155 | " target_dataset = val\n",
156 | " target_smiles = val_smiles"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 7,
162 | "metadata": {
163 | "collapsed": false
164 | },
165 | "outputs": [
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "dropout_ratio, n_layers 0.25 1\n",
171 | "Use NFPDrop predictor...\n"
172 | ]
173 | }
174 | ],
175 | "source": [
176 | "# --- model preparation ---\n",
177 | "\n",
178 | "model = predictor.build_predictor(\n",
179 | " method, unit_num, conv_layers, class_num, dropout_ratio, n_layers)\n",
180 | "\n",
181 | "classifier = L.Classifier(model,\n",
182 | " lossfun=F.sigmoid_cross_entropy,\n",
183 | " accfun=F.binary_accuracy)\n",
184 | "\n",
185 | "serializers.load_npz(model_path, model)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 8,
191 | "metadata": {
192 | "collapsed": false
193 | },
194 | "outputs": [],
195 | "source": [
196 | "def clip_original_size(saliency, num_atoms):\n",
197 | " \"\"\"`saliency` array is 0 padded, this method align to have original\n",
198 | " molecule's length\n",
199 | " \"\"\"\n",
200 | " assert len(saliency) == len(num_atoms)\n",
201 | " saliency_list = []\n",
202 | " for i in range(len(saliency)):\n",
203 | " saliency_list.append(saliency[i, :num_atoms[i]])\n",
204 | " return saliency_list\n",
205 | "\n",
206 | "def preprocess_fun(*inputs):\n",
207 | " if len(inputs) == 3:\n",
208 | " atom, adj, t = inputs\n",
209 | " elif len(inputs) == 2:\n",
210 | " atom, adj = inputs\n",
211 | " # HACKING for now...\n",
212 | " # classifier.predictor.pick = True\n",
213 | " # result = classifier.predictor(atom, adj)\n",
214 | " atom_embed = classifier.predictor.graph_conv.embed(atom)\n",
215 | " if len(inputs) == 3:\n",
216 | " return atom_embed, adj, t\n",
217 | " elif len(inputs) == 2:\n",
218 | " return atom_embed, adj\n",
219 | " \n",
220 | "\n",
221 | "def eval_fun(*inputs):\n",
222 | " #atom_embed, adj, t = inputs\n",
223 | " if len(inputs) == 3:\n",
224 | " atom_embed, adj, t = inputs\n",
225 | " elif len(inputs) == 2:\n",
226 | " atom_embed, adj = inputs\n",
227 | " prob = classifier.predictor(atom_embed, adj)\n",
228 | " # print('embed', atom_embed.shape, 'prob', prob.shape)\n",
229 | " out = F.sum(prob)\n",
230 | " # return {'embed': atom_embed, 'out': out}\n",
231 | " return out\n"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "## Saliency calculation & visualization\n",
239 | "\n",
240 | "You can calculate saliency and visualize it in following steps:\n",
241 | "\n",
242 | "1. Instantiate Saliency Calculator\n",
243 | "2. `compute_xxx` (xxx is vanilla/smooth/bayes) method to calculate saliency (gradient etc)\n",
244 | "3. `transform` method to convert saliency into visualizable format.\n",
245 | "4. Use Visualizer class to visualize it."
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 9,
251 | "metadata": {
252 | "collapsed": false
253 | },
254 | "outputs": [],
255 | "source": [
256 | "# 1. instantiation\n",
257 | "gradient_calculator = GradientCalculator(\n",
258 | " classifier, eval_fun=eval_fun,\n",
259 | " # target_key='embed', eval_key='out',\n",
260 | " target_key=0,\n",
261 | ")"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 10,
267 | "metadata": {
268 | "collapsed": false
269 | },
270 | "outputs": [
271 | {
272 | "name": "stdout",
273 | "output_type": "stream",
274 | "text": [
275 | "saliency_arrays (1, 38, 41, 16)\n",
276 | "saliency (38, 41)\n"
277 | ]
278 | }
279 | ],
280 | "source": [
281 | "M = 40\n",
282 | "rates = np.array(list(range(1, 11))) * 0.1\n",
283 | "\n",
284 | "# --- VanillaGrad ---\n",
285 | "# 2. compute\n",
286 | "saliency_arrays_vanilla = gradient_calculator.compute_vanilla(\n",
287 | " target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun)\n",
288 | "# 3. transform\n",
289 | "saliency_vanilla = gradient_calculator.transform(\n",
290 | " saliency_arrays_vanilla, ch_axis=3, method='square')\n",
291 | "# saliency_arrays (1, 28, 43, 64) -> M, minibatch, max_atom, ch_dim\n",
292 | "print('saliency_arrays', saliency_arrays_vanilla.shape)\n",
293 | "# saliency (28, 43) -> minibatch, max_atom\n",
294 | "print('saliency', saliency_vanilla.shape)"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 11,
300 | "metadata": {
301 | "collapsed": false
302 | },
303 | "outputs": [
304 | {
305 | "data": {
306 | "image/svg+xml": [
307 | "\n",
308 | " \n",
309 | "\n",
310 | "\n",
311 | "\n",
312 | "\n",
313 | "\n",
314 | "\n",
315 | "\n",
316 | "\n",
317 | "\n",
318 | "\n",
319 | "\n",
320 | "\n",
321 | "\n",
322 | "\n",
323 | "\n",
324 | "\n",
325 | "\n",
326 | "\n",
327 | "\n",
328 | "\n",
329 | "\n",
330 | "\n",
331 | "\n",
332 | "\n",
333 | "\n",
334 | "\n",
335 | "\n",
336 | "\n",
337 | "\n",
338 | "\n",
339 | "\n",
340 | "\n",
341 | "\n",
342 | "\n",
343 | "\n",
344 | "\n",
345 | "\n",
346 | "\n",
347 | "\n",
348 | "\n",
349 | "\n",
350 | "\n",
351 | "\n",
352 | "\n",
353 | "\n",
354 | "\n",
355 | "\n",
356 | "\n",
357 | "\n",
358 | "\n",
359 | "\n",
360 | "\n",
361 | "\n",
362 | "\n",
363 | "\n",
364 | "OH \n",
365 | "OH \n",
366 | " "
367 | ],
368 | "text/plain": [
369 | ""
370 | ]
371 | },
372 | "execution_count": 11,
373 | "metadata": {},
374 | "output_type": "execute_result"
375 | }
376 | ],
377 | "source": [
378 | "from saliency.visualizer.smiles_visualizer import SmilesVisualizer\n",
379 | "\n",
380 | "sv = SmilesVisualizer()\n",
381 | "# 4. visualize\n",
382 | "index = 2\n",
383 | "sv.visualize(saliency_vanilla[index], target_smiles[index], visualize_ratio=0.3)"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": 12,
389 | "metadata": {
390 | "collapsed": false
391 | },
392 | "outputs": [
393 | {
394 | "data": {
395 | "image/svg+xml": [
396 | "\n",
397 | " \n",
398 | "\n",
399 | "\n",
400 | "\n",
401 | "\n",
402 | "\n",
403 | "\n",
404 | "\n",
405 | "\n",
406 | "\n",
407 | "\n",
408 | "\n",
409 | "\n",
410 | "\n",
411 | "\n",
412 | "\n",
413 | "\n",
414 | "\n",
415 | "\n",
416 | "\n",
417 | "\n",
418 | "\n",
419 | "\n",
420 | "\n",
421 | "\n",
422 | "\n",
423 | "\n",
424 | "\n",
425 | "\n",
426 | "\n",
427 | "\n",
428 | "\n",
429 | "\n",
430 | "\n",
431 | "\n",
432 | "\n",
433 | "\n",
434 | "\n",
435 | "\n",
436 | "\n",
437 | "\n",
438 | "\n",
439 | "\n",
440 | "\n",
441 | "\n",
442 | "\n",
443 | "\n",
444 | "\n",
445 | "\n",
446 | "\n",
447 | "\n",
448 | "\n",
449 | "\n",
450 | "\n",
451 | "\n",
452 | "\n",
453 | "\n",
454 | "\n",
455 | "\n",
456 | "\n",
457 | "\n",
458 | "\n",
459 | "\n",
460 | "\n",
461 | "\n",
462 | "\n",
463 | "\n",
464 | "\n",
465 | "\n",
466 | "\n",
467 | "\n",
468 | "\n",
469 | "\n",
470 | "\n",
471 | "\n",
472 | "\n",
473 | "\n",
474 | "\n",
475 | "\n",
476 | "\n",
477 | "\n",
478 | "\n",
479 | "\n",
480 | "\n",
481 | "\n",
482 | "\n",
483 | "\n",
484 | "\n",
485 | "\n",
486 | "\n",
487 | "\n",
488 | "\n",
489 | "\n",
490 | "\n",
491 | "\n",
492 | "\n",
493 | "\n",
494 | "\n",
495 | "\n",
496 | "\n",
497 | "\n",
498 | "\n",
499 | "\n",
500 | "\n",
501 | "\n",
502 | "\n",
503 | "\n",
504 | "\n",
505 | "\n",
506 | "\n",
507 | "\n",
508 | "\n",
509 | "\n",
510 | "\n",
511 | "\n",
512 | "N \n",
513 | "N \n",
514 | "OH \n",
515 | "N \n",
516 | "N \n",
517 | " "
518 | ],
519 | "text/plain": [
520 | ""
521 | ]
522 | },
523 | "metadata": {},
524 | "output_type": "display_data"
525 | },
526 | {
527 | "data": {
528 | "text/plain": [
529 | ""
530 | ]
531 | },
532 | "execution_count": 12,
533 | "metadata": {},
534 | "output_type": "execute_result"
535 | }
536 | ],
537 | "source": [
538 | "# interactive plot demo\n",
539 | "\n",
540 | "def sv_visualize(i, ratio):\n",
541 | " return sv.visualize(saliency_vanilla[i], target_smiles[i], visualize_ratio=ratio)\n",
542 | "\n",
543 | "interact(sv_visualize, i=(0, len(saliency_vanilla) - 1, 1), ratio=(0, 1.0, 0.1))"
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "execution_count": 19,
549 | "metadata": {
550 | "collapsed": true
551 | },
552 | "outputs": [],
553 | "source": [
554 | "os.makedirs('results/visualize/srmmp', exist_ok=True)"
555 | ]
556 | },
557 | {
558 | "cell_type": "code",
559 | "execution_count": 20,
560 | "metadata": {
561 | "collapsed": false
562 | },
563 | "outputs": [],
564 | "source": [
565 | "# --- SmoothGrad ---\n",
566 | "saliency_arrays_smooth = gradient_calculator.compute_smooth(\n",
567 | " target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,\n",
568 | " M=M, mode='absolute', scale=0.15)"
569 | ]
570 | },
571 | {
572 | "cell_type": "code",
573 | "execution_count": 21,
574 | "metadata": {
575 | "collapsed": false
576 | },
577 | "outputs": [],
578 | "source": [
579 | "# --- BayesGrad ---\n",
580 | "# `train=True` enables dropout, which corresponds to BayesGrad\n",
581 | "saliency_arrays_bayes = gradient_calculator.compute_vanilla(\n",
582 | " target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,\n",
583 | " M=M, train=True)"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": 22,
589 | "metadata": {
590 | "collapsed": false
591 | },
592 | "outputs": [
593 | {
594 | "name": "stdout",
595 | "output_type": "stream",
596 | "text": [
597 | "SR-MMP: 1\n"
598 | ]
599 | },
600 | {
601 | "data": {
602 | "image/svg+xml": [
603 | "\n",
604 | " \n",
605 | "\n",
606 | "\n",
607 | "\n",
608 | "\n",
609 | "\n",
610 | "\n",
611 | "\n",
612 | "\n",
613 | "\n",
614 | "\n",
615 | "\n",
616 | "\n",
617 | "\n",
618 | "\n",
619 | "\n",
620 | "\n",
621 | "\n",
622 | "\n",
623 | "\n",
624 | "\n",
625 | "\n",
626 | "\n",
627 | "\n",
628 | "\n",
629 | "\n",
630 | "\n",
631 | "\n",
632 | "\n",
633 | "\n",
634 | "\n",
635 | "\n",
636 | "\n",
637 | "\n",
638 | "\n",
639 | "\n",
640 | "\n",
641 | "\n",
642 | "\n",
643 | "\n",
644 | "\n",
645 | "\n",
646 | "\n",
647 | "\n",
648 | "\n",
649 | "\n",
650 | "\n",
651 | "\n",
652 | "\n",
653 | "\n",
654 | "\n",
655 | "\n",
656 | "\n",
657 | "\n",
658 | "\n",
659 | "\n",
660 | "\n",
661 | "\n",
662 | "\n",
663 | "\n",
664 | "\n",
665 | "\n",
666 | "\n",
667 | "\n",
668 | "\n",
669 | "\n",
670 | "\n",
671 | "\n",
672 | "\n",
673 | "\n",
674 | "\n",
675 | "\n",
676 | "\n",
677 | "\n",
678 | "\n",
679 | "\n",
680 | "\n",
681 | "\n",
682 | "\n",
683 | "\n",
684 | "\n",
685 | "\n",
686 | "\n",
687 | "\n",
688 | "\n",
689 | "\n",
690 | "\n",
691 | "\n",
692 | "\n",
693 | "\n",
694 | "\n",
695 | "\n",
696 | "\n",
697 | "\n",
698 | "\n",
699 | "\n",
700 | "\n",
701 | "\n",
702 | "\n",
703 | "\n",
704 | "\n",
705 | "\n",
706 | "\n",
707 | "\n",
708 | "\n",
709 | "\n",
710 | "\n",
711 | "\n",
712 | "\n",
713 | "\n",
714 | "\n",
715 | "\n",
716 | "\n",
717 | "\n",
718 | "\n",
719 | "N \n",
720 | "N \n",
721 | "OH \n",
722 | "N \n",
723 | "N \n",
724 | " "
725 | ],
726 | "text/plain": [
727 | ""
728 | ]
729 | },
730 | "metadata": {},
731 | "output_type": "display_data"
732 | }
733 | ],
734 | "source": [
735 | "# Single plot demo\n",
736 | "from IPython.display import display, HTML\n",
737 | "\n",
738 | "def sv_visualize(i, ratio, lam, method, view):\n",
739 | " print('SR-MMP: ', target_dataset.features[i, -1][0])\n",
740 | " saliency_bayes = gradient_calculator.transform(\n",
741 | " saliency_arrays_bayes, ch_axis=3, method=method, lam=lam)\n",
742 | " \n",
743 | " if view == 'view':\n",
744 | " svg_bayes = sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio)\n",
745 | " # display(svg_vanilla, svg_smooth, svg_bayes)\n",
746 | " display(svg_bayes)\n",
747 | " elif view == 'save':\n",
748 | " os.makedirs('results/visualize', exist_ok=True)\n",
749 | " sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_bayes.png'.format(i))\n",
750 | " else:\n",
751 | " print(view, 'not supported')\n",
752 | "\n",
753 | "interact(sv_visualize, i=(0, len(target_dataset) - 1, 1), ratio=(0, 1.0, 0.1), lam=(-3.0, 3.1, 0.1), method=['square', 'abs', 'raw'], view=['view', 'save'])"
754 | ]
755 | },
756 | {
757 | "cell_type": "code",
758 | "execution_count": 25,
759 | "metadata": {
760 | "collapsed": false
761 | },
762 | "outputs": [
763 | {
764 | "name": "stdout",
765 | "output_type": "stream",
766 | "text": [
767 | "SR-MMP: 1\n"
768 | ]
769 | },
770 | {
771 | "data": {
772 | "image/svg+xml": [
773 | "\n",
774 | " \n",
775 | "\n",
776 | "\n",
777 | "\n",
778 | "\n",
779 | "\n",
780 | "\n",
781 | "\n",
782 | "\n",
783 | "\n",
784 | "\n",
785 | "\n",
786 | "\n",
787 | "\n",
788 | "\n",
789 | "\n",
790 | "\n",
791 | "\n",
792 | "\n",
793 | "\n",
794 | "\n",
795 | "\n",
796 | "\n",
797 | "\n",
798 | "\n",
799 | "\n",
800 | "\n",
801 | "\n",
802 | "\n",
803 | "\n",
804 | "\n",
805 | "\n",
806 | "\n",
807 | "\n",
808 | "\n",
809 | "\n",
810 | "\n",
811 | "\n",
812 | "\n",
813 | "\n",
814 | "\n",
815 | "\n",
816 | "\n",
817 | "\n",
818 | "\n",
819 | "\n",
820 | "\n",
821 | "\n",
822 | "\n",
823 | "\n",
824 | "\n",
825 | "\n",
826 | "\n",
827 | "\n",
828 | "\n",
829 | "\n",
830 | "\n",
831 | "\n",
832 | "\n",
833 | "\n",
834 | "\n",
835 | "\n",
836 | "\n",
837 | "\n",
838 | "\n",
839 | "\n",
840 | "\n",
841 | "\n",
842 | "\n",
843 | "\n",
844 | "\n",
845 | "\n",
846 | "\n",
847 | "\n",
848 | "\n",
849 | "\n",
850 | "\n",
851 | "\n",
852 | "\n",
853 | "\n",
854 | "\n",
855 | "\n",
856 | "\n",
857 | "\n",
858 | "\n",
859 | "\n",
860 | "\n",
861 | "\n",
862 | "\n",
863 | "\n",
864 | "\n",
865 | "\n",
866 | "\n",
867 | "\n",
868 | "\n",
869 | "\n",
870 | "\n",
871 | "O \n",
872 | "N + \n",
873 | "O - \n",
874 | "N \n",
875 | "F \n",
876 | "N \n",
877 | "NH \n",
878 | " "
879 | ],
880 | "text/plain": [
881 | ""
882 | ]
883 | },
884 | "metadata": {},
885 | "output_type": "display_data"
886 | }
887 | ],
888 | "source": [
889 | "# Multiple plot demo\n",
890 | "from IPython.display import display, HTML\n",
891 | "\n",
892 | "def sv_visualize(i, ratio, lam, method, view):\n",
893 | " print('SR-MMP: ', target_dataset.features[i, -1][0])\n",
894 | " saliency_vanilla = gradient_calculator.transform(\n",
895 | " saliency_arrays_vanilla, ch_axis=3, method=method, lam=0)\n",
896 | " saliency_smooth = gradient_calculator.transform(\n",
897 | " saliency_arrays_smooth, ch_axis=3, method=method, lam=lam)\n",
898 | " saliency_bayes = gradient_calculator.transform(\n",
899 | " saliency_arrays_bayes, ch_axis=3, method=method, lam=lam)\n",
900 | " \n",
901 | " if view == 'view':\n",
902 | " svg_vanilla = sv.visualize(saliency_vanilla[i], target_smiles[i], visualize_ratio=ratio)\n",
903 | " svg_smooth = sv.visualize(saliency_smooth[i], target_smiles[i], visualize_ratio=ratio)\n",
904 | " svg_bayes = sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio)\n",
905 | " # display(svg_vanilla, svg_smooth, svg_bayes)\n",
906 | " display(svg_bayes)\n",
907 | " elif view == 'save':\n",
908 | " os.makedirs('results/visualize', exist_ok=True)\n",
909 | " sv.visualize(saliency_vanilla[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_vanilla.png'.format(i))\n",
910 | " sv.visualize(saliency_smooth[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_smooth.png'.format(i))\n",
911 | " sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_bayes.png'.format(i))\n",
912 | " else:\n",
913 | " print(view, 'not supported')\n",
914 | "\n",
915 | "interact(sv_visualize, i=(0, len(target_dataset) - 1, 1), ratio=(0, 1.0, 0.1), lam=(-3.0, 3.1, 0.1), method=['square', 'abs', 'raw'], view=['view', 'save'])"
916 | ]
917 | },
918 | {
919 | "cell_type": "code",
920 | "execution_count": null,
921 | "metadata": {
922 | "collapsed": true
923 | },
924 | "outputs": [],
925 | "source": []
926 | },
927 | {
928 | "cell_type": "markdown",
929 | "metadata": {},
930 | "source": [
931 | "## Visualization of specific molecules, including Tyrphostin 9"
932 | ]
933 | },
934 | {
935 | "cell_type": "code",
936 | "execution_count": 17,
937 | "metadata": {
938 | "collapsed": false
939 | },
940 | "outputs": [
941 | {
942 | "name": "stderr",
943 | "output_type": "stream",
944 | "text": [
945 | "100%|██████████| 4/4 [00:00<00:00, 378.38it/s]\n",
946 | "INFO:chainer_chemistry.dataset.parsers.csv_file_parser:Preprocess finished. FAIL 0, SUCCESS 4, TOTAL 4\n"
947 | ]
948 | }
949 | ],
950 | "source": [
951 | "import tempfile\n",
952 | "import pandas\n",
953 | "\n",
954 | "from chainer_chemistry.dataset.parsers import CSVFileParser\n",
955 | "from chainer_chemistry.dataset.preprocessors import preprocess_method_dict\n",
956 | "\n",
957 | "smiles_list = [\n",
958 | " 'CC(C)(C=O)Cc1cc(C(C)(C)C)c(O)c(C(C)(C)C)c1', # (a)\n",
959 | "# 'CC(C)(C=O)Cc1cc(C(C)(C)C)c(C)c(C(C)(C)C)c1',\n",
960 | " 'CC(C)(C)C1=C(C)C(C(C)(C)C)=CC(/C=C(C#N)\\C#N)=C1', # (b)\n",
961 | " 'CC(C)(C)C1=C(O)C(C(C)(C)C)=CC(/C=C(C#N)\\C#N)=C1', # (c)\n",
962 | " 'O=C1Nc2ccc(I)cc2C1=Cc1cc(Br)c(O)c(Br)c1' # (d) \n",
963 | "]\n",
964 | "\n",
965 | "preprocessor = preprocess_method_dict['nfp']()\n",
966 | "parser = CSVFileParser(preprocessor,\n",
967 | " labels=None, smiles_col='smiles')\n",
968 | "\n",
969 | "with tempfile.TemporaryDirectory() as dirpath:\n",
970 | " csv_path = os.path.join(dirpath, 'tmp.csv')\n",
971 | " df = pandas.DataFrame({'smiles': smiles_list})\n",
972 | " df.to_csv(csv_path)\n",
973 | " result = parser.parse(csv_path, return_smiles=True)\n",
974 | "\n",
975 | "custom_target_dataset, custom_target_smiles = result['dataset'], result['smiles']"
976 | ]
977 | },
978 | {
979 | "cell_type": "code",
980 | "execution_count": 26,
981 | "metadata": {
982 | "collapsed": false
983 | },
984 | "outputs": [
985 | {
986 | "data": {
987 | "image/svg+xml": [
988 | "\n",
989 | " \n",
990 | "\n",
991 | "\n",
992 | "\n",
993 | "\n",
994 | "\n",
995 | "\n",
996 | "\n",
997 | "\n",
998 | "\n",
999 | "\n",
1000 | "\n",
1001 | "\n",
1002 | "\n",
1003 | "\n",
1004 | "\n",
1005 | "\n",
1006 | "\n",
1007 | "\n",
1008 | "\n",
1009 | "\n",
1010 | "\n",
1011 | "\n",
1012 | "\n",
1013 | "\n",
1014 | "\n",
1015 | "\n",
1016 | "\n",
1017 | "\n",
1018 | "\n",
1019 | "\n",
1020 | "\n",
1021 | "\n",
1022 | "\n",
1023 | "\n",
1024 | "\n",
1025 | "\n",
1026 | "\n",
1027 | "\n",
1028 | "\n",
1029 | "\n",
1030 | "\n",
1031 | "\n",
1032 | "\n",
1033 | "\n",
1034 | "\n",
1035 | "\n",
1036 | "\n",
1037 | "\n",
1038 | "\n",
1039 | "\n",
1040 | "\n",
1041 | "\n",
1042 | "\n",
1043 | "\n",
1044 | "\n",
1045 | "\n",
1046 | "\n",
1047 | "\n",
1048 | "O \n",
1049 | "OH \n",
1050 | " "
1051 | ],
1052 | "text/plain": [
1053 | ""
1054 | ]
1055 | },
1056 | "metadata": {},
1057 | "output_type": "display_data"
1058 | }
1059 | ],
1060 | "source": [
1061 | "# --- BayesGrad ---\n",
1062 | "# `train=True` enables dropout, which corresponds to BayesGrad\n",
1063 | "saliency_arrays_bayes = gradient_calculator.compute_vanilla(\n",
1064 | " custom_target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,\n",
1065 | " M=M, train=True)\n",
1066 | "\n",
1067 | "def sv_visualize(i, ratio, lam, method, view):\n",
1068 | " saliency_bayes = gradient_calculator.transform(\n",
1069 | " saliency_arrays_bayes, ch_axis=3, method=method, lam=lam)\n",
1070 | " \n",
1071 | " if view == 'view':\n",
1072 | " svg_bayes = sv.visualize(saliency_bayes[i], custom_target_smiles[i], visualize_ratio=ratio)\n",
1073 | " display(svg_bayes)\n",
1074 | " elif view == 'save':\n",
1075 | " os.makedirs('results/visualize', exist_ok=True)\n",
1076 | " sv.visualize(saliency_bayes[i], custom_target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/custom_{}_bayes.png'.format(i))\n",
1077 | " else:\n",
1078 | " print(view, 'not supported')\n",
1079 | "\n",
1080 | "interact(sv_visualize, i=(0, len(custom_target_dataset) - 1, 1), ratio=(0, 1.0, 0.1), lam=(-3.0, 3.1, 0.1), method=['square', 'abs', 'raw'], view=['view', 'save'])"
1081 | ]
1082 | },
1083 | {
1084 | "cell_type": "code",
1085 | "execution_count": null,
1086 | "metadata": {
1087 | "collapsed": true
1088 | },
1089 | "outputs": [],
1090 | "source": []
1091 | },
1092 | {
1093 | "cell_type": "code",
1094 | "execution_count": null,
1095 | "metadata": {
1096 | "collapsed": true
1097 | },
1098 | "outputs": [],
1099 | "source": []
1100 | }
1101 | ],
1102 | "metadata": {
1103 | "anaconda-cloud": {},
1104 | "kernelspec": {
1105 | "display_name": "Python [default]",
1106 | "language": "python",
1107 | "name": "python3"
1108 | },
1109 | "language_info": {
1110 | "codemirror_mode": {
1111 | "name": "ipython",
1112 | "version": 3
1113 | },
1114 | "file_extension": ".py",
1115 | "mimetype": "text/x-python",
1116 | "name": "python",
1117 | "nbconvert_exporter": "python",
1118 | "pygments_lexer": "ipython3",
1119 | "version": "3.5.2"
1120 | },
1121 | "widgets": {
1122 | "state": {
1123 | "4ddc337df552443b94c357e8015dae4f": {
1124 | "views": [
1125 | {
1126 | "cell_index": 21.0
1127 | }
1128 | ]
1129 | },
1130 | "86b5e4662cd34a549c070dc3e90515b3": {
1131 | "views": [
1132 | {
1133 | "cell_index": 12.0
1134 | }
1135 | ]
1136 | },
1137 | "8fa97c75fec14b5abc88618435d810c2": {
1138 | "views": [
1139 | {
1140 | "cell_index": 17.0
1141 | }
1142 | ]
1143 | },
1144 | "cb91b59b01c44714911d7b8b48dc954e": {
1145 | "views": [
1146 | {
1147 | "cell_index": 16.0
1148 | }
1149 | ]
1150 | },
1151 | "ec91304623f74358bd014a7ff21f88f4": {
1152 | "views": [
1153 | {
1154 | "cell_index": 4.0
1155 | }
1156 | ]
1157 | }
1158 | },
1159 | "version": "1.2.0"
1160 | }
1161 | },
1162 | "nbformat": 4,
1163 | "nbformat_minor": 2
1164 | }
1165 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/models/__init__.py
--------------------------------------------------------------------------------
/models/ggnn_drop.py:
--------------------------------------------------------------------------------
1 | """
2 | Copied from chainer chemistry v0.3.0
3 | Only the difference is to add dropout function
4 | """
5 | import chainer
6 | from chainer import cuda
7 | from chainer import functions
8 | from chainer import links
9 |
10 | import chainer_chemistry
11 | from chainer_chemistry.config import MAX_ATOMIC_NUM
12 | from chainer_chemistry.links import EmbedAtomID
13 | from chainer_chemistry.links import GraphLinear
14 |
15 |
16 | class GGNNDrop(chainer.Chain):
17 | """Gated Graph Neural Networks (GGNN)
18 |
19 | See: Li, Y., Tarlow, D., Brockschmidt, M., & Zemel, R. (2015).\
20 | Gated graph sequence neural networks. \
21 | `arXiv:1511.05493 `_
22 |
23 | Args:
24 | out_dim (int): dimension of output feature vector
25 | hidden_dim (int): dimension of feature vector
26 | associated to each atom
27 | n_layers (int): number of layers
28 | n_atom_types (int): number of types of atoms
29 | concat_hidden (bool): If set to True, readout is executed in each layer
30 | and the result is concatenated
31 | weight_tying (bool): enable weight_tying or not
32 |
33 | """
34 | NUM_EDGE_TYPE = 4
35 |
36 | def __init__(self, out_dim, hidden_dim=16,
37 | n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False,
38 | weight_tying=True, dropout_ratio=0):
39 | super(GGNNDrop, self).__init__()
40 | n_readout_layer = n_layers if concat_hidden else 1
41 | n_message_layer = 1 if weight_tying else n_layers
42 | with self.init_scope():
43 | # Update
44 | self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types)
45 | self.message_layers = chainer.ChainList(
46 | *[GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim)
47 | for _ in range(n_message_layer)]
48 | )
49 | self.update_layer = links.GRU(2 * hidden_dim, hidden_dim)
50 | # Readout
51 | self.i_layers = chainer.ChainList(
52 | *[GraphLinear(2 * hidden_dim, out_dim)
53 | for _ in range(n_readout_layer)]
54 | )
55 | self.j_layers = chainer.ChainList(
56 | *[GraphLinear(hidden_dim, out_dim)
57 | for _ in range(n_readout_layer)]
58 | )
59 | self.out_dim = out_dim
60 | self.hidden_dim = hidden_dim
61 | self.n_layers = n_layers
62 | self.concat_hidden = concat_hidden
63 | self.weight_tying = weight_tying
64 | self.dropout_ratio = dropout_ratio
65 |
66 | def update(self, h, adj, step=0):
67 | # --- Message & Update part ---
68 | # (minibatch, atom, ch)
69 | mb, atom, ch = h.shape
70 | out_ch = ch
71 | message_layer_index = 0 if self.weight_tying else step
72 | m = functions.reshape(self.message_layers[message_layer_index](h),
73 | (mb, atom, out_ch, self.NUM_EDGE_TYPE))
74 | # m: (minibatch, atom, ch, edge_type)
75 | # Transpose
76 | m = functions.transpose(m, (0, 3, 1, 2))
77 | # m: (minibatch, edge_type, atom, ch)
78 |
79 | adj = functions.reshape(adj, (mb * self.NUM_EDGE_TYPE, atom, atom))
80 | # (minibatch * edge_type, atom, out_ch)
81 | m = functions.reshape(m, (mb * self.NUM_EDGE_TYPE, atom, out_ch))
82 |
83 | m = chainer_chemistry.functions.matmul(adj, m)
84 |
85 | # (minibatch * edge_type, atom, out_ch)
86 | m = functions.reshape(m, (mb, self.NUM_EDGE_TYPE, atom, out_ch))
87 | # Take sum
88 | m = functions.sum(m, axis=1)
89 | # (minibatch, atom, out_ch)
90 |
91 | # --- Update part ---
92 | # Contraction
93 | h = functions.reshape(h, (mb * atom, ch))
94 |
95 | # Contraction
96 | m = functions.reshape(m, (mb * atom, ch))
97 |
98 | out_h = self.update_layer(functions.concat((h, m), axis=1))
99 | # Expansion
100 | out_h = functions.reshape(out_h, (mb, atom, ch))
101 | return out_h
102 |
103 | def readout(self, h, h0, step=0):
104 | # --- Readout part ---
105 | index = step if self.concat_hidden else 0
106 | # h, h0: (minibatch, atom, ch)
107 | g = functions.sigmoid(
108 | self.i_layers[index](functions.concat((h, h0), axis=2))) \
109 | * self.j_layers[index](h)
110 | g = functions.sum(g, axis=1) # sum along atom's axis
111 | return g
112 |
113 | def __call__(self, atom_array, adj):
114 | """Forward propagation
115 |
116 | Args:
117 | atom_array (numpy.ndarray): minibatch of molecular which is
118 | represented with atom IDs (representing C, O, S, ...)
119 | `atom_array[mol_index, atom_index]` represents `mol_index`-th
120 | molecule's `atom_index`-th atomic number
121 | adj (numpy.ndarray): minibatch of adjancency matrix with edge-type
122 | information
123 |
124 | Returns:
125 | ~chainer.Variable: minibatch of fingerprint
126 | """
127 | # reset state
128 | self.update_layer.reset_state()
129 | if atom_array.dtype == self.xp.int32:
130 | h = self.embed(atom_array) # (minibatch, max_num_atoms)
131 | else:
132 | h = atom_array
133 | h0 = functions.copy(h, cuda.get_device_from_array(h.data).id)
134 | g_list = []
135 | for step in range(self.n_layers):
136 | h = chainer.functions.dropout(h, ratio=self.dropout_ratio)
137 | h = self.update(h, adj, step)
138 | if self.concat_hidden:
139 | g = self.readout(h, h0, step)
140 | g_list.append(g)
141 |
142 | if self.concat_hidden:
143 | return functions.concat(g_list, axis=1)
144 | else:
145 | g = self.readout(h, h0, 0)
146 | return g
147 |
--------------------------------------------------------------------------------
/models/nfp_drop.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of Neural Fingerprint
3 |
4 | Copied from chainer chemistry v0.3.0
5 | Only the difference is to add dropout function
6 | """
7 | import chainer
8 | from chainer import cuda, Variable
9 | from chainer import functions
10 | import numpy
11 |
12 | import chainer_chemistry
13 | from chainer_chemistry.config import MAX_ATOMIC_NUM
14 |
15 |
16 | class NFPUpdate(chainer.Chain):
17 | """NFP sub module for update part
18 |
19 | Args:
20 | in_channels (int): input channel dimension
21 | out_channels (int): output channel dimension
22 | max_degree (int): max degree of edge
23 | """
24 |
25 | def __init__(self, in_channels, out_channels, max_degree=6):
26 | super(NFPUpdate, self).__init__()
27 | num_degree_type = max_degree + 1
28 | with self.init_scope():
29 | self.graph_linears = chainer.ChainList(
30 | *[chainer_chemistry.links.GraphLinear(in_channels, out_channels)
31 | for _ in range(num_degree_type)]
32 | )
33 | self.max_degree = max_degree
34 | self.in_channels = in_channels
35 | self.out_channels = out_channels
36 |
37 | def __call__(self, h, adj, deg_conds):
38 | # h (minibatch, atom, ch)
39 | # h encodes each atom's info in ch axis of size hidden_dim
40 | # adjs (minibatch, atom, atom)
41 |
42 | # --- Message part ---
43 | # Take sum along adjacent atoms
44 |
45 | # fv (minibatch, atom, ch)
46 | fv = chainer_chemistry.functions.matmul(adj, h)
47 |
48 | # --- Update part ---
49 | # s0, s1, s2 = fv.shape
50 | if self.xp is numpy:
51 | zero_array = numpy.zeros(fv.shape, dtype=numpy.float32)
52 | else:
53 | zero_array = self.xp.zeros_like(fv)
54 |
55 | fvds = [functions.where(cond, fv, zero_array) for cond in deg_conds]
56 |
57 | out_h = 0
58 | for graph_linear, fvd in zip(self.graph_linears, fvds):
59 | out_h = out_h + graph_linear(fvd)
60 |
61 | # out_x shape (minibatch, max_num_atoms, hidden_dim)
62 | out_h = functions.sigmoid(out_h)
63 | return out_h
64 |
65 |
66 | class NFPReadout(chainer.Chain):
67 | """NFP sub module for readout part
68 |
69 | Args:
70 | in_channels (int): dimension of feature vector associated to each
71 | atom (node)
72 | out_size (int): output dimension of feature vector associated to each
73 | molecule (graph)
74 | """
75 |
76 | def __init__(self, in_channels, out_size):
77 | super(NFPReadout, self).__init__()
78 | with self.init_scope():
79 | self.output_weight = chainer_chemistry.links.GraphLinear(
80 | in_channels, out_size)
81 | self.in_channels = in_channels
82 | self.out_size = out_size
83 |
84 | def __call__(self, h):
85 | # input h shape (minibatch, atom, ch)
86 | # return i shape (minibatch, ch)
87 |
88 | # --- Readout part ---
89 | i = self.output_weight(h)
90 | i = functions.softmax(i, axis=2) # softmax along channel axis
91 | i = functions.sum(i, axis=1) # sum along atom's axis
92 | return i
93 |
94 |
95 | class NFPDrop(chainer.Chain):
96 |
97 | """Neural Finger Print (NFP)
98 |
99 | Args:
100 | out_dim (int): dimension of output feature vector
101 | hidden_dim (int): dimension of feature vector
102 | associated to each atom
103 | max_degree (int): max degree of atoms
104 | when molecules are regarded as graphs
105 | n_atom_types (int): number of types of atoms
106 | n_layer (int): number of layers
107 | """
108 |
109 | def __init__(self, out_dim, hidden_dim=16, n_layers=4, max_degree=6,
110 | n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False,
111 | dropout_ratio=0):
112 | super(NFPDrop, self).__init__()
113 | num_degree_type = max_degree + 1
114 | with self.init_scope():
115 | self.embed = chainer_chemistry.links.EmbedAtomID(
116 | in_size=n_atom_types, out_size=hidden_dim)
117 | self.layers = chainer.ChainList(
118 | *[NFPUpdate(hidden_dim, hidden_dim, max_degree=max_degree)
119 | for _ in range(n_layers)])
120 | self.read_out_layers = chainer.ChainList(
121 | *[NFPReadout(hidden_dim, out_dim)
122 | for _ in range(n_layers)])
123 | self.out_dim = out_dim
124 | self.hidden_dim = hidden_dim
125 | self.max_degree = max_degree
126 | self.num_degree_type = num_degree_type
127 | self.n_layers = n_layers
128 | self.concat_hidden = concat_hidden
129 | self.dropout_ratio = dropout_ratio
130 |
131 | def __call__(self, atom_array, adj):
132 | """Forward propagation
133 |
134 | Args:
135 | atom_array (numpy.ndarray): minibatch of molecular which is
136 | represented with atom IDs (representing C, O, S, ...)
137 | `atom_array[mol_index, atom_index]` represents `mol_index`-th
138 | molecule's `atom_index`-th atomic number
139 | adj (numpy.ndarray): minibatch of adjancency matrix
140 | `adj[mol_index]` represents `mol_index`-th molecule's
141 | adjacency matrix
142 |
143 | Returns:
144 | ~chainer.Variable: minibatch of fingerprint
145 | """
146 | if atom_array.dtype == self.xp.int32:
147 | # atom_array: (minibatch, atom)
148 | h = self.embed(atom_array)
149 | else:
150 | h = atom_array
151 | # h: (minibatch, atom, ch)
152 | g = 0
153 |
154 | # --- NFP update & readout ---
155 | # degree_mat: (minibatch, max_num_atoms)
156 | if isinstance(adj, Variable):
157 | adj_array = adj.data
158 | else:
159 | adj_array = adj
160 | degree_mat = self.xp.sum(adj_array, axis=1)
161 | # deg_condst: (minibatch, atom, ch)
162 | deg_conds = [self.xp.broadcast_to(
163 | ((degree_mat - degree) == 0)[:, :, None], h.shape)
164 | for degree in range(1, self.num_degree_type + 1)]
165 | g_list = []
166 | for update, readout in zip(self.layers, self.read_out_layers):
167 | h = chainer.functions.dropout(h, ratio=self.dropout_ratio)
168 | h = update(h, adj, deg_conds)
169 | dg = readout(h)
170 | g = g + dg
171 | if self.concat_hidden:
172 | g_list.append(g)
173 |
174 | if self.concat_hidden:
175 | return functions.concat(g_list, axis=2)
176 | else:
177 | return g
178 |
--------------------------------------------------------------------------------
/models/predictor.py:
--------------------------------------------------------------------------------
1 | import chainer
2 | from chainer import functions as F
3 | import chainer.links as L
4 | import sys
5 | import os
6 |
7 | from chainer_chemistry.models import GGNN
8 | from chainer_chemistry.models import NFP
9 | from chainer_chemistry.models import SchNet
10 | from chainer_chemistry.models import WeaveNet
11 |
12 | sys.path.append(os.path.dirname(__file__))
13 | from models.nfp_drop import NFPDrop
14 | from models.ggnn_drop import GGNNDrop
15 |
16 |
17 | class MLPDrop(chainer.Chain):
18 | """Basic implementation for MLP with dropout"""
19 | # def __init__(self, hidden_dim, out_dim, n_layers=2, activation=F.relu):
20 | def __init__(self, out_dim, hidden_dim, n_layers=1, activation=F.relu,
21 | dropout_ratio=0.25):
22 | super(MLPDrop, self).__init__()
23 | if n_layers <= 0:
24 | raise ValueError('n_layers must be positive integer, but set {}'
25 | .format(n_layers))
26 | layers = [L.Linear(None, hidden_dim) for i in range(n_layers - 1)]
27 | with self.init_scope():
28 | self.layers = chainer.ChainList(*layers)
29 | self.l_out = L.Linear(None, out_dim)
30 | self.activation = activation
31 | self.dropout_ratio = dropout_ratio
32 |
33 | def __call__(self, x):
34 | h = F.dropout(x, ratio=self.dropout_ratio)
35 | for l in self.layers:
36 | h = F.dropout(self.activation(l(h)), ratio=self.dropout_ratio)
37 | h = self.l_out(h)
38 | return h
39 |
40 |
41 | def build_predictor(method, n_unit, conv_layers, class_num,
42 | dropout_ratio=0.25, n_layers=1):
43 | print('dropout_ratio, n_layers', dropout_ratio, n_layers)
44 | mlp_class = MLPDrop
45 | if method == 'nfp':
46 | print('Use NFP predictor...')
47 | predictor = GraphConvPredictor(
48 | NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
49 | mlp_class(out_dim=class_num, hidden_dim=n_unit, dropout_ratio=dropout_ratio,
50 | n_layers=n_layers))
51 | elif method == 'nfpdrop':
52 | print('Use NFPDrop predictor...')
53 | predictor = GraphConvPredictor(
54 | NFPDrop(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers,
55 | dropout_ratio=dropout_ratio),
56 | mlp_class(out_dim=class_num, hidden_dim=n_unit,
57 | dropout_ratio=dropout_ratio,
58 | n_layers=n_layers))
59 | elif method == 'ggnn':
60 | print('Use GGNN predictor...')
61 | predictor = GraphConvPredictor(
62 | GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers),
63 | mlp_class(out_dim=class_num, hidden_dim=n_unit,
64 | dropout_ratio=dropout_ratio, n_layers=n_layers))
65 | elif method == 'ggnndrop':
66 | print('Use GGNNDrop predictor...')
67 | predictor = GraphConvPredictor(
68 | GGNNDrop(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers,
69 | dropout_ratio=dropout_ratio),
70 | mlp_class(out_dim=class_num, hidden_dim=n_unit,
71 | dropout_ratio=dropout_ratio, n_layers=n_layers))
72 | elif method == 'schnet':
73 | print('Use SchNet predictor...')
74 | predictor = SchNet(out_dim=class_num, hidden_dim=n_unit,
75 | n_layers=conv_layers, readout_hidden_dim=n_unit)
76 | elif method == 'weavenet':
77 | print('Use WeaveNet predictor...')
78 | n_atom = 20
79 | n_sub_layer = 1
80 | weave_channels = [50] * conv_layers
81 | predictor = GraphConvPredictor(
82 | WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit,
83 | n_sub_layer=n_sub_layer, n_atom=n_atom),
84 | mlp_class(out_dim=class_num, hidden_dim=n_unit,
85 | dropout_ratio=dropout_ratio, n_layers=n_layers))
86 | else:
87 | raise ValueError('[ERROR] Invalid predictor: method={}'.format(method))
88 | return predictor
89 |
90 |
91 | class GraphConvPredictor(chainer.Chain):
92 | """Wrapper class that combines a graph convolution and MLP."""
93 |
94 | def __init__(self, graph_conv, mlp):
95 | """Constructor
96 |
97 | Args:
98 | graph_conv: graph convolution network to obtain molecule feature
99 | representation
100 | mlp: multi layer perceptron, used as final connected layer
101 | """
102 |
103 | super(GraphConvPredictor, self).__init__()
104 | with self.init_scope():
105 | self.graph_conv = graph_conv
106 | self.mlp = mlp
107 |
108 | def __call__(self, atoms, adjs):
109 | x = self.graph_conv(atoms, adjs)
110 | x = self.mlp(x)
111 | return x
112 |
113 | def predict(self, atoms, adjs):
114 | with chainer.no_backprop_mode(), chainer.using_config('train', False):
115 | x = self.__call__(atoms, adjs)
116 | return F.sigmoid(x)
117 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | chainer==4.2.0
2 | chainer-chemistry==0.4.0
3 | matplotlib==2.2.2
4 | future==0.16.0
5 | cairosvg==2.1.3
6 | ipython==5.1.0
7 |
--------------------------------------------------------------------------------
/saliency/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/saliency/__init__.py
--------------------------------------------------------------------------------
/saliency/calculator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/saliency/calculator/__init__.py
--------------------------------------------------------------------------------
/saliency/calculator/base_calculator.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from abc import ABCMeta
3 | from abc import abstractmethod
4 |
5 | import chainer
6 | from chainer import cuda
7 | from chainer.dataset.convert import concat_examples, \
8 | _concat_arrays_with_padding
9 | from chainer.iterators import SerialIterator
10 |
11 | from future.utils import with_metaclass
12 | import numpy
13 |
14 |
15 | _sampling_axis = 0
16 |
17 |
18 | def _to_tuple(x):
19 | if not isinstance(x, tuple):
20 | x = (x,)
21 | return x
22 |
23 |
24 | def _to_variable(x):
25 | if not isinstance(x, chainer.Variable):
26 | x = chainer.Variable(x)
27 | return x
28 |
29 |
30 | def _extract_numpy(x):
31 | if isinstance(x, chainer.Variable):
32 | x = x.data
33 | return cuda.to_cpu(x)
34 |
35 |
36 | class BaseCalculator(with_metaclass(ABCMeta, object)):
37 |
38 | def __init__(self, model):
39 | self.model = model # type: chainer.Chain
40 | self._device = cuda.get_device_from_array(*model.params()).id
41 | # print('device', self._device)
42 |
43 | def compute(
44 | self, data, M=1, method='vanilla', batchsize=16,
45 | converter=concat_examples, retain_inputs=False, preprocess_fn=None,
46 | postprocess_fn=None, train=False):
47 | method_dict = {
48 | 'vanilla': self.compute_vanilla,
49 | 'smooth': self.compute_smooth,
50 | 'bayes': self.compute_bayes,
51 | }
52 | return method_dict[method](
53 | data, batchsize=batchsize, M=M, converter=converter,
54 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn,
55 | postprocess_fn=postprocess_fn, train=train)
56 |
57 | def compute_vanilla(self, data, batchsize=16, M=1,
58 | converter=concat_examples, retain_inputs=False,
59 | preprocess_fn=None, postprocess_fn=None, train=False):
60 | """VanillaGrad"""
61 | saliency_list = []
62 | for _ in range(M):
63 | with chainer.using_config('train', train):
64 | saliency = self._forward(
65 | data, fn=self._compute_core, batchsize=batchsize,
66 | converter=converter,
67 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn,
68 | postprocess_fn=postprocess_fn)
69 | saliency_list.append(cuda.to_cpu(saliency))
70 | return numpy.stack(saliency_list, axis=_sampling_axis)
71 |
72 | def compute_smooth(self, data, M=10, batchsize=16,
73 | converter=concat_examples, retain_inputs=False,
74 | preprocess_fn=None, postprocess_fn=None, train=False,
75 | scale=0.15, mode='relative'):
76 | """SmoothGrad
77 | Reference
78 | https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L54
79 | """
80 |
81 | def smooth_fn(*inputs):
82 | #TODO: support cupy input
83 | target_array = inputs[self.target_key].data
84 | xp = cuda.get_array_module(target_array)
85 |
86 | noise = xp.random.normal(
87 | 0, scale, inputs[self.target_key].data.shape)
88 | if mode == 'absolute':
89 | # `scale` is used as is
90 | pass
91 | elif mode == 'relative':
92 | # `scale_axis` is used to calculate `max` and `min` of target_array
93 | # As default, all axes except batch axis are treated as `scale_axis`.
94 | scale_axis = tuple(range(1, target_array.ndim))
95 | noise = noise * (xp.max(target_array, axis=scale_axis, keepdims=True)
96 | - xp.min(target_array, axis=scale_axis, keepdims=True))
97 | # print('[DEBUG] noise', noise.shape)
98 | else:
99 | raise ValueError("[ERROR] Unexpected value mode={}"
100 | .format(mode))
101 | inputs[self.target_key].data += noise
102 | return self._compute_core(*inputs)
103 |
104 | saliency_list = []
105 | for _ in range(M):
106 | with chainer.using_config('train', train):
107 | saliency = self._forward(
108 | data, fn=smooth_fn, batchsize=batchsize,
109 | converter=converter,
110 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn,
111 | postprocess_fn=postprocess_fn)
112 | saliency_array = cuda.to_cpu(saliency)
113 | saliency_list.append(saliency_array)
114 | return numpy.stack(saliency_list, axis=_sampling_axis)
115 |
116 | def compute_bayes(self, data, M=10, batchsize=16,
117 | converter=concat_examples, retain_inputs=False,
118 | preprocess_fn=None, postprocess_fn=None, train=True):
119 | """BayesGrad"""
120 | warnings.warn('`compute_bayes` method maybe deleted in the future...'
121 | 'please use `compute_vanilla` with train=True instead.')
122 | assert train
123 | # This is actually just an alias of `compute_vanilla` with `train=True`
124 | # Maybe deleted in the future.
125 | return self.compute_vanilla(
126 | data, M=M, batchsize=batchsize, converter=converter,
127 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn,
128 | postprocess_fn=postprocess_fn, train=True)
129 |
130 | def transform(self, saliency_arrays, method='raw', lam=0, ch_axis=2):
131 | if method == 'raw':
132 | h = numpy.sum(saliency_arrays, axis=ch_axis)
133 | elif method == 'abs':
134 | h = numpy.sum(numpy.abs(saliency_arrays), axis=ch_axis)
135 | elif method == 'square':
136 | h = numpy.sum(saliency_arrays ** 2, axis=ch_axis)
137 | else:
138 | raise ValueError('')
139 |
140 | sampling_axis = _sampling_axis
141 | if lam == 0:
142 | return numpy.mean(h, axis=sampling_axis)
143 | else:
144 | if h.shape[sampling_axis] == 1:
145 | # VanillaGrad does not support LCB/UCB calculation
146 | raise ValueError(
147 | 'saliency_arrays.shape[{}] must be larget than 1'.format(sampling_axis))
148 | return numpy.mean(h, axis=sampling_axis) + lam * numpy.std(
149 | h, axis=sampling_axis)
150 |
151 | @abstractmethod
152 | def _compute_core(self, *inputs):
153 | raise NotImplementedError
154 |
155 | def _forward(self, data, fn=None, batchsize=16,
156 | converter=concat_examples, retain_inputs=False,
157 | preprocess_fn=None, postprocess_fn=None):
158 | """Forward data by iterating with batch
159 |
160 | Args:
161 | data: "train_x array" or "chainer dataset"
162 | fn (Callable): Main function to forward. Its input argument is
163 | either Variable, cupy.ndarray or numpy.ndarray, and returns
164 | Variable.
165 | batchsize (int): batch size
166 | converter (Callable): convert from `data` to `inputs`
167 | retain_inputs (bool): If True, this instance keeps inputs in
168 | `self.inputs` or not.
169 | preprocess_fn (Callable): Its input is numpy.ndarray or
170 | cupy.ndarray, it can return either Variable, cupy.ndarray or
171 | numpy.ndarray
172 | postprocess_fn (Callable): Its input argument is Variable,
173 | but this method may return either Variable, cupy.ndarray or
174 | numpy.ndarray.
175 |
176 | Returns (tuple or numpy.ndarray): forward result
177 |
178 | """
179 | input_list = None
180 | output_list = None
181 | it = SerialIterator(data, batch_size=batchsize, repeat=False,
182 | shuffle=False)
183 | for batch in it:
184 | inputs = converter(batch, self._device)
185 | inputs = _to_tuple(inputs)
186 |
187 | if preprocess_fn:
188 | inputs = preprocess_fn(*inputs)
189 | inputs = _to_tuple(inputs)
190 |
191 | inputs = (_to_variable(x) for x in inputs)
192 |
193 | outputs = fn(*inputs)
194 |
195 | # Init
196 | if retain_inputs:
197 | if input_list is None:
198 | input_list = [[] for _ in range(len(inputs))]
199 | for j, input in enumerate(inputs):
200 | input_list[j].append(cuda.to_cpu(input))
201 | if output_list is None:
202 | output_list = [[] for _ in range(len(outputs))]
203 |
204 | if postprocess_fn:
205 | outputs = postprocess_fn(*outputs)
206 | outputs = _to_tuple(outputs)
207 | for j, output in enumerate(outputs):
208 | output_list[j].append(_extract_numpy(output))
209 |
210 | if retain_inputs:
211 | self.inputs = [numpy.concatenate(
212 | in_array) for in_array in input_list]
213 |
214 | result = [_concat(output) for output in output_list]
215 |
216 | # result = [numpy.concatenate(output) for output in output_list]
217 | if len(result) == 1:
218 | return result[0]
219 | else:
220 | return result
221 |
222 |
223 | def _concat(batch_list):
224 | try:
225 | return numpy.concatenate(batch_list)
226 | except Exception as e:
227 | # Thre is a case that each input has different shape,
228 | # we cannot concatenate into array in this case.
229 |
230 | elem_list = [elem for batch in batch_list for elem in batch]
231 | return _concat_arrays_with_padding(elem_list, padding=0)
232 |
--------------------------------------------------------------------------------
/saliency/calculator/gradient_calculator.py:
--------------------------------------------------------------------------------
1 | from saliency.calculator.base_calculator import BaseCalculator
2 |
3 |
4 | class GradientCalculator(BaseCalculator):
5 |
6 | def __init__(self, model, eval_fun=None, eval_key=None, target_key=0,
7 | multiply_target=False):
8 | super(GradientCalculator, self).__init__(model)
9 | # self.model = model
10 | # self._device = cuda.get_array_module(model)
11 | self.eval_fun = eval_fun
12 | self.eval_key = eval_key
13 | self.target_key = target_key
14 |
15 | self.multiply_target = multiply_target
16 |
17 | def _compute_core(self, *inputs):
18 | # outputs = fn(*inputs)
19 | # outputs = _to_tuple(outputs)
20 | self.model.cleargrads()
21 | result = self.eval_fun(*inputs)
22 | if self.eval_key is None:
23 | eval_var = result
24 | elif isinstance(self.eval_key, str):
25 | eval_var = result[self.eval_key]
26 | else:
27 | raise TypeError('Unexpected type {} for eval_key'
28 | .format(type(self.eval_key)))
29 | # TODO: Consider how deal with the case when eval_var is not scalar,
30 | # 1. take sum
31 | # 2. raise error (default behavior)
32 | # I think option 1 "take sum" is better, since gradient is calculated
33 | # automatically independently in that case.
34 | eval_var.backward(retain_grad=True)
35 |
36 | if self.target_key is None:
37 | target_var = inputs
38 | elif isinstance(self.target_key, int):
39 | target_var = inputs[self.target_key]
40 | else:
41 | raise TypeError('Unexpected type {} for target_key'
42 | .format(type(self.target_key)))
43 | saliency = target_var.grad
44 | if self.multiply_target:
45 | saliency *= target_var.data
46 | outputs = (saliency,)
47 | return outputs
48 |
--------------------------------------------------------------------------------
/saliency/calculator/integrated_gradients_calculator.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from chainer import Variable
3 |
4 | from saliency.calculator.base_calculator import BaseCalculator
5 | from saliency.calculator.gradient_calculator import GradientCalculator
6 |
7 |
8 | class IntegratedGradientsCalculator(GradientCalculator):
9 |
10 | def __init__(self, model, eval_fun=None, eval_key=None, target_key=0,
11 | baseline=None, steps=25):
12 | super(IntegratedGradientsCalculator, self).__init__(
13 | model, eval_fun=eval_fun, eval_key=eval_key, target_key=target_key,
14 | multiply_target=False
15 | )
16 | # self.target_key = target_key
17 | self.baseline = baseline or 0.
18 | self.steps = steps
19 |
20 | def get_target_var(self, inputs):
21 | if self.target_key is None:
22 | target_var = inputs
23 | elif isinstance(self.target_key, int):
24 | target_var = inputs[self.target_key]
25 | else:
26 | raise TypeError('Unexpected type {} for target_key'
27 | .format(type(self.target_key)))
28 | return target_var
29 |
30 | def _compute_core(self, *inputs):
31 |
32 | total_grads = 0.
33 | target_var = self.get_target_var(inputs)
34 | base = self.baseline
35 | diff = target_var.data - base
36 | for alpha in numpy.linspace(0., 1., self.steps):
37 | # TODO: consider case target_key=None
38 | interpolated_inputs = (
39 | Variable(base + alpha * diff) if self.target_key == i else elem
40 | for i, elem in enumerate(inputs))
41 | total_grads += super(
42 | IntegratedGradientsCalculator, self)._compute_core(
43 | *interpolated_inputs)[0]
44 | saliency = total_grads * diff
45 | return saliency,
46 |
--------------------------------------------------------------------------------
/saliency/calculator/occlusion_calculator.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import numpy
4 |
5 | import chainer
6 | import six
7 |
8 | from saliency.calculator.base_calculator import BaseCalculator
9 |
10 |
11 | def _to_tuple(x):
12 | if isinstance(x, int):
13 | x = (x,)
14 | elif isinstance(x, (list, tuple)):
15 | x = tuple(x)
16 | else:
17 | raise TypeError('Unexpected type {}'.format(type(x)))
18 | return x
19 |
20 |
21 | class OcclusionCalculator(BaseCalculator):
22 |
23 | def __init__(self, model, eval_fun=None, eval_key=None,
24 | enable_backprop=False, size=1, slide_axis=(2, 3),
25 | target_key=0):
26 | super(OcclusionCalculator, self).__init__(model)
27 | # self.model = model
28 | # self._device = cuda.get_array_module(model)
29 | self.eval_fun = eval_fun
30 | self.eval_key = eval_key
31 | self.enable_backprop = enable_backprop
32 | self.slide_axis = _to_tuple(slide_axis)
33 | size = _to_tuple(size)
34 | if len(self.slide_axis) != size:
35 | size = size * len(self.slide_axis)
36 | self.size = size
37 | print('slide_axis', self.slide_axis, 'size', self.size)
38 | self.target_key = target_key
39 |
40 | def _compute_core(self, *inputs):
41 | if self.target_key is None:
42 | target_var = inputs
43 | elif isinstance(self.target_key, int):
44 | target_var = inputs[self.target_key]
45 | else:
46 | raise TypeError('Unexpected type {} for target_key'
47 | .format(type(self.target_key)))
48 |
49 | def _extract_score(result):
50 | if self.eval_key is None:
51 | score = result
52 | elif isinstance(self.eval_key, str):
53 | score = result[self.eval_key]
54 | else:
55 | raise TypeError('Unexpected type {} for eval_key'
56 | .format(type(self.eval_key)))
57 | return score
58 |
59 | # Usually, backward() is not necessary for calculating occlusion
60 | with chainer.using_config('enable_backprop', self.enable_backprop):
61 | original_result = self.eval_fun(*inputs)
62 | original_score = _extract_score(original_result)
63 |
64 | # TODO: xp and value assign dynamically
65 | xp = numpy
66 | value = 0.
67 |
68 | # fill with `value`
69 | target_dim = target_var.ndim
70 | batch_size = target_var.shape[0]
71 | occlusion_window_shape = [1] * target_dim
72 | occlusion_window_shape[0] = batch_size
73 | for axis, size in zip(self.slide_axis, self.size):
74 | occlusion_window_shape[axis] = size
75 | occlusion_scores_shape = [1] * target_dim
76 | occlusion_scores_shape[0] = batch_size
77 | for axis, size in zip(self.slide_axis, self.size):
78 | occlusion_scores_shape[axis] = target_var.shape[axis]
79 | # print('[DEBUG] occlusion_shape', occlusion_window_shape)
80 | occlusion_window = xp.ones(occlusion_window_shape, dtype=target_var.dtype) * value
81 | # print('[DEBUG] occlusion_window.shape', occlusion_window.shape)
82 | occlusion_scores = xp.zeros(occlusion_scores_shape, dtype=xp.float32)
83 | # print('[DEBUG] occlusion_scores.shape', occlusion_scores.shape)
84 |
85 | def _extract_index(slide_axis, size, start_indices):
86 | colon = slice(None)
87 | index = [colon] * target_dim
88 | for axis, size, start in zip(slide_axis, size, start_indices):
89 | index[axis] = slice(start, start + size, 1)
90 | return index
91 |
92 | end_list = [target_var.data.shape[axis] - size for axis, size
93 | in zip(self.slide_axis, self.size)]
94 |
95 | for start in itertools.product(*[six.moves.range(end) for end in end_list]):
96 | target_var_occluded = target_var.data.copy()
97 | occlude_index = _extract_index(self.slide_axis, self.size, start)
98 | target_var_occluded[occlude_index] = occlusion_window
99 |
100 | # Usually, backward() is not necessary for calculating occlusion
101 | with chainer.using_config('enable_backprop', self.enable_backprop):
102 | # TODO: consider case target_key=None
103 | occluded_inputs = (target_var_occluded if self.target_key == i else elem
104 | for i, elem in enumerate(inputs))
105 | occluded_result = self.eval_fun(*occluded_inputs)
106 | occluded_score = _extract_score(occluded_result)
107 | score_diff_var = original_score - occluded_score
108 | # TODO: expand_dim dynamically
109 | score_diff = xp.broadcast_to(score_diff_var.data[:, :, None], occlusion_window_shape)
110 | occlusion_scores[occlude_index] += score_diff
111 | outputs = (occlusion_scores,)
112 | return outputs
113 |
114 |
115 | if __name__ == '__main__':
116 | # TODO: test
117 | raise NotImplementedError()
118 | oc = OcclusionCalculator(model)
119 | saliency_array = oc.compute_vanilla()
120 | saliency = oc.transform(saliency_array)
121 |
--------------------------------------------------------------------------------
/saliency/visualizer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/saliency/visualizer/__init__.py
--------------------------------------------------------------------------------
/saliency/visualizer/base_visualizer.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta
2 | from abc import abstractmethod
3 |
4 | from future.utils import with_metaclass
5 |
6 |
7 | class BaseVisualizer(with_metaclass(ABCMeta, object)):
8 |
9 | @abstractmethod
10 | def visualize(self, *args, **kwargs):
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/saliency/visualizer/smiles_visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy
2 |
3 | from rdkit import Chem
4 | from rdkit.Chem import rdDepictor
5 | from rdkit.Chem.Draw import rdMolDraw2D
6 |
7 |
8 | from saliency.visualizer.base_visualizer import BaseVisualizer
9 |
10 |
11 | def _convert_to_2d(axes, nrows, ncols):
12 | if nrows == 1 and ncols == 1:
13 | axes = numpy.array([[axes]])
14 | elif nrows == 1:
15 | axes = axes[None, :]
16 | elif ncols == 1:
17 | axes = axes[:, None]
18 | else:
19 | pass
20 | assert axes.ndim == 2
21 | return axes
22 |
23 |
24 | def is_visible(begin, end):
25 | if begin <= 0 or end <= 0:
26 | return 0
27 | elif begin >= 1 or end >= 1:
28 | return 1
29 | else:
30 | return (begin + end) * 0.5
31 |
32 |
33 | # Default color function
34 | def red(x):
35 | # return in RGB order
36 | # x=0 -> 1, 1, 1 (white)
37 | # x=1 -> 1, 0, 0 (red)
38 | return 1., 1. - x, 1. - x
39 |
40 |
41 | def min_max_scaler(saliency):
42 | """Normalize saliency to value 0-1"""
43 | maxv = numpy.max(saliency)
44 | minv = numpy.min(saliency)
45 | if maxv == minv:
46 | saliency = numpy.zeros_like(saliency)
47 | else:
48 | saliency = (saliency - minv) / (maxv - minv)
49 | return saliency
50 |
51 |
52 | class SmilesVisualizer(BaseVisualizer):
53 |
54 | def visualize(self, saliency, smiles, save_filepath=None,
55 | visualize_ratio=1.0, color_fn=red, scaler=min_max_scaler, legend=''):
56 | mol = Chem.MolFromSmiles(smiles)
57 | num_atoms = mol.GetNumAtoms()
58 | rdDepictor.Compute2DCoords(mol)
59 | Chem.SanitizeMol(mol)
60 | Chem.Kekulize(mol)
61 | n_atoms = mol.GetNumAtoms()
62 | # highlight = list(range(n_atoms))
63 |
64 | # --- type check ---
65 | assert saliency.ndim == 1
66 | # Cut saliency array for unnecessary tail part
67 | saliency = saliency[:num_atoms]
68 | # Normalize to [0, 1]
69 | saliency = scaler(saliency)
70 | # normed_saliency = copy.deepcopy(saliency)
71 |
72 | if visualize_ratio < 1.0:
73 | threshold_index = int(n_atoms * visualize_ratio)
74 | idx = numpy.argsort(saliency)
75 | idx = numpy.flip(idx, axis=0)
76 | # set threshold to top `visualize_ratio` saliency
77 | threshold = saliency[idx[threshold_index]]
78 | saliency = numpy.where(saliency < threshold, 0., saliency)
79 | else:
80 | threshold = numpy.min(saliency)
81 |
82 | highlight_atoms = list(map(lambda g: g.__int__(), numpy.where(saliency >= threshold)[0]))
83 | atom_colors = {i: color_fn(e) for i, e in enumerate(saliency)}
84 | bondlist = [bond.GetIdx() for bond in mol.GetBonds()]
85 |
86 | def color_bond(bond):
87 | begin = saliency[bond.GetBeginAtomIdx()]
88 | end = saliency[bond.GetEndAtomIdx()]
89 | return color_fn(is_visible(begin, end))
90 | bondcolorlist = {i: color_bond(bond) for i, bond in enumerate(mol.GetBonds())}
91 | drawer = rdMolDraw2D.MolDraw2DSVG(500, 375)
92 | drawer.DrawMolecule(
93 | mol, highlightAtoms=highlight_atoms,
94 | highlightAtomColors=atom_colors, highlightBonds=bondlist,
95 | highlightBondColors=bondcolorlist, legend=legend)
96 | drawer.FinishDrawing()
97 | svg = drawer.GetDrawingText()
98 | if save_filepath:
99 | extention = save_filepath.split('.')[-1]
100 | if extention == 'svg':
101 | print('saving svg to {}'.format(save_filepath))
102 | with open(save_filepath, 'w') as f:
103 | f.write(svg)
104 | elif extention == 'png':
105 | import cairosvg
106 | print('saving png to {}'.format(save_filepath))
107 | # cairosvg.svg2png(
108 | # url=svg_save_filepath, write_to=save_filepath)
109 | # print('svg type', type(svg))
110 | cairosvg.svg2png(bytestring=svg, write_to=save_filepath)
111 | else:
112 | raise ValueError(
113 | 'Unsupported extention {} for save_filepath {}'
114 | .format(extention, save_filepath))
115 | else:
116 | from IPython.core.display import SVG
117 | return SVG(svg.replace('svg:', ''))
118 |
--------------------------------------------------------------------------------