├── .gitignore ├── LICENSE ├── README.md ├── assets ├── delaney_solubility │ ├── 19.png │ ├── 3.png │ └── 54.png ├── tox21_pyridine │ ├── 13_bayes.png │ ├── 21_bayes.png │ └── 6_bayes.png └── tox21_srmmp │ ├── 27_bayes.png │ ├── 2_bayes.png │ └── 3_bayes.png ├── experiments ├── delaney │ ├── plot.py │ └── train.py └── tox21 │ ├── calc_prcauc_with_seeds.py │ ├── calc_prcauc_with_seeds.sh │ ├── data.py │ ├── plot_precision_recall.py │ ├── train_few_with_seeds.sh │ ├── train_tox21.py │ ├── utils.py │ ├── visualize-saliency-pyrigine.ipynb │ └── visualize-saliency-tox21.ipynb ├── models ├── __init__.py ├── ggnn_drop.py ├── nfp_drop.py └── predictor.py ├── requirements.txt └── saliency ├── __init__.py ├── calculator ├── __init__.py ├── base_calculator.py ├── gradient_calculator.py ├── integrated_gradients_calculator.py └── occlusion_calculator.py └── visualizer ├── __init__.py ├── base_visualizer.py └── smiles_visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | .idea/ 103 | experiments/tox21/results/ 104 | experiments/delaney/results/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Preferred Networks, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesgrad 2 | BayesGrad: Explaining Predictions of Graph Convolutional Networks 3 | 4 | The paper is available on arXiv, [https://arxiv.org/abs/1807.01985](https://arxiv.org/abs/1807.01985). 5 | 6 |

7 | 8 | 9 | 10 |

11 |

12 | From left: tox21 pyridine (C5H5N), tox21 SR-MMP, delaney solubility visualization. 13 |

14 | 15 | ## Citation 16 | If you find our work useful in your research, please consider citing: 17 | 18 | ``` 19 | @article{akita2018bayesgrad, 20 | title={BayesGrad: Explaining Predictions of Graph Convolutional Networks}, 21 | author={Akita, Hirotaka and Nakago, Kosuke and Komatsu, Tomoki and Sugawara, Yohei and Maeda, Shin-ichi and Baba, Yukino and Kashima, Hisashi}, 22 | journal={arXiv preprint arXiv:1807.01985}, 23 | year={2018} 24 | } 25 | ``` 26 | 27 | ## Setup 28 | 29 | [Chainer Chemistry](https://github.com/pfnet-research/chainer-chemistry) [1] is used in our code. 30 | It is an extension library for deep learning framework [Chainer](https://github.com/chainer/chainer) [2], 31 | and it supports several graph-convolutional neural network together with chemical dataset management. 32 | 33 | The experiment is executed under following environment: 34 | 35 | - OS: Linux 36 | - python: 3.6.1 37 | - conda version: 4.4.4 38 | 39 | ```bash 40 | conda create -n bayesgrad python=3.6.1 41 | source activate bayesgrad 42 | pip install chainer==4.2.0 43 | pip install chainer-chemistry==0.4.0 44 | conda install -c rdkit rdkit==2017.09.3.0 45 | pip install matplotlib==2.2.2 46 | pip install future==0.16.0 47 | pip install cairosvg==2.1.3 48 | pip install ipython==5.1.0 49 | ``` 50 | 51 | [Note] 52 | Please install specified python version & rdkit version. 53 | Latest python version and rdkit may not work well as discussed [here](https://github.com/pfnet-research/chainer-chemistry/issues/138). 54 | If you face error try 55 | ```bash 56 | conda install libgcc 57 | ``` 58 | 59 | If you want to use GPU, please install `cupy` as well. 60 | ```bash 61 | # XX should be CUDA version (80, 90 or 91) 62 | pip install cupy-cudaXX==4.2.0 63 | ``` 64 | 65 | ## Experiments 66 | 67 | Each experiment can be executed as follows. 68 | 69 | ### Tox21 Pyridine experiment 70 | Experiments described in Section 4.1 in the paper. Tox21 [3] dataset is used. 71 | 72 | ```bash 73 | cd experiments/tox21 74 | ``` 75 | 76 | #### Training with all train data, plot precision-recall curve 77 | 78 | Set `-g -1` to use CPU, `-g 0` to use GPU. 79 | ```bash 80 | python train_tox21.py --iterator-type=balanced --label=pyridine --method=ggnndrop --epoch=50 --unit-num=16 --n-layers=1 -b 32 --conv-layers=4 --num-train=-1 --dropout-ratio=0.25 --out=results/ggnndrop_pyridine -g 0 81 | python plot_precision_recall.py --dirpath=results/ggnndrop_pyridine 82 | ``` 83 | 84 | #### Visualization with trained model 85 | See `visualize-saliency-pyrigine.ipynb`. 86 | 87 |

88 | 89 | 90 | 91 |

92 | 93 | Our method successfully focuses on pyridine (C5H5N) substructures. 94 | 95 | #### Training 30 different models with few train data, calculate RPC-AUC score 96 | Argument: `-1` to use CPU, `0` to use GPU. 97 | 98 | Note that this experiment takes time (took around 2.5 hours with GPU in our environment), 99 | since it trains 30 different models. 100 | 101 | ```bash 102 | bash -x ./train_few_with_seeds.sh 0 103 | bash -x ./calc_prcauc_with_seeds.sh 0 104 | ``` 105 | 106 | Then see `results/ggnndrop_pyridin_numtrain1000-seed0-29/prcauc_stats_absolute_0.15.csv` for the results. 107 | 108 | ### Tox21 SR-MMP experiment 109 | Experiments described in Section 4.2 in the paper. Tox21 [3] dataset is used. 110 | 111 | ```bash 112 | cd experiments/tox21 113 | ``` 114 | 115 | #### Training the model 116 | Set `-g -1` to use CPU, `-g 0` to use GPU. 117 | ```bash 118 | python train_tox21.py --iterator-type=balanced --label=SR-MMP --method=nfpdrop --epoch=200 --unit-num=16 --n-layers=1 -b 32 --conv-layers=4 --num-train=-1 --dropout-ratio=0.25 --out=results/nfpdrop_srmmp -g 0 119 | ``` 120 | 121 | #### Visualization of tox21 data & Tyrphostin 9 with trained model 122 | See `visualize-saliency-tox21.ipynb`. 123 | 124 | Jupyter notebook interactive visualization: 125 |

126 | 127 |

128 | 129 | Several picked images: 130 |

131 | 132 | 133 | 134 |

135 | 136 | Toxicity mechanism is still in an active research topic and it is difficult to quantitatively analyze its results. 137 | We hope these visualization helps to analyze and establish further knowledge about toxicity. 138 | 139 | ### Solubility experiment 140 | Experiment done in Section 4.3 in the paper. ESOL [4] dataset (provided by MoleculeNet [5]) is used. 141 | 142 | ```bash 143 | cd experiments/delaney 144 | ``` 145 | 146 | #### Training the model 147 | Set `-g -1` to use CPU, `-g 0` to use GPU. 148 | 149 | ```bash 150 | python train.py -e 100 -n 3 --method=nfpdrop -g 0 151 | ``` 152 | 153 | #### Visualization with trained model 154 | ```bash 155 | python plot.py --dirpath=./results/nfpdrop_M30_conv3_unit32_b32 156 | ``` 157 | 158 |

159 | 160 | 161 | 162 |

163 | 164 | Red color represents these atoms are hydrophilic, and blue color represents hydrophobic. 165 | Above figure is consistent with fundamental physicochemical knowledge as explained in the paper. 166 | 167 | ## Saliency Calculation 168 | 169 | Although only results of gradient method [6, 7, 8] are reported in the paper, 170 | this repository contains saliency calculation code for several other algorithms as well. 171 | 172 | We can apply SmoothGrad [8] and/or BayesGrad (Ours) into following algorithms. 173 | 174 | - Vanilla Gradients [6, 7] 175 | - Integrated Gradients [9] 176 | - Occlusion [10] 177 | 178 | The code design is inspired by [PAIR-code/saliency](https://github.com/PAIR-code/saliency). 179 | 180 | ## License 181 | 182 | Our code is released under MIT License (see [LICENSE](https://github.com/pfnet-research/bayesgrad/blob/master/LICENSE) file for details). 183 | 184 | ## Reference 185 | 186 | [1] pfnet research. chainer-chemistry https://github.com/pfnet-research/chainer-chemistry 187 | 188 | [2] Seiya Tokui, Kenta Oono, Shohei Hido, and Justin Clayton. Chainer: a next-generation open source framework for deep learning. In *Proceedings of Workshop on Machine Learning Systems (LearningSys) in Advances in Neural Information Processing System (NIPS) 28*, 2015. 189 | 190 | [3] Ruili Huang, Menghang Xia, Dac-Trung Nguyen, Tongan Zhao, Srilatha Sakamuru, Jinghua Zhao, Sampada A Shahane, Anna Rossoshek, and Anton Simeonov. Tox21challenge to build predictive models of nuclear receptor and stress response pathways as mediated by exposure to environmental chemicals and drugs. Frontiers in Environmental Science, 3:85, 2016. 191 | 192 | [4] John S. Delaney. Esol: Estimating aqueous solubility directly from molecular structure. Journal of Chemical Information and Computer Sciences, 44(3):1000{1005,2004. PMID: 15154768. 193 | 194 | [5] Zhenqin Wu, Bharath Ramsundar, Evan N. Feinberg, Joseph Gomes, Caleb Geniesse, Aneesh S. Pappu, Karl Leswing, Vijay Pande, MoleculeNet: A Benchmark for Molecular Machine Learning, arXiv preprint, arXiv: 1703.00564, 2017. 195 | 196 | [6] Dumitru Erhan, Yoshua Bengio, Aaron Courville, Pascal Vincent. Visualizing Higher-Layer Features of a Deep Network. 2009. 197 | 198 | [7] Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. Deep inside convolutional networks: Visualising image classication models and saliency maps. arXiv preprint arXiv:1312.6034, 2013. 199 | 200 | [8] Daniel Smilkov, Nikhil Thorat, Been Kim, Fernanda Viegas, and Martin Wattenberg. SmoothGrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825, 2017. 201 | 202 | [9] Mukund Sundararajan, Ankur Taly, and Qiqi Yan. Axiomatic attribution for deep networks. In Doina Precup and Yee Whye Teh (eds.), 203 | Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pp. 3319–3328, International Convention Centre, Sydney, Australia, 06–11 Aug 2017. PMLR. 204 | URL http://proceedings.mlr.press/v70/sundararajan17a.html. 205 | 206 | [10] Matthew D Zeiler and Rob Fergus. Visualizing and understanding convolutional networks. In 207 | European conference on computer vision, pp. 818–833. Springer, 2014. 208 | -------------------------------------------------------------------------------- /assets/delaney_solubility/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/delaney_solubility/19.png -------------------------------------------------------------------------------- /assets/delaney_solubility/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/delaney_solubility/3.png -------------------------------------------------------------------------------- /assets/delaney_solubility/54.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/delaney_solubility/54.png -------------------------------------------------------------------------------- /assets/tox21_pyridine/13_bayes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_pyridine/13_bayes.png -------------------------------------------------------------------------------- /assets/tox21_pyridine/21_bayes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_pyridine/21_bayes.png -------------------------------------------------------------------------------- /assets/tox21_pyridine/6_bayes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_pyridine/6_bayes.png -------------------------------------------------------------------------------- /assets/tox21_srmmp/27_bayes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_srmmp/27_bayes.png -------------------------------------------------------------------------------- /assets/tox21_srmmp/2_bayes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_srmmp/2_bayes.png -------------------------------------------------------------------------------- /assets/tox21_srmmp/3_bayes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/assets/tox21_srmmp/3_bayes.png -------------------------------------------------------------------------------- /experiments/delaney/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import os 5 | import sys 6 | 7 | import matplotlib 8 | matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 12 | from saliency.visualizer.smiles_visualizer import SmilesVisualizer 13 | 14 | 15 | def visualize(dir_path): 16 | parent_dir = os.path.dirname(dir_path) 17 | saliency_vanilla = np.load(os.path.join(dir_path, "saliency_vanilla.npy")) 18 | saliency_smooth = np.load(os.path.join(dir_path, "saliency_smooth.npy")) 19 | saliency_bayes = np.load(os.path.join(dir_path, "saliency_bayes.npy")) 20 | 21 | visualizer = SmilesVisualizer() 22 | os.makedirs(os.path.join(parent_dir, "result_vanilla"), exist_ok=True) 23 | os.makedirs(os.path.join(parent_dir, "result_smooth"), exist_ok=True) 24 | os.makedirs(os.path.join(parent_dir, "result_bayes"), exist_ok=True) 25 | 26 | test_idx = np.load(os.path.join(dir_path, "test_idx.npy")) 27 | answer = np.load(os.path.join(dir_path, "answer.npy")) 28 | output = np.load(os.path.join(dir_path, "output.npy")) 29 | 30 | smiles_all = np.load(os.path.join(parent_dir, "smiles.npy")) 31 | 32 | def calc_range(saliency): 33 | vmax = float('-inf') 34 | vmin = float('inf') 35 | for v in saliency: 36 | vmax = max(vmax, np.max(v)) 37 | vmin = min(vmin, np.min(v)) 38 | return vmin, vmax 39 | 40 | v_range_vanilla = calc_range(saliency_vanilla) 41 | v_range_smooth = calc_range(saliency_smooth) 42 | v_range_bayes = calc_range(saliency_bayes) 43 | 44 | def get_scaler(v_range): 45 | def scaler(saliency_): 46 | saliency = np.copy(saliency_) 47 | minv, maxv = v_range 48 | if maxv == minv: 49 | saliency = np.zeros_like(saliency) 50 | else: 51 | pos = saliency >= 0.0 52 | saliency[pos] = saliency[pos]/maxv 53 | nega = saliency < 0.0 54 | saliency[nega] = saliency[nega]/(np.abs(minv)) 55 | return saliency 56 | return scaler 57 | 58 | scaler_vanilla = get_scaler(v_range_vanilla) 59 | scaler_smooth = get_scaler(v_range_smooth) 60 | scaler_bayes = get_scaler(v_range_bayes) 61 | 62 | def color(x): 63 | if x > 0: 64 | # Red for positive value 65 | return 1., 1. - x, 1. - x 66 | else: 67 | # Blue for negative value 68 | x *= -1 69 | return 1. - x, 1. - x, 1. 70 | 71 | for i, id in enumerate(test_idx): 72 | smiles = smiles_all[id] 73 | out = output[i] 74 | ans = answer[i] 75 | # legend = "t:{}, p:{}".format(ans, out) 76 | legend = '' 77 | ext = '.png' # '.svg' 78 | # visualizer.visualize( 79 | # saliency_vanilla[id], smiles, save_filepath=os.path.join(parent_dir, "result_vanilla", str(id) + ext), 80 | # visualize_ratio=1.0, legend=legend, scaler=scaler_vanilla, color_fn=color) 81 | # visualizer.visualize( 82 | # saliency_smooth[id], smiles, save_filepath=os.path.join(parent_dir, "result_smooth", str(id) + ext), 83 | # visualize_ratio=1.0, legend=legend, scaler=scaler_smooth, color_fn=color) 84 | visualizer.visualize( 85 | saliency_bayes[id], smiles, save_filepath=os.path.join(parent_dir, "result_bayes", str(id) + ext), 86 | visualize_ratio=1.0, legend=legend, scaler=scaler_bayes, color_fn=color) 87 | 88 | 89 | def plot_result(prediction, answer, save_filepath='result.png'): 90 | plt.scatter(prediction, answer, marker='.') 91 | plt.plot([-100, 100], [-100, 100], c='r') 92 | max_v = max(np.max(prediction), np.max(answer)) 93 | min_v = min(np.min(prediction), np.min(answer)) 94 | plt.xlim([min_v-0.1, max_v+0.1]) 95 | plt.xlabel("prediction") 96 | plt.ylim([min_v-0.1, max_v+0.1]) 97 | plt.ylabel("ground truth") 98 | plt.savefig(save_filepath) 99 | plt.close() 100 | 101 | 102 | def main(): 103 | parser = argparse.ArgumentParser( 104 | description='Regression with own dataset.') 105 | parser.add_argument('--dirpath', '-d', type=str, default='./results/M_30_3_32_32') 106 | args = parser.parse_args() 107 | path = args.dirpath 108 | n_split = 5 109 | output = [] 110 | answer = [] 111 | for i in range(n_split): 112 | suffix = str(i) + "-" + str(n_split) 113 | output.append(np.load(os.path.join(path, suffix, "output.npy"))) 114 | answer.append(np.load(os.path.join(path, suffix, "answer.npy"))) 115 | output = np.concatenate(output) 116 | answer = np.concatenate(answer) 117 | 118 | plot_result(output, answer, save_filepath=os.path.join(path, "result.png")) 119 | for i in range(n_split): 120 | suffix = str(i) + "-" + str(n_split) 121 | print(suffix) 122 | visualize(os.path.join(path, suffix)) 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | 128 | -------------------------------------------------------------------------------- /experiments/delaney/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | from sklearn.preprocessing import StandardScaler 8 | 9 | import matplotlib as mpl 10 | mpl.use('Agg') 11 | 12 | from chainer import functions as F, cuda, Variable 13 | from chainer import iterators as I 14 | from chainer import optimizers as O 15 | from chainer import training 16 | from chainer.training import extensions as E 17 | from chainer.datasets import SubDataset 18 | from chainer import serializers 19 | 20 | from chainer_chemistry.dataset.converters import concat_mols 21 | from chainer_chemistry.dataset.preprocessors import preprocess_method_dict 22 | from chainer_chemistry.datasets import NumpyTupleDataset 23 | from chainer_chemistry.datasets.molnet import get_molnet_dataset 24 | from chainer_chemistry.models.prediction import Regressor 25 | 26 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 27 | from models import predictor 28 | from saliency.calculator.gradient_calculator import GradientCalculator 29 | from plot import plot_result 30 | 31 | 32 | def save_result(dataset, model, dir_path, M): 33 | regressor = Regressor(model, lossfun=F.mean_squared_error) 34 | # model.to_cpu() 35 | 36 | def preprocess_fun(*inputs): 37 | atom, adj, t = inputs 38 | # HACKING for now... 39 | atom_embed = regressor.predictor.graph_conv.embed(atom) 40 | return atom_embed, adj, t 41 | 42 | def eval_fun(*inputs): 43 | atom_embed, adj, t = inputs 44 | prob = regressor.predictor(atom_embed, adj) 45 | out = F.sum(prob) 46 | return out 47 | 48 | gradient_calculator = GradientCalculator( 49 | regressor, eval_fun=eval_fun, 50 | target_key=0, multiply_target=True 51 | ) 52 | 53 | def clip_original_size(saliency_, num_atoms_): 54 | """`saliency` array is 0 padded, this method align to have original 55 | molecule's length 56 | """ 57 | assert len(saliency_) == len(num_atoms_) 58 | saliency_list = [] 59 | for i in range(len(saliency_)): 60 | saliency_list.append(saliency_[i, :num_atoms_[i]]) 61 | return saliency_list 62 | 63 | atoms = dataset.features[:, 0] 64 | num_atoms = [len(a) for a in atoms] 65 | 66 | print('calculating saliency... M={}'.format(M)) 67 | # --- VanillaGrad --- 68 | saliency_arrays = gradient_calculator.compute_vanilla( 69 | dataset, converter=concat_mols, preprocess_fn=preprocess_fun) 70 | saliency = gradient_calculator.transform( 71 | saliency_arrays, ch_axis=3, method='raw') 72 | saliency_vanilla = clip_original_size(saliency, num_atoms) 73 | np.save(os.path.join(dir_path, "saliency_vanilla"), saliency_vanilla) 74 | 75 | # --- SmoothGrad --- 76 | saliency_arrays = gradient_calculator.compute_smooth( 77 | dataset, converter=concat_mols, preprocess_fn=preprocess_fun, M=M) 78 | saliency = gradient_calculator.transform( 79 | saliency_arrays, ch_axis=3, method='raw') 80 | saliency_smooth = clip_original_size(saliency, num_atoms) 81 | np.save(os.path.join(dir_path, "saliency_smooth"), saliency_smooth) 82 | 83 | # --- BayesGrad --- 84 | # train=True corresponds to BayesGrad 85 | saliency_arrays = gradient_calculator.compute_vanilla( 86 | dataset, converter=concat_mols, preprocess_fn=preprocess_fun, M=M, 87 | train=True) 88 | saliency = gradient_calculator.transform( 89 | saliency_arrays, ch_axis=3, method='raw', lam=0) 90 | saliency_bayes = clip_original_size(saliency, num_atoms) 91 | np.save(os.path.join(dir_path, "saliency_bayes"), saliency_bayes) 92 | 93 | 94 | def get_dir_path(batchsize, n_unit, conv_layers, M, method): 95 | dir_path = "results/{}_M{}_conv{}_unit{}_b{}".format(method, M, conv_layers, n_unit, batchsize) 96 | dir_path = os.path.join("./", dir_path) 97 | return dir_path 98 | 99 | 100 | def train(gpu, method, epoch, batchsize, n_unit, conv_layers, dataset, smiles, M, n_split, split_idx, order): 101 | n = len(dataset) 102 | assert len(order) == n 103 | left_idx = (n // n_split) * split_idx 104 | is_right_most_split = (n_split == split_idx + 1) 105 | if is_right_most_split: 106 | test_order = order[left_idx:] 107 | train_order = order[:left_idx] 108 | else: 109 | right_idx = (n // n_split) * (split_idx + 1) 110 | test_order = order[left_idx:right_idx] 111 | train_order = np.concatenate([order[:left_idx], order[right_idx:]]) 112 | 113 | new_order = np.concatenate([train_order, test_order]) 114 | n_train = len(train_order) 115 | 116 | # Standard Scaler for labels 117 | ss = StandardScaler() 118 | labels = dataset.get_datasets()[-1] 119 | train_label = labels[new_order[:n_train]] 120 | ss = ss.fit(train_label) # fit only by train 121 | labels = ss.transform(dataset.get_datasets()[-1]) 122 | dataset = NumpyTupleDataset(*(dataset.get_datasets()[:-1] + (labels,))) 123 | 124 | dataset_train = SubDataset(dataset, 0, n_train, new_order) 125 | dataset_test = SubDataset(dataset, n_train, n, new_order) 126 | 127 | # Network 128 | model = predictor.build_predictor( 129 | method, n_unit, conv_layers, 1, dropout_ratio=0.25, n_layers=1) 130 | 131 | train_iter = I.SerialIterator(dataset_train, batchsize) 132 | val_iter = I.SerialIterator(dataset_test, batchsize, repeat=False, shuffle=False) 133 | 134 | def scaled_abs_error(x0, x1): 135 | if isinstance(x0, Variable): 136 | x0 = cuda.to_cpu(x0.data) 137 | if isinstance(x1, Variable): 138 | x1 = cuda.to_cpu(x1.data) 139 | scaled_x0 = ss.inverse_transform(cuda.to_cpu(x0)) 140 | scaled_x1 = ss.inverse_transform(cuda.to_cpu(x1)) 141 | diff = scaled_x0 - scaled_x1 142 | return np.mean(np.absolute(diff), axis=0)[0] 143 | 144 | regressor = Regressor( 145 | model, lossfun=F.mean_squared_error, 146 | metrics_fun={'abs_error': scaled_abs_error}, device=gpu) 147 | 148 | optimizer = O.Adam(alpha=0.0005) 149 | optimizer.setup(regressor) 150 | 151 | updater = training.StandardUpdater(train_iter, optimizer, device=gpu, converter=concat_mols) 152 | 153 | dir_path = get_dir_path(batchsize, n_unit, conv_layers, M, method) 154 | dir_path = os.path.join(dir_path, str(split_idx) + "-" + str(n_split)) 155 | os.makedirs(dir_path, exist_ok=True) 156 | print('creating ', dir_path) 157 | np.save(os.path.join(dir_path, "test_idx"), np.array(test_order)) 158 | 159 | trainer = training.Trainer(updater, (epoch, 'epoch'), out=dir_path) 160 | trainer.extend(E.Evaluator(val_iter, regressor, device=gpu, 161 | converter=concat_mols)) 162 | trainer.extend(E.LogReport()) 163 | trainer.extend(E.PrintReport(['epoch', 'main/loss', 'main/abs_error', 164 | 'validation/main/loss', 165 | 'validation/main/abs_error', 166 | 'elapsed_time'])) 167 | trainer.extend(E.ProgressBar()) 168 | trainer.run() 169 | 170 | # --- Plot regression evaluation result --- 171 | dataset_test = SubDataset(dataset, n_train, n, new_order) 172 | batch_all = concat_mols(dataset_test, device=gpu) 173 | serializers.save_npz(os.path.join(dir_path, "model.npz"), model) 174 | result = model(batch_all[0], batch_all[1]) 175 | result = ss.inverse_transform(cuda.to_cpu(result.data)) 176 | answer = ss.inverse_transform(cuda.to_cpu(batch_all[2])) 177 | plot_result(result, answer, save_filepath=os.path.join(dir_path, "result.png")) 178 | 179 | # --- Plot regression evaluation result end --- 180 | np.save(os.path.join(dir_path, "output.npy"), result) 181 | np.save(os.path.join(dir_path, "answer.npy"), answer) 182 | smiles_part = np.array(smiles)[test_order] 183 | np.save(os.path.join(dir_path, "smiles.npy"), smiles_part) 184 | 185 | # calculate saliency and save it. 186 | save_result(dataset, model, dir_path, M) 187 | 188 | 189 | def main(): 190 | # Supported preprocessing/network list 191 | parser = argparse.ArgumentParser( 192 | description='Regression with own dataset.') 193 | parser.add_argument('--gpu', '-g', type=int, default=-1) 194 | parser.add_argument('--method', type=str, default='nfpdrop', 195 | choices=['nfpdrop', 'ggnndrop', 'nfp', 'ggnn']) 196 | parser.add_argument('--epoch', '-e', type=int, default=20) 197 | parser.add_argument('--seed', '-s', type=int, default=777) 198 | parser.add_argument('--layer', '-n', type=int, default=3) 199 | parser.add_argument('--batchsize', '-b', type=int, default=32) 200 | parser.add_argument('--m', '-m', type=int, default=30) 201 | args = parser.parse_args() 202 | 203 | dataset_name = 'delaney' 204 | # labels = "measured log solubility in mols per litre" 205 | labels = None 206 | 207 | # Dataset preparation 208 | print('Preprocessing dataset...') 209 | method = args.method 210 | if 'nfp' in method: 211 | preprocess_method = 'nfp' 212 | elif 'ggnn' in method: 213 | preprocess_method = 'ggnn' 214 | else: 215 | raise ValueError('Unexpected method', method) 216 | preprocessor = preprocess_method_dict[preprocess_method]() 217 | data = get_molnet_dataset( 218 | dataset_name, preprocessor, labels=labels, return_smiles=True, 219 | frac_train=1.0, frac_valid=0.0, frac_test=0.0) 220 | dataset = data['dataset'][0] 221 | smiles = data['smiles'][0] 222 | 223 | epoch = args.epoch 224 | gpu = args.gpu 225 | 226 | n_unit_list = [32] 227 | random_state = np.random.RandomState(args.seed) 228 | n = len(dataset) 229 | 230 | M = args.m 231 | order = np.arange(n) 232 | random_state.shuffle(order) 233 | batchsize = args.batchsize 234 | for n_unit in n_unit_list: 235 | n_layer = args.layer 236 | n_split = 5 237 | for idx in range(n_split): 238 | print('Start training: ', idx+1, "/", n_split) 239 | dir_path = get_dir_path(batchsize, n_unit, n_layer, M, method) 240 | os.makedirs(dir_path, exist_ok=True) 241 | np.save(os.path.join(dir_path, "smiles.npy"), np.array(smiles)) 242 | train(gpu, method, epoch, batchsize, n_unit, n_layer, dataset, smiles, M, n_split, idx, order) 243 | 244 | 245 | if __name__ == '__main__': 246 | main() 247 | -------------------------------------------------------------------------------- /experiments/tox21/calc_prcauc_with_seeds.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate statistics by seeds. 3 | """ 4 | import argparse 5 | 6 | import matplotlib 7 | import pandas 8 | 9 | matplotlib.use('agg') 10 | import matplotlib.pyplot as plt 11 | 12 | from chainer import functions as F 13 | from chainer import links as L 14 | from chainer import serializers 15 | from chainer_chemistry.dataset.converters import concat_mols 16 | from chainer_chemistry.datasets import NumpyTupleDataset 17 | import numpy as np 18 | from rdkit import RDLogger, Chem 19 | from sklearn.metrics import auc 20 | from tqdm import tqdm 21 | 22 | import sys 23 | import os 24 | import logging 25 | 26 | 27 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 28 | from models import predictor 29 | from saliency.calculator.gradient_calculator import GradientCalculator 30 | 31 | import data 32 | from data import PYRIDINE_SMILES 33 | from plot_precision_recall import calc_recall_precision 34 | 35 | 36 | def calc_prc_auc(saliency, rates, haspiindex): 37 | recall, precision = calc_recall_precision(saliency, rates, haspiindex) 38 | print('recall', recall) 39 | print('precision', precision) 40 | prcauc = auc(recall, precision) 41 | print('prcauc', prcauc) 42 | return recall, precision, prcauc 43 | 44 | 45 | def parse(): 46 | parser = argparse.ArgumentParser( 47 | description='Multitask Learning with Tox21.') 48 | parser.add_argument('--method', '-m', type=str, 49 | default='nfp', help='graph convolution model to use ' 50 | 'as a predictor.') 51 | parser.add_argument('--label', '-l', type=str, 52 | default='pyridine', help='target label for logistic ' 53 | 'regression. Use all labels if this option ' 54 | 'is not specified.') 55 | parser.add_argument('--conv-layers', '-c', type=int, default=4, 56 | help='number of convolution layers') 57 | parser.add_argument('--n-layers', type=int, default=1, 58 | help='number of mlp layers') 59 | parser.add_argument('--batchsize', '-b', type=int, default=128, 60 | help='batch size') 61 | parser.add_argument('--gpu', '-g', type=int, default=-1, 62 | help='GPU ID to use. Negative value indicates ' 63 | 'not to use GPU and to run the code in CPU.') 64 | parser.add_argument('--out', '-o', type=str, default='result', 65 | help='path to output directory') 66 | parser.add_argument('--epoch', '-e', type=int, default=10, 67 | help='number of epochs') 68 | parser.add_argument('--unit-num', '-u', type=int, default=16, 69 | help='number of units in one layer of the model') 70 | parser.add_argument('--dropout-ratio', '-d', type=float, default=0.25, 71 | help='dropout_ratio') 72 | parser.add_argument('--seeds', type=int, default=0, 73 | help='number of seed to use for calculation') 74 | parser.add_argument('--num-train', type=int, default=-1, 75 | help='number of training data to be used, ' 76 | 'negative value indicates use all train data') 77 | parser.add_argument('--scale', type=float, default=0.15, help='scale for smoothgrad') 78 | parser.add_argument('--mode', type=str, default='absolute', help='mode for smoothgrad') 79 | args = parser.parse_args() 80 | return args 81 | 82 | 83 | def main(method, labels, unit_num, conv_layers, class_num, n_layers, 84 | dropout_ratio, model_path_list, save_dir_path, scale=0.15, mode='relative', M=5): 85 | # Dataset preparation 86 | train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels) 87 | 88 | # --- model preparation --- 89 | model = predictor.build_predictor( 90 | method, unit_num, conv_layers, class_num, dropout_ratio, n_layers) 91 | 92 | classifier = L.Classifier(model, 93 | lossfun=F.sigmoid_cross_entropy, 94 | accfun=F.binary_accuracy) 95 | 96 | target_dataset = val 97 | target_smiles = val_smiles 98 | 99 | val_mols = [Chem.MolFromSmiles(smi) for smi in tqdm(val_smiles)] 100 | 101 | pi = Chem.MolFromSmarts(PYRIDINE_SMILES) 102 | piindex = np.where(np.array([mol.HasSubstructMatch(pi) for mol in val_mols]) == True) 103 | haspi = np.array(val_mols)[piindex] 104 | 105 | # It only extracts one substructure, not expected behavior 106 | # haspiindex = [set(mol.GetSubstructMatch(pi)) for mol in haspi] 107 | def flatten_tuple(x): 108 | return [element for tupl in x for element in tupl] 109 | haspiindex = [flatten_tuple(mol.GetSubstructMatches(pi)) for mol in haspi] 110 | print('piindex', piindex) 111 | print('haspi', haspi.shape) 112 | print('haspiindex', haspiindex) 113 | print('haspiindex length', [len(k) for k in haspiindex]) 114 | 115 | pyrigine_dataset = NumpyTupleDataset(*target_dataset.features[piindex, :]) 116 | pyrigine_smiles = target_smiles[piindex] 117 | 118 | atoms = pyrigine_dataset.features[:, 0] 119 | num_atoms = [len(a) for a in atoms] 120 | 121 | def clip_original_size(saliency, num_atoms): 122 | """`saliency` array is 0 padded, this method align to have original 123 | molecule's length 124 | """ 125 | assert len(saliency) == len(num_atoms) 126 | saliency_list = [] 127 | for i in range(len(saliency)): 128 | saliency_list.append(saliency[i, :num_atoms[i]]) 129 | return saliency_list 130 | 131 | def preprocess_fun(*inputs): 132 | atom, adj, t = inputs 133 | # HACKING for now... 134 | # classifier.predictor.pick = True 135 | # result = classifier.predictor(atom, adj) 136 | atom_embed = classifier.predictor.graph_conv.embed(atom) 137 | return atom_embed, adj, t 138 | 139 | def eval_fun(*inputs): 140 | atom_embed, adj, t = inputs 141 | prob = classifier.predictor(atom_embed, adj) 142 | # print('embed', atom_embed.shape, 'prob', prob.shape) 143 | out = F.sum(prob) 144 | # return {'embed': atom_embed, 'out': out} 145 | return out 146 | 147 | gradient_calculator = GradientCalculator( 148 | classifier, eval_fun=eval_fun, 149 | # target_key='embed', eval_key='out', 150 | target_key=0, 151 | ) 152 | 153 | print('M', M) 154 | # rates = np.array(list(range(1, 11))) * 0.1 155 | num = 20 156 | # rates = np.linspace(0, 1, num=num+1)[1:] 157 | rates = np.linspace(0.1, 1, num=num) 158 | print('rates', len(rates), rates) 159 | 160 | fig = plt.figure(figsize=(7, 5), dpi=200) 161 | 162 | precisions_vanilla = [] 163 | precisions_smooth = [] 164 | precisions_bayes = [] 165 | precisions_bayes_smooth = [] 166 | prcauc_vanilla = [] 167 | prcauc_smooth = [] 168 | prcauc_bayes = [] 169 | prcauc_bayes_smooth = [] 170 | 171 | prcauc_diff_smooth_vanilla = [] 172 | prcauc_diff_bayes_vanilla = [] 173 | for model_path in model_path_list: 174 | serializers.load_npz(model_path, model) 175 | 176 | # --- VanillaGrad --- 177 | saliency_arrays = gradient_calculator.compute_vanilla( 178 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun) 179 | saliency = gradient_calculator.transform( 180 | saliency_arrays, ch_axis=3, method='square') 181 | # saliency_arrays (1, 28, 43, 64) -> M, batch_size, max_atom, ch_dim 182 | print('saliency_arrays', saliency_arrays.shape) 183 | # saliency (28, 43) -> batch_size, max_atom 184 | print('saliency', saliency.shape) 185 | saliency_vanilla = clip_original_size(saliency, num_atoms) 186 | 187 | # recall & precision 188 | print('vanilla') 189 | naiverecall, naiveprecision, naiveprcauc = calc_prc_auc(saliency_vanilla, rates, haspiindex) 190 | precisions_vanilla.append(naiveprecision) 191 | prcauc_vanilla.append(naiveprcauc) 192 | 193 | # --- SmoothGrad --- 194 | saliency_arrays = gradient_calculator.compute_smooth( 195 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun, 196 | M=M, scale=scale, mode=mode) 197 | saliency = gradient_calculator.transform( 198 | saliency_arrays, ch_axis=3, method='square') 199 | 200 | saliency_smooth = clip_original_size(saliency, num_atoms) 201 | 202 | # recall & precision 203 | print('smooth') 204 | smoothrecall, smoothprecision, smoothprcauc = calc_prc_auc(saliency_smooth, rates, haspiindex) 205 | precisions_smooth.append(smoothprecision) 206 | prcauc_smooth.append(smoothprcauc) 207 | 208 | # --- BayesGrad --- 209 | saliency_arrays = gradient_calculator.compute_vanilla( 210 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun, train=True, M=M) 211 | saliency = gradient_calculator.transform( 212 | saliency_arrays, ch_axis=3, method='square', lam=0) 213 | saliency_bayes = clip_original_size(saliency, num_atoms) 214 | 215 | bgrecall0, bgprecision0, bayesprcauc = calc_prc_auc(saliency_bayes, rates, haspiindex) 216 | precisions_bayes.append(bgprecision0) 217 | prcauc_bayes.append(bayesprcauc) 218 | prcauc_diff_smooth_vanilla.append(smoothprcauc - naiveprcauc) 219 | prcauc_diff_bayes_vanilla.append(bayesprcauc - naiveprcauc) 220 | 221 | # --- BayesSmoothGrad --- 222 | saliency_arrays = gradient_calculator.compute_smooth( 223 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun, 224 | M=M, scale=scale, mode=mode, train=True) 225 | saliency = gradient_calculator.transform( 226 | saliency_arrays, ch_axis=3, method='square') 227 | saliency_bayes_smooth = clip_original_size(saliency, num_atoms) 228 | # recall & precision 229 | print('bayes smooth') 230 | bayes_smoothrecall, bayes_smoothprecision, bayes_smoothprcauc = calc_prc_auc(saliency_bayes_smooth, rates, haspiindex) 231 | precisions_bayes_smooth.append(bayes_smoothprecision) 232 | prcauc_bayes_smooth.append(bayes_smoothprcauc) 233 | 234 | precisions_vanilla = np.array(precisions_vanilla) 235 | precisions_smooth = np.array(precisions_smooth) 236 | precisions_bayes = np.array(precisions_bayes) 237 | precisions_bayes_smooth = np.array(precisions_bayes_smooth) 238 | 239 | df = pandas.DataFrame({ 240 | 'model_path': model_path, 241 | 'prcauc_vanilla': prcauc_vanilla, 242 | 'prcauc_smooth': prcauc_smooth, 243 | 'prcauc_bayes': prcauc_bayes, 244 | 'prcauc_bayes_smooth': prcauc_bayes_smooth, 245 | 'prcauc_diff_smooth_vanilla': prcauc_diff_smooth_vanilla, 246 | 'prcauc_diff_bayes_vanilla': prcauc_diff_bayes_vanilla 247 | }) 248 | save_csv_path = save_dir_path + '/prcauc_{}_{}.csv'.format(mode, scale) 249 | print('save to ', save_csv_path) 250 | df.to_csv(save_csv_path) 251 | 252 | prcauc_vanilla = np.array(prcauc_vanilla) 253 | prcauc_smooth = np.array(prcauc_smooth) 254 | prcauc_bayes = np.array(prcauc_bayes) 255 | prcauc_bayes_smooth = np.array(prcauc_bayes_smooth) 256 | prcauc_diff_smooth_vanilla = np.array(prcauc_diff_smooth_vanilla) 257 | prcauc_diff_bayes_vanilla = np.array(prcauc_diff_bayes_vanilla) 258 | 259 | def show_avg_std(array, tag=''): 260 | print('{}: mean {:8.03}, std {:8.03}' 261 | .format(tag, np.mean(array, axis=0), np.std(array, axis=0))) 262 | return {'method': tag, 'mean': np.mean(array, axis=0), 'std': np.std(array, axis=0)} 263 | 264 | df = pandas.DataFrame([ 265 | show_avg_std(prcauc_vanilla, tag='vanilla'), 266 | show_avg_std(prcauc_smooth, tag='smooth'), 267 | show_avg_std(prcauc_bayes, tag='bayes'), 268 | show_avg_std(prcauc_bayes_smooth, tag='bayes_smooth'), 269 | show_avg_std(prcauc_diff_smooth_vanilla, tag='diff_smooth_vanilla'), 270 | show_avg_std(prcauc_diff_bayes_vanilla, tag='diff_bayes_vanilla'), 271 | ]) 272 | save_csv_path = save_dir_path + '/prcauc_stats_{}_{}.csv'.format(mode, scale) 273 | print('save to ', save_csv_path) 274 | df.to_csv(save_csv_path) 275 | # import IPython; IPython.embed() 276 | 277 | def _plot_with_errorbar(x, precisions, color='blue', alpha=None, label=None): 278 | y = np.mean(precisions, axis=0) 279 | plt.errorbar(x, y, yerr=np.std(precisions, axis=0), fmt='ro', ecolor=color) # fmt='' 280 | plt.plot(x, y, 'k-', color=color, label=label, alpha=alpha) 281 | 282 | alpha = 0.5 283 | _plot_with_errorbar(rates, precisions_vanilla, color='blue', alpha=alpha, label='VanillaGrad') 284 | _plot_with_errorbar(rates, precisions_smooth, color='green', alpha=alpha, label='SmoothGrad') 285 | _plot_with_errorbar(rates, precisions_bayes, color='yellow', alpha=alpha, label='BayesGrad') 286 | _plot_with_errorbar(rates, precisions_bayes_smooth, color='orange', alpha=alpha, label='BayesSmoothGrad') 287 | plt.legend() 288 | plt.xlabel("recall") 289 | plt.ylabel("precision") 290 | save_path = os.path.join(save_dir_path, 'artificial_pr.png') 291 | print('saved to ', save_path) 292 | plt.savefig(save_path) 293 | 294 | 295 | if __name__ == '__main__': 296 | # Disable errors by RDKit occurred in preprocessing Tox21 dataset. 297 | lg = RDLogger.logger() 298 | lg.setLevel(RDLogger.CRITICAL) 299 | # show INFO level log from chainer chemistry 300 | logging.basicConfig(level=logging.INFO) 301 | 302 | args = parse() 303 | # --- config --- 304 | method = args.method 305 | labels = args.label 306 | unit_num = args.unit_num 307 | conv_layers = args.conv_layers 308 | class_num = 1 309 | n_layers = args.n_layers 310 | dropout_ratio = args.dropout_ratio 311 | num_train = args.num_train 312 | seeds = args.seeds 313 | 314 | root = '.' 315 | model_path_list = [] 316 | for i in range(seeds): 317 | dir_path = '{}/results/{}_{}_numtrain{}_seed{}'.format( 318 | root, method, labels, num_train, i) 319 | model_path = os.path.join(dir_path, 'predictor.npz') 320 | model_path_list.append(model_path) 321 | 322 | save_dir_path = '{}/results/{}_{}_numtrain{}_seed0-{}'.format( 323 | root, method, labels, num_train, seeds-1) 324 | if not os.path.exists(save_dir_path): 325 | os.mkdir(save_dir_path) 326 | 327 | # --- config end --- 328 | 329 | main(method, labels, unit_num, conv_layers, class_num, n_layers, 330 | dropout_ratio, model_path_list, save_dir_path, args.scale, args.mode, M=100) 331 | -------------------------------------------------------------------------------- /experiments/tox21/calc_prcauc_with_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eu 4 | 5 | gpu=${1:--1} 6 | label=pyridine 7 | seeds=30 8 | num_train=1000 9 | unit_num=16 10 | 11 | # mode=relative 12 | mode=absolute 13 | method=ggnndrop 14 | ratio=0.25 15 | 16 | for scale in 0.05 0.10 0.15 0.20; do 17 | python calc_prcauc_with_seeds.py -g ${gpu} --label=${label} --unit-num=${unit_num} --n-layers=1 --dropout-ratio=${ratio} --num-train=${num_train} --seed=${seeds} --method=${method} --scale=${scale} --mode=${mode} 18 | done 19 | -------------------------------------------------------------------------------- /experiments/tox21/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy 4 | from chainer_chemistry.dataset.preprocessors import preprocess_method_dict 5 | from chainer_chemistry import datasets as D 6 | from chainer_chemistry.datasets.numpy_tuple_dataset import NumpyTupleDataset 7 | from rdkit import Chem 8 | from tqdm import tqdm 9 | 10 | import utils 11 | 12 | 13 | class _CacheNamePolicy(object): 14 | 15 | train_file_name = 'train.npz' 16 | val_file_name = 'val.npz' 17 | test_file_name = 'test.npz' 18 | smiles_file_name = 'smiles.npz' 19 | 20 | def _get_cache_directory_path(self, method, labels, prefix): 21 | if labels: 22 | return os.path.join(prefix, '{}_{}'.format(method, labels)) 23 | else: 24 | return os.path.join(prefix, '{}_all'.format(method)) 25 | 26 | def __init__(self, method, labels, prefix='input'): 27 | self.method = method 28 | self.labels = labels 29 | self.prefix = prefix 30 | self.cache_dir = self._get_cache_directory_path(method, labels, prefix) 31 | 32 | def get_train_file_path(self): 33 | return os.path.join(self.cache_dir, self.train_file_name) 34 | 35 | def get_val_file_path(self): 36 | return os.path.join(self.cache_dir, self.val_file_name) 37 | 38 | def get_test_file_path(self): 39 | return os.path.join(self.cache_dir, self.test_file_name) 40 | 41 | def get_smiles_path(self): 42 | return os.path.join(self.cache_dir, self.smiles_file_name) 43 | 44 | def create_cache_directory(self): 45 | try: 46 | os.makedirs(self.cache_dir) 47 | except OSError: 48 | if not os.path.isdir(self.cache_dir): 49 | raise 50 | 51 | 52 | PYRIDINE_SMILES = 'c1ccncc1' 53 | 54 | 55 | def hassubst(mol, smart=PYRIDINE_SMILES): 56 | return numpy.array(int(mol.HasSubstructMatch(Chem.MolFromSmarts(smart)))).astype('int32') 57 | 58 | 59 | def load_dataset(method, labels, prefix='input'): 60 | method = 'nfp' if 'nfp' in method else method # to deal with nfpdrop 61 | method = 'ggnn' if 'ggnn' in method else method # to deal with ggnndrop 62 | policy = _CacheNamePolicy(method, labels, prefix) 63 | train_path = policy.get_train_file_path() 64 | val_path = policy.get_val_file_path() 65 | test_path = policy.get_test_file_path() 66 | smiles_path = policy.get_smiles_path() 67 | 68 | train, val, test = None, None, None 69 | train_smiles, val_smiles, test_smiles = None, None, None 70 | print() 71 | if os.path.exists(policy.cache_dir): 72 | print('load from cache {}'.format(policy.cache_dir)) 73 | train = NumpyTupleDataset.load(train_path) 74 | val = NumpyTupleDataset.load(val_path) 75 | test = NumpyTupleDataset.load(test_path) 76 | train_smiles, val_smiles, test_smiles = utils.load_npz(smiles_path) 77 | if train is None or val is None or test is None: 78 | print('preprocessing dataset...') 79 | preprocessor = preprocess_method_dict[method]() 80 | if labels == 'pyridine': 81 | train, val, test, train_smiles, val_smiles, test_smiles = D.get_tox21( 82 | preprocessor, labels=None, return_smiles=True) 83 | print('converting label into pyridine...') 84 | # --- Pyridine = 1 --- 85 | train_pyridine_label = [ 86 | hassubst(Chem.MolFromSmiles(smi), smart=PYRIDINE_SMILES) for smi in tqdm(train_smiles)] 87 | val_pyridine_label = [ 88 | hassubst(Chem.MolFromSmiles(smi), smart=PYRIDINE_SMILES) for smi in tqdm(val_smiles)] 89 | test_pyridine_label = [ 90 | hassubst(Chem.MolFromSmiles(smi), smart=PYRIDINE_SMILES) for smi in tqdm(test_smiles)] 91 | 92 | train_pyridine_label = numpy.array(train_pyridine_label)[:, None] 93 | val_pyridine_label = numpy.array(val_pyridine_label)[:, None] 94 | test_pyridine_label = numpy.array(test_pyridine_label)[:, None] 95 | print('train positive/negative', numpy.sum(train_pyridine_label == 1), numpy.sum(train_pyridine_label == 0)) 96 | train = NumpyTupleDataset(*train.features[:, :-1], train_pyridine_label) 97 | val = NumpyTupleDataset(*val.features[:, :-1], val_pyridine_label) 98 | test = NumpyTupleDataset(*test.features[:, :-1], test_pyridine_label) 99 | else: 100 | train, val, test, train_smiles, val_smiles, test_smiles = D.get_tox21( 101 | preprocessor, labels=labels, return_smiles=True) 102 | 103 | # Cache dataset 104 | policy.create_cache_directory() 105 | NumpyTupleDataset.save(train_path, train) 106 | NumpyTupleDataset.save(val_path, val) 107 | NumpyTupleDataset.save(test_path, test) 108 | train_smiles = numpy.array(train_smiles) 109 | val_smiles = numpy.array(val_smiles) 110 | test_smiles = numpy.array(test_smiles) 111 | utils.save_npz(smiles_path, (train_smiles, val_smiles, test_smiles)) 112 | return train, val, test, train_smiles, val_smiles, test_smiles 113 | -------------------------------------------------------------------------------- /experiments/tox21/plot_precision_recall.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import argparse 5 | import json 6 | 7 | try: 8 | import matplotlib 9 | matplotlib.use('agg') 10 | except: 11 | pass 12 | import matplotlib.pyplot as plt 13 | 14 | from chainer import functions as F 15 | from chainer import links as L 16 | from tqdm import tqdm 17 | from chainer import serializers 18 | import numpy as np 19 | from rdkit import RDLogger, Chem 20 | 21 | from chainer_chemistry.dataset.converters import concat_mols 22 | from chainer_chemistry.datasets import NumpyTupleDataset 23 | 24 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 25 | from models import predictor 26 | from saliency.calculator.gradient_calculator import GradientCalculator 27 | from saliency.calculator.integrated_gradients_calculator import IntegratedGradientsCalculator 28 | from saliency.calculator.occlusion_calculator import OcclusionCalculator 29 | 30 | import data 31 | from data import PYRIDINE_SMILES, hassubst 32 | 33 | 34 | def percentile_index(ar, num): 35 | """ 36 | ar (numpy.ndarray): array 37 | num (float): rate 38 | 39 | Extract `num` rate of largest index in this array. 40 | """ 41 | threshold = int(len(ar) * num) 42 | idx = np.argsort(ar) 43 | return idx[-threshold:] 44 | 45 | 46 | def calc_recall_precision_for_rate(grads, rate, haspiindex): 47 | recall_list = [] 48 | hit_rate_list = [] 49 | for i in range(len(grads)): 50 | largest_index = percentile_index(grads[i], float(rate)) 51 | set_largest_index = set(largest_index) 52 | hit_index = set_largest_index.intersection(haspiindex[i]) 53 | hit_num = len(hit_index) 54 | hit_rate = float(hit_num) / float(len(set_largest_index)) 55 | 56 | recall_list.append(float(hit_num) / len(haspiindex[i])) 57 | hit_rate_list.append(hit_rate) 58 | recall = np.mean(np.array(recall_list)) 59 | precision = np.mean(np.array(hit_rate_list)) 60 | return recall, precision 61 | 62 | 63 | def calc_recall_precision(grads, rates, haspiindex): 64 | r_list = [] 65 | p_list = [] 66 | for rate in rates: 67 | r, p = calc_recall_precision_for_rate(grads, rate, haspiindex) 68 | r_list.append(r) 69 | p_list.append(p) 70 | return r_list, p_list 71 | 72 | 73 | def parse(): 74 | parser = argparse.ArgumentParser( 75 | description='Multitask Learning with Tox21.') 76 | parser.add_argument('--batchsize', '-b', type=int, default=128, 77 | help='batch size') 78 | parser.add_argument('--gpu', '-g', type=int, default=-1, 79 | help='GPU ID to use. Negative value indicates ' 80 | 'not to use GPU and to run the code in CPU.') 81 | parser.add_argument('--dirpath', '-d', type=str, default='results', 82 | help='path to train results directory') 83 | parser.add_argument('--calculator', type=str, default='gradient') 84 | args = parser.parse_args() 85 | return args 86 | 87 | 88 | def main(method, labels, unit_num, conv_layers, class_num, n_layers, 89 | dropout_ratio, model_path, save_path): 90 | # Dataset preparation 91 | train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels) 92 | 93 | # --- model preparation --- 94 | model = predictor.build_predictor( 95 | method, unit_num, conv_layers, class_num, dropout_ratio, n_layers) 96 | 97 | classifier = L.Classifier(model, 98 | lossfun=F.sigmoid_cross_entropy, 99 | accfun=F.binary_accuracy) 100 | 101 | print('Loading model parameter from ', model_path) 102 | serializers.load_npz(model_path, model) 103 | 104 | target_dataset = val 105 | target_smiles = val_smiles 106 | 107 | val_mols = [Chem.MolFromSmiles(smi) for smi in tqdm(val_smiles)] 108 | 109 | pyridine_mol = Chem.MolFromSmarts(PYRIDINE_SMILES) 110 | pyridine_index = np.where(np.array([mol.HasSubstructMatch(pyridine_mol) for mol in val_mols]) == True) 111 | val_pyridine_mols = np.array(val_mols)[pyridine_index] 112 | 113 | # It only extracts one substructure, not expected behavior 114 | # val_pyridine_pos = [set(mol.GetSubstructMatch(pi)) for mol in val_pyridine_mols] 115 | def flatten_tuple(x): 116 | return [element for tupl in x for element in tupl] 117 | 118 | val_pyridine_pos = [flatten_tuple(mol.GetSubstructMatches(pyridine_mol)) for mol in val_pyridine_mols] 119 | 120 | # print('pyridine_index', pyridine_index) 121 | # print('val_pyridine_mols', val_pyridine_mols.shape) 122 | # print('val_pyridine_pos', val_pyridine_pos) 123 | # print('val_pyridine_pos length', [len(k) for k in val_pyridine_pos]) 124 | 125 | pyrigine_dataset = NumpyTupleDataset(*target_dataset.features[pyridine_index, :]) 126 | pyrigine_smiles = target_smiles[pyridine_index] 127 | print('pyrigine_dataset', len(pyrigine_dataset), len(pyrigine_smiles)) 128 | 129 | atoms = pyrigine_dataset.features[:, 0] 130 | num_atoms = [len(a) for a in atoms] 131 | 132 | def clip_original_size(saliency, num_atoms): 133 | """`saliency` array is 0 padded, this method align to have original 134 | molecule's length 135 | """ 136 | assert len(saliency) == len(num_atoms) 137 | saliency_list = [] 138 | for i in range(len(saliency)): 139 | saliency_list.append(saliency[i, :num_atoms[i]]) 140 | return saliency_list 141 | 142 | def preprocess_fun(*inputs): 143 | atom, adj, t = inputs 144 | # HACKING for now... 145 | atom_embed = classifier.predictor.graph_conv.embed(atom) 146 | return atom_embed, adj, t 147 | 148 | def eval_fun(*inputs): 149 | atom_embed, adj, t = inputs 150 | prob = classifier.predictor(atom_embed, adj) 151 | out = F.sum(prob) 152 | return out 153 | 154 | calculator_method = args.calculator 155 | print('calculator method', calculator_method) 156 | if calculator_method == 'gradient': 157 | # option1: Gradient 158 | calculator = GradientCalculator( 159 | classifier, eval_fun=eval_fun, 160 | # target_key='embed', eval_key='out', 161 | target_key=0, 162 | # multiply_target=True # this will calculate grad * input 163 | ) 164 | elif calculator_method == 'integrated_gradients': 165 | # option2: IntegratedGradients 166 | calculator = IntegratedGradientsCalculator( 167 | classifier, eval_fun=eval_fun, 168 | # target_key='embed', eval_key='out', 169 | target_key=0, steps=10 170 | ) 171 | elif calculator_method == 'occlusion': 172 | # option3: Occlusion 173 | def eval_fun_occlusion(*inputs): 174 | atom_embed, adj, t = inputs 175 | prob = classifier.predictor(atom_embed, adj) 176 | # Do not take sum, instead return batch-wise score 177 | out = F.sigmoid(prob) 178 | return out 179 | calculator = OcclusionCalculator( 180 | classifier, eval_fun=eval_fun_occlusion, 181 | # target_key='embed', eval_key='out', 182 | target_key=0, slide_axis=1 183 | ) 184 | else: 185 | raise ValueError("[ERROR] Unexpected value calculator_method={}".format(calculator_method)) 186 | 187 | M = 100 188 | num = 20 189 | rates = np.linspace(0.1, 1, num=num) 190 | print('M', M) 191 | 192 | # --- VanillaGrad --- 193 | saliency_arrays = calculator.compute_vanilla( 194 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun) 195 | saliency = calculator.transform( 196 | saliency_arrays, ch_axis=3, method='square') 197 | # saliency_arrays -> M, batch_size, max_atom, ch_dim 198 | # print('saliency_arrays', saliency_arrays.shape) 199 | # saliency -> batch_size, max_atom 200 | # print('saliency', saliency.shape) 201 | saliency_vanilla = clip_original_size(saliency, num_atoms) 202 | 203 | # recall & precision 204 | vanilla_recall, vanilla_precision = calc_recall_precision(saliency_vanilla, rates, val_pyridine_pos) 205 | print('vanilla_recall', vanilla_recall) 206 | print('vanilla_precision', vanilla_precision) 207 | 208 | # --- SmoothGrad --- 209 | saliency_arrays = calculator.compute_smooth( 210 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun, 211 | M=M, 212 | mode='absolute', scale=0.15 # previous implementation 213 | # mode='relative', scale=0.05 214 | ) 215 | saliency = calculator.transform( 216 | saliency_arrays, ch_axis=3, method='square') 217 | 218 | saliency_smooth = clip_original_size(saliency, num_atoms) 219 | 220 | # recall & precision 221 | smooth_recall, smooth_precision = calc_recall_precision(saliency_smooth, rates, val_pyridine_pos) 222 | print('smooth_recall', smooth_recall) 223 | print('smooth_precision', smooth_precision) 224 | 225 | # --- BayesGrad --- 226 | # bayes grad is calculated by compute_vanilla with train=True 227 | saliency_arrays = calculator.compute_vanilla( 228 | pyrigine_dataset, converter=concat_mols, preprocess_fn=preprocess_fun, 229 | M=M, train=True) 230 | saliency = calculator.transform( 231 | saliency_arrays, ch_axis=3, method='square', lam=0) 232 | saliency_bayes = clip_original_size(saliency, num_atoms) 233 | 234 | bayes_recall, bayes_precision = calc_recall_precision(saliency_bayes, rates, val_pyridine_pos) 235 | print('bayes_recall', bayes_recall) 236 | print('bayes_precision', bayes_precision) 237 | 238 | plt.figure(figsize=(7, 5), dpi=200) 239 | plt.plot(vanilla_recall, vanilla_precision, 'k-', color='blue', label='VanillaGrad') 240 | plt.plot(smooth_recall, smooth_precision, 'k-', color='green', label='SmoothGrad') 241 | plt.plot(bayes_recall, bayes_precision, 'k-', color='red', label='BayesGrad(Ours)') 242 | plt.axhline(y=vanilla_precision[-1], color='gray', linestyle='--') 243 | plt.legend() 244 | plt.xlabel("recall") 245 | plt.ylabel("precision") 246 | if save_path: 247 | print('saved to ', save_path) 248 | plt.savefig(save_path) 249 | # plt.savefig('artificial_pr.eps') 250 | else: 251 | plt.show() 252 | 253 | 254 | if __name__ == '__main__': 255 | # Disable errors by RDKit occurred in preprocessing Tox21 dataset. 256 | lg = RDLogger.logger() 257 | lg.setLevel(RDLogger.CRITICAL) 258 | # show INFO level log from chainer chemistry 259 | logging.basicConfig(level=logging.INFO) 260 | 261 | args = parse() 262 | # --- extracting configs --- 263 | dirpath = args.dirpath 264 | json_path = os.path.join(dirpath, 'args.json') 265 | if not os.path.exists(json_path): 266 | raise ValueError( 267 | 'json_path {} not found! Execute train_tox21.py beforehand.'.format(json)) 268 | with open(json_path, 'r') as f: 269 | train_args = json.load(f) 270 | 271 | method = train_args['method'] 272 | labels = train_args['label'] # 'pyridine' 273 | 274 | unit_num = train_args['unit_num'] 275 | conv_layers = train_args['conv_layers'] 276 | class_num = 1 277 | n_layers = train_args['n_layers'] 278 | dropout_ratio = train_args['dropout_ratio'] 279 | num_train = train_args['num_train'] 280 | # seed = train_args['seed'] 281 | # --- extracting configs end --- 282 | 283 | model_path = os.path.join(dirpath, 'predictor.npz') 284 | save_path = os.path.join( 285 | dirpath, 'precision_recall_{}.png'.format(args.calculator)) 286 | # --- config end --- 287 | 288 | main(method, labels, unit_num, conv_layers, class_num, n_layers, 289 | dropout_ratio, model_path, save_path) 290 | -------------------------------------------------------------------------------- /experiments/tox21/train_few_with_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eu 4 | 5 | gpu=${1:--1} 6 | # seed start to end 7 | start=0 8 | end=30 9 | num_train=1000 10 | unit_num=16 11 | epoch=100 12 | label=pyridine 13 | method=ggnndrop 14 | ratio=0.25 15 | 16 | for ((seed=$start; seed < $end; seed++)); do 17 | python train_tox21.py -g ${gpu} --iterator-type=balanced --label=${label} --method=${method} --epoch=${epoch} --unit-num=${unit_num} --n-layers=1 -b 32 --conv-layers=4 --num-train=${num_train} --seed=${seed} --dropout-ratio=${ratio} --out=results/${method}_${label}_numtrain${num_train}_seed${seed} 18 | done 19 | -------------------------------------------------------------------------------- /experiments/tox21/train_tox21.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Tox21 data training 4 | Additionally, artificial pyridine dataset training is supported. 5 | """ 6 | 7 | from __future__ import print_function 8 | 9 | import logging 10 | import sys 11 | import os 12 | 13 | try: 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | except ImportError: 17 | pass 18 | 19 | import numpy as np 20 | import argparse 21 | import chainer 22 | from chainer import functions as F 23 | from chainer import iterators as I 24 | from chainer import links as L 25 | from chainer import optimizers as O 26 | from chainer import training 27 | from chainer.training import extensions as E 28 | import json 29 | from rdkit import RDLogger, Chem 30 | 31 | from chainer_chemistry.datasets import NumpyTupleDataset 32 | from chainer_chemistry.dataset.converters import concat_mols 33 | from chainer_chemistry import datasets as D 34 | from chainer_chemistry.iterators.balanced_serial_iterator import BalancedSerialIterator # NOQA 35 | from chainer_chemistry.training.extensions.roc_auc_evaluator import ROCAUCEvaluator 36 | 37 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 38 | from models import predictor 39 | 40 | import data 41 | 42 | 43 | def main(): 44 | # Supported preprocessing/network list 45 | method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'nfpdrop', 'ggnndrop'] 46 | label_names = D.get_tox21_label_names() + ['pyridine'] 47 | iterator_type = ['serial', 'balanced'] 48 | 49 | parser = argparse.ArgumentParser( 50 | description='Multitask Learning with Tox21.') 51 | parser.add_argument('--method', '-m', type=str, choices=method_list, 52 | default='nfp', help='graph convolution model to use ' 53 | 'as a predictor.') 54 | parser.add_argument('--label', '-l', type=str, choices=label_names, 55 | default='', help='target label for logistic ' 56 | 'regression. Use all labels if this option ' 57 | 'is not specified.') 58 | parser.add_argument('--iterator-type', type=str, choices=iterator_type, 59 | default='serial', help='iterator type. If `balanced` ' 60 | 'is specified, data is sampled to take same number of' 61 | 'positive/negative labels during training.') 62 | parser.add_argument('--conv-layers', '-c', type=int, default=4, 63 | help='number of convolution layers') 64 | parser.add_argument('--n-layers', type=int, default=1, 65 | help='number of mlp layers') 66 | parser.add_argument('--batchsize', '-b', type=int, default=32, 67 | help='batch size') 68 | parser.add_argument('--gpu', '-g', type=int, default=-1, 69 | help='GPU ID to use. Negative value indicates ' 70 | 'not to use GPU and to run the code in CPU.') 71 | parser.add_argument('--out', '-o', type=str, default='result', 72 | help='path to output directory') 73 | parser.add_argument('--epoch', '-e', type=int, default=10, 74 | help='number of epochs') 75 | parser.add_argument('--unit-num', '-u', type=int, default=16, 76 | help='number of units in one layer of the model') 77 | parser.add_argument('--resume', '-r', type=str, default='', 78 | help='path to a trainer snapshot') 79 | parser.add_argument('--frequency', '-f', type=int, default=-1, 80 | help='Frequency of taking a snapshot') 81 | parser.add_argument('--dropout-ratio', '-d', type=float, default=0.25, 82 | help='dropout_ratio') 83 | parser.add_argument('--seed', type=int, default=0, 84 | help='random seed') 85 | parser.add_argument('--num-train', type=int, default=-1, 86 | help='number of training data to be used, ' 87 | 'negative value indicates use all train data') 88 | args = parser.parse_args() 89 | 90 | method = args.method 91 | if args.label: 92 | labels = args.label 93 | class_num = len(labels) if isinstance(labels, list) else 1 94 | else: 95 | labels = None 96 | class_num = len(label_names) 97 | 98 | # Dataset preparation 99 | train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels) 100 | 101 | num_train = args.num_train # 100 102 | if num_train > 0: 103 | # reduce size of train data 104 | seed = args.seed # 0 105 | np.random.seed(seed) 106 | train_selected_label = np.random.permutation(np.arange(len(train)))[:num_train] 107 | print('num_train', num_train, len(train_selected_label), train_selected_label) 108 | train = NumpyTupleDataset(*train.features[train_selected_label, :]) 109 | # Network 110 | predictor_ = predictor.build_predictor( 111 | method, args.unit_num, args.conv_layers, class_num, args.dropout_ratio, 112 | args.n_layers 113 | ) 114 | 115 | iterator_type = args.iterator_type 116 | if iterator_type == 'serial': 117 | train_iter = I.SerialIterator(train, args.batchsize) 118 | elif iterator_type == 'balanced': 119 | if class_num > 1: 120 | raise ValueError('BalancedSerialIterator can be used with only one' 121 | 'label classification, please specify label to' 122 | 'be predicted by --label option.') 123 | train_iter = BalancedSerialIterator( 124 | train, args.batchsize, train.features[:, -1], ignore_labels=-1) 125 | train_iter.show_label_stats() 126 | else: 127 | raise ValueError('Invalid iterator type {}'.format(iterator_type)) 128 | val_iter = I.SerialIterator(val, args.batchsize, 129 | repeat=False, shuffle=False) 130 | classifier = L.Classifier(predictor_, 131 | lossfun=F.sigmoid_cross_entropy, 132 | accfun=F.binary_accuracy) 133 | if args.gpu >= 0: 134 | chainer.cuda.get_device_from_id(args.gpu).use() 135 | classifier.to_gpu() 136 | 137 | optimizer = O.Adam() 138 | optimizer.setup(classifier) 139 | 140 | updater = training.StandardUpdater( 141 | train_iter, optimizer, device=args.gpu, converter=concat_mols) 142 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) 143 | 144 | trainer.extend(E.Evaluator(val_iter, classifier, 145 | device=args.gpu, converter=concat_mols)) 146 | trainer.extend(E.LogReport()) 147 | 148 | # --- ROCAUC Evaluator --- 149 | train_eval_iter = I.SerialIterator(train, args.batchsize, 150 | repeat=False, shuffle=False) 151 | trainer.extend(ROCAUCEvaluator( 152 | train_eval_iter, classifier, eval_func=predictor_, 153 | device=args.gpu, converter=concat_mols, name='train')) 154 | trainer.extend(ROCAUCEvaluator( 155 | val_iter, classifier, eval_func=predictor_, 156 | device=args.gpu, converter=concat_mols, name='val')) 157 | trainer.extend(E.PrintReport([ 158 | 'epoch', 'main/loss', 'main/accuracy', 'train/main/roc_auc', 159 | 'validation/main/loss', 'validation/main/accuracy', 160 | 'val/main/roc_auc', 'elapsed_time'])) 161 | 162 | trainer.extend(E.ProgressBar(update_interval=10)) 163 | if args.resume: 164 | chainer.serializers.load_npz(args.resume, trainer) 165 | 166 | trainer.run() 167 | 168 | with open(os.path.join(args.out, 'args.json'), 'w') as f: 169 | json.dump(vars(args), f, indent=4) 170 | chainer.serializers.save_npz( 171 | os.path.join(args.out, 'predictor.npz'), predictor_) 172 | 173 | 174 | if __name__ == '__main__': 175 | # Disable errors by RDKit occurred in preprocessing Tox21 dataset. 176 | lg = RDLogger.logger() 177 | lg.setLevel(RDLogger.CRITICAL) 178 | # show INFO level log from chainer chemistry 179 | logging.basicConfig(level=logging.INFO) 180 | 181 | main() 182 | -------------------------------------------------------------------------------- /experiments/tox21/utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | 4 | def save_npz(filepath, datasets): 5 | if not isinstance(datasets, (list, tuple)): 6 | datasets = (datasets, ) 7 | numpy.savez(filepath, *datasets) 8 | 9 | 10 | def load_npz(filepath): 11 | load_data = numpy.load(filepath) 12 | result = [] 13 | i = 0 14 | while True: 15 | key = 'arr_{}'.format(i) 16 | if key in load_data.keys(): 17 | result.append(load_data[key]) 18 | i += 1 19 | else: 20 | break 21 | if len(result) == 1: 22 | result = result[0] 23 | return result 24 | -------------------------------------------------------------------------------- /experiments/tox21/visualize-saliency-tox21.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "#from load import *\n", 12 | "%load_ext autoreload\n", 13 | "%autoreload 2\n", 14 | "import pickle\n", 15 | "from chainer.datasets import TupleDataset\n", 16 | "from chainer.dataset import concat_examples\n", 17 | "from chainer import functions as F, cuda\n", 18 | "from chainer import iterators as I\n", 19 | "from chainer import links as L\n", 20 | "from chainer import optimizers as O\n", 21 | "from chainer import training\n", 22 | "from ipywidgets import interact\n", 23 | "import chainer\n", 24 | "import cupy\n", 25 | "from tqdm import tqdm_notebook as tqdm\n", 26 | "from chainer import serializers\n", 27 | "\n", 28 | "\n", 29 | "from chainer_chemistry.datasets import NumpyTupleDataset\n", 30 | "from chainer_chemistry.dataset.converters import concat_mols\n", 31 | "\n", 32 | "import numpy\n", 33 | "import numpy as np\n", 34 | "\n", 35 | "from rdkit import RDLogger\n", 36 | "from rdkit import Chem\n", 37 | "from rdkit.Chem import rdchem\n", 38 | "from rdkit.Chem import rdDepictor\n", 39 | "from rdkit.Chem.Draw import rdMolDraw2D\n", 40 | "from rdkit.Chem.Draw import IPythonConsole\n", 41 | "from IPython.display import SVG\n", 42 | "from rdkit.Chem import Draw\n", 43 | "from rdkit import rdBase" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "import sys\n", 55 | "import os\n", 56 | "\n", 57 | "sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath('__file__')))))\n", 58 | "from saliency.calculator.gradient_calculator import GradientCalculator\n", 59 | "from models import predictor\n", 60 | "\n", 61 | "sys.path.append(os.path.dirname(os.path.abspath('__file__')))\n", 62 | "import data\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": { 69 | "collapsed": false 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "import logging\n", 74 | "\n", 75 | "# Disable errors by RDKit occurred in preprocessing Tox21 dataset.\n", 76 | "lg = RDLogger.logger()\n", 77 | "lg.setLevel(RDLogger.CRITICAL)\n", 78 | "# show INFO level log from chainer chemistry\n", 79 | "logging.basicConfig(level=logging.INFO)\n" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": { 86 | "collapsed": false 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "import json\n", 91 | "\n", 92 | "# training result directory\n", 93 | "dirpath = './results/nfpdrop_srmmp'\n", 94 | "\n", 95 | "json_path = os.path.join(dirpath, 'args.json')\n", 96 | "if not os.path.exists(json_path):\n", 97 | " raise ValueError(\n", 98 | " 'json_path {} not found! Execute train_tox21.py beforehand.'.format(json_path))\n", 99 | "with open(json_path, 'r') as f:\n", 100 | " train_args = json.load(f)\n", 101 | "\n", 102 | "method = train_args['method']\n", 103 | "labels = train_args['label'] # 'pyridine'\n", 104 | "\n", 105 | "unit_num = train_args['unit_num']\n", 106 | "conv_layers = train_args['conv_layers']\n", 107 | "class_num = 1\n", 108 | "n_layers = train_args['n_layers']\n", 109 | "dropout_ratio = train_args['dropout_ratio']\n", 110 | "num_train = train_args['num_train']\n", 111 | "\n", 112 | "model_path = os.path.join(dirpath, 'predictor.npz')\n" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 5, 118 | "metadata": { 119 | "collapsed": false 120 | }, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "\n", 127 | "load from cache input/nfp_SR-MMP\n", 128 | "\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "# Pyridine Dataset preparation\n", 134 | "train, val, test, train_smiles, val_smiles, test_smiles = data.load_dataset(method, labels)\n", 135 | "val_mols = [Chem.MolFromSmiles(smi) for smi in tqdm(val_smiles)]" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 6, 141 | "metadata": { 142 | "collapsed": false 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "visualize_only_positive = True\n", 147 | "if visualize_only_positive:\n", 148 | " # visualize all only label=1\n", 149 | " pos_index = val.features[:, -1][:, 0] == 1\n", 150 | " # print('pos_index', pos_index.shape, pos_index)\n", 151 | " target_dataset = NumpyTupleDataset(*val.features[pos_index, :])\n", 152 | " target_smiles = val_smiles[pos_index]\n", 153 | "else:\n", 154 | " # visualize all validation data\n", 155 | " target_dataset = val\n", 156 | " target_smiles = val_smiles" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 7, 162 | "metadata": { 163 | "collapsed": false 164 | }, 165 | "outputs": [ 166 | { 167 | "name": "stdout", 168 | "output_type": "stream", 169 | "text": [ 170 | "dropout_ratio, n_layers 0.25 1\n", 171 | "Use NFPDrop predictor...\n" 172 | ] 173 | } 174 | ], 175 | "source": [ 176 | "# --- model preparation ---\n", 177 | "\n", 178 | "model = predictor.build_predictor(\n", 179 | " method, unit_num, conv_layers, class_num, dropout_ratio, n_layers)\n", 180 | "\n", 181 | "classifier = L.Classifier(model,\n", 182 | " lossfun=F.sigmoid_cross_entropy,\n", 183 | " accfun=F.binary_accuracy)\n", 184 | "\n", 185 | "serializers.load_npz(model_path, model)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 8, 191 | "metadata": { 192 | "collapsed": false 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "def clip_original_size(saliency, num_atoms):\n", 197 | " \"\"\"`saliency` array is 0 padded, this method align to have original\n", 198 | " molecule's length\n", 199 | " \"\"\"\n", 200 | " assert len(saliency) == len(num_atoms)\n", 201 | " saliency_list = []\n", 202 | " for i in range(len(saliency)):\n", 203 | " saliency_list.append(saliency[i, :num_atoms[i]])\n", 204 | " return saliency_list\n", 205 | "\n", 206 | "def preprocess_fun(*inputs):\n", 207 | " if len(inputs) == 3:\n", 208 | " atom, adj, t = inputs\n", 209 | " elif len(inputs) == 2:\n", 210 | " atom, adj = inputs\n", 211 | " # HACKING for now...\n", 212 | " # classifier.predictor.pick = True\n", 213 | " # result = classifier.predictor(atom, adj)\n", 214 | " atom_embed = classifier.predictor.graph_conv.embed(atom)\n", 215 | " if len(inputs) == 3:\n", 216 | " return atom_embed, adj, t\n", 217 | " elif len(inputs) == 2:\n", 218 | " return atom_embed, adj\n", 219 | " \n", 220 | "\n", 221 | "def eval_fun(*inputs):\n", 222 | " #atom_embed, adj, t = inputs\n", 223 | " if len(inputs) == 3:\n", 224 | " atom_embed, adj, t = inputs\n", 225 | " elif len(inputs) == 2:\n", 226 | " atom_embed, adj = inputs\n", 227 | " prob = classifier.predictor(atom_embed, adj)\n", 228 | " # print('embed', atom_embed.shape, 'prob', prob.shape)\n", 229 | " out = F.sum(prob)\n", 230 | " # return {'embed': atom_embed, 'out': out}\n", 231 | " return out\n" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "## Saliency calculation & visualization\n", 239 | "\n", 240 | "You can calculate saliency and visualize it in following steps:\n", 241 | "\n", 242 | "1. Instantiate Saliency Calculator\n", 243 | "2. `compute_xxx` (xxx is vanilla/smooth/bayes) method to calculate saliency (gradient etc)\n", 244 | "3. `transform` method to convert saliency into visualizable format.\n", 245 | "4. Use Visualizer class to visualize it." 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 9, 251 | "metadata": { 252 | "collapsed": false 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "# 1. instantiation\n", 257 | "gradient_calculator = GradientCalculator(\n", 258 | " classifier, eval_fun=eval_fun,\n", 259 | " # target_key='embed', eval_key='out',\n", 260 | " target_key=0,\n", 261 | ")" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 10, 267 | "metadata": { 268 | "collapsed": false 269 | }, 270 | "outputs": [ 271 | { 272 | "name": "stdout", 273 | "output_type": "stream", 274 | "text": [ 275 | "saliency_arrays (1, 38, 41, 16)\n", 276 | "saliency (38, 41)\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "M = 40\n", 282 | "rates = np.array(list(range(1, 11))) * 0.1\n", 283 | "\n", 284 | "# --- VanillaGrad ---\n", 285 | "# 2. compute\n", 286 | "saliency_arrays_vanilla = gradient_calculator.compute_vanilla(\n", 287 | " target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun)\n", 288 | "# 3. transform\n", 289 | "saliency_vanilla = gradient_calculator.transform(\n", 290 | " saliency_arrays_vanilla, ch_axis=3, method='square')\n", 291 | "# saliency_arrays (1, 28, 43, 64) -> M, minibatch, max_atom, ch_dim\n", 292 | "print('saliency_arrays', saliency_arrays_vanilla.shape)\n", 293 | "# saliency (28, 43) -> minibatch, max_atom\n", 294 | "print('saliency', saliency_vanilla.shape)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 11, 300 | "metadata": { 301 | "collapsed": false 302 | }, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "image/svg+xml": [ 307 | "\n", 308 | " \n", 309 | "\n", 310 | "\n", 311 | "\n", 312 | "\n", 313 | "\n", 314 | "\n", 315 | "\n", 316 | "\n", 317 | "\n", 318 | "\n", 319 | "\n", 320 | "\n", 321 | "\n", 322 | "\n", 323 | "\n", 324 | "\n", 325 | "\n", 326 | "\n", 327 | "\n", 328 | "\n", 329 | "\n", 330 | "\n", 331 | "\n", 332 | "\n", 333 | "\n", 334 | "\n", 335 | "\n", 336 | "\n", 337 | "\n", 338 | "\n", 339 | "\n", 340 | "\n", 341 | "\n", 342 | "\n", 343 | "\n", 344 | "\n", 345 | "\n", 346 | "\n", 347 | "\n", 348 | "\n", 349 | "\n", 350 | "\n", 351 | "\n", 352 | "\n", 353 | "\n", 354 | "\n", 355 | "\n", 356 | "\n", 357 | "\n", 358 | "\n", 359 | "\n", 360 | "\n", 361 | "\n", 362 | "\n", 363 | "\n", 364 | "OH\n", 365 | "OH\n", 366 | "" 367 | ], 368 | "text/plain": [ 369 | "" 370 | ] 371 | }, 372 | "execution_count": 11, 373 | "metadata": {}, 374 | "output_type": "execute_result" 375 | } 376 | ], 377 | "source": [ 378 | "from saliency.visualizer.smiles_visualizer import SmilesVisualizer\n", 379 | "\n", 380 | "sv = SmilesVisualizer()\n", 381 | "# 4. visualize\n", 382 | "index = 2\n", 383 | "sv.visualize(saliency_vanilla[index], target_smiles[index], visualize_ratio=0.3)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 12, 389 | "metadata": { 390 | "collapsed": false 391 | }, 392 | "outputs": [ 393 | { 394 | "data": { 395 | "image/svg+xml": [ 396 | "\n", 397 | " \n", 398 | "\n", 399 | "\n", 400 | "\n", 401 | "\n", 402 | "\n", 403 | "\n", 404 | "\n", 405 | "\n", 406 | "\n", 407 | "\n", 408 | "\n", 409 | "\n", 410 | "\n", 411 | "\n", 412 | "\n", 413 | "\n", 414 | "\n", 415 | "\n", 416 | "\n", 417 | "\n", 418 | "\n", 419 | "\n", 420 | "\n", 421 | "\n", 422 | "\n", 423 | "\n", 424 | "\n", 425 | "\n", 426 | "\n", 427 | "\n", 428 | "\n", 429 | "\n", 430 | "\n", 431 | "\n", 432 | "\n", 433 | "\n", 434 | "\n", 435 | "\n", 436 | "\n", 437 | "\n", 438 | "\n", 439 | "\n", 440 | "\n", 441 | "\n", 442 | "\n", 443 | "\n", 444 | "\n", 445 | "\n", 446 | "\n", 447 | "\n", 448 | "\n", 449 | "\n", 450 | "\n", 451 | "\n", 452 | "\n", 453 | "\n", 454 | "\n", 455 | "\n", 456 | "\n", 457 | "\n", 458 | "\n", 459 | "\n", 460 | "\n", 461 | "\n", 462 | "\n", 463 | "\n", 464 | "\n", 465 | "\n", 466 | "\n", 467 | "\n", 468 | "\n", 469 | "\n", 470 | "\n", 471 | "\n", 472 | "\n", 473 | "\n", 474 | "\n", 475 | "\n", 476 | "\n", 477 | "\n", 478 | "\n", 479 | "\n", 480 | "\n", 481 | "\n", 482 | "\n", 483 | "\n", 484 | "\n", 485 | "\n", 486 | "\n", 487 | "\n", 488 | "\n", 489 | "\n", 490 | "\n", 491 | "\n", 492 | "\n", 493 | "\n", 494 | "\n", 495 | "\n", 496 | "\n", 497 | "\n", 498 | "\n", 499 | "\n", 500 | "\n", 501 | "\n", 502 | "\n", 503 | "\n", 504 | "\n", 505 | "\n", 506 | "\n", 507 | "\n", 508 | "\n", 509 | "\n", 510 | "\n", 511 | "\n", 512 | "N\n", 513 | "N\n", 514 | "OH\n", 515 | "N\n", 516 | "N\n", 517 | "" 518 | ], 519 | "text/plain": [ 520 | "" 521 | ] 522 | }, 523 | "metadata": {}, 524 | "output_type": "display_data" 525 | }, 526 | { 527 | "data": { 528 | "text/plain": [ 529 | "" 530 | ] 531 | }, 532 | "execution_count": 12, 533 | "metadata": {}, 534 | "output_type": "execute_result" 535 | } 536 | ], 537 | "source": [ 538 | "# interactive plot demo\n", 539 | "\n", 540 | "def sv_visualize(i, ratio):\n", 541 | " return sv.visualize(saliency_vanilla[i], target_smiles[i], visualize_ratio=ratio)\n", 542 | "\n", 543 | "interact(sv_visualize, i=(0, len(saliency_vanilla) - 1, 1), ratio=(0, 1.0, 0.1))" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": 19, 549 | "metadata": { 550 | "collapsed": true 551 | }, 552 | "outputs": [], 553 | "source": [ 554 | "os.makedirs('results/visualize/srmmp', exist_ok=True)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 20, 560 | "metadata": { 561 | "collapsed": false 562 | }, 563 | "outputs": [], 564 | "source": [ 565 | "# --- SmoothGrad ---\n", 566 | "saliency_arrays_smooth = gradient_calculator.compute_smooth(\n", 567 | " target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,\n", 568 | " M=M, mode='absolute', scale=0.15)" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 21, 574 | "metadata": { 575 | "collapsed": false 576 | }, 577 | "outputs": [], 578 | "source": [ 579 | "# --- BayesGrad ---\n", 580 | "# `train=True` enables dropout, which corresponds to BayesGrad\n", 581 | "saliency_arrays_bayes = gradient_calculator.compute_vanilla(\n", 582 | " target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,\n", 583 | " M=M, train=True)" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 22, 589 | "metadata": { 590 | "collapsed": false 591 | }, 592 | "outputs": [ 593 | { 594 | "name": "stdout", 595 | "output_type": "stream", 596 | "text": [ 597 | "SR-MMP: 1\n" 598 | ] 599 | }, 600 | { 601 | "data": { 602 | "image/svg+xml": [ 603 | "\n", 604 | " \n", 605 | "\n", 606 | "\n", 607 | "\n", 608 | "\n", 609 | "\n", 610 | "\n", 611 | "\n", 612 | "\n", 613 | "\n", 614 | "\n", 615 | "\n", 616 | "\n", 617 | "\n", 618 | "\n", 619 | "\n", 620 | "\n", 621 | "\n", 622 | "\n", 623 | "\n", 624 | "\n", 625 | "\n", 626 | "\n", 627 | "\n", 628 | "\n", 629 | "\n", 630 | "\n", 631 | "\n", 632 | "\n", 633 | "\n", 634 | "\n", 635 | "\n", 636 | "\n", 637 | "\n", 638 | "\n", 639 | "\n", 640 | "\n", 641 | "\n", 642 | "\n", 643 | "\n", 644 | "\n", 645 | "\n", 646 | "\n", 647 | "\n", 648 | "\n", 649 | "\n", 650 | "\n", 651 | "\n", 652 | "\n", 653 | "\n", 654 | "\n", 655 | "\n", 656 | "\n", 657 | "\n", 658 | "\n", 659 | "\n", 660 | "\n", 661 | "\n", 662 | "\n", 663 | "\n", 664 | "\n", 665 | "\n", 666 | "\n", 667 | "\n", 668 | "\n", 669 | "\n", 670 | "\n", 671 | "\n", 672 | "\n", 673 | "\n", 674 | "\n", 675 | "\n", 676 | "\n", 677 | "\n", 678 | "\n", 679 | "\n", 680 | "\n", 681 | "\n", 682 | "\n", 683 | "\n", 684 | "\n", 685 | "\n", 686 | "\n", 687 | "\n", 688 | "\n", 689 | "\n", 690 | "\n", 691 | "\n", 692 | "\n", 693 | "\n", 694 | "\n", 695 | "\n", 696 | "\n", 697 | "\n", 698 | "\n", 699 | "\n", 700 | "\n", 701 | "\n", 702 | "\n", 703 | "\n", 704 | "\n", 705 | "\n", 706 | "\n", 707 | "\n", 708 | "\n", 709 | "\n", 710 | "\n", 711 | "\n", 712 | "\n", 713 | "\n", 714 | "\n", 715 | "\n", 716 | "\n", 717 | "\n", 718 | "\n", 719 | "N\n", 720 | "N\n", 721 | "OH\n", 722 | "N\n", 723 | "N\n", 724 | "" 725 | ], 726 | "text/plain": [ 727 | "" 728 | ] 729 | }, 730 | "metadata": {}, 731 | "output_type": "display_data" 732 | } 733 | ], 734 | "source": [ 735 | "# Single plot demo\n", 736 | "from IPython.display import display, HTML\n", 737 | "\n", 738 | "def sv_visualize(i, ratio, lam, method, view):\n", 739 | " print('SR-MMP: ', target_dataset.features[i, -1][0])\n", 740 | " saliency_bayes = gradient_calculator.transform(\n", 741 | " saliency_arrays_bayes, ch_axis=3, method=method, lam=lam)\n", 742 | " \n", 743 | " if view == 'view':\n", 744 | " svg_bayes = sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio)\n", 745 | " # display(svg_vanilla, svg_smooth, svg_bayes)\n", 746 | " display(svg_bayes)\n", 747 | " elif view == 'save':\n", 748 | " os.makedirs('results/visualize', exist_ok=True)\n", 749 | " sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_bayes.png'.format(i))\n", 750 | " else:\n", 751 | " print(view, 'not supported')\n", 752 | "\n", 753 | "interact(sv_visualize, i=(0, len(target_dataset) - 1, 1), ratio=(0, 1.0, 0.1), lam=(-3.0, 3.1, 0.1), method=['square', 'abs', 'raw'], view=['view', 'save'])" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": 25, 759 | "metadata": { 760 | "collapsed": false 761 | }, 762 | "outputs": [ 763 | { 764 | "name": "stdout", 765 | "output_type": "stream", 766 | "text": [ 767 | "SR-MMP: 1\n" 768 | ] 769 | }, 770 | { 771 | "data": { 772 | "image/svg+xml": [ 773 | "\n", 774 | " \n", 775 | "\n", 776 | "\n", 777 | "\n", 778 | "\n", 779 | "\n", 780 | "\n", 781 | "\n", 782 | "\n", 783 | "\n", 784 | "\n", 785 | "\n", 786 | "\n", 787 | "\n", 788 | "\n", 789 | "\n", 790 | "\n", 791 | "\n", 792 | "\n", 793 | "\n", 794 | "\n", 795 | "\n", 796 | "\n", 797 | "\n", 798 | "\n", 799 | "\n", 800 | "\n", 801 | "\n", 802 | "\n", 803 | "\n", 804 | "\n", 805 | "\n", 806 | "\n", 807 | "\n", 808 | "\n", 809 | "\n", 810 | "\n", 811 | "\n", 812 | "\n", 813 | "\n", 814 | "\n", 815 | "\n", 816 | "\n", 817 | "\n", 818 | "\n", 819 | "\n", 820 | "\n", 821 | "\n", 822 | "\n", 823 | "\n", 824 | "\n", 825 | "\n", 826 | "\n", 827 | "\n", 828 | "\n", 829 | "\n", 830 | "\n", 831 | "\n", 832 | "\n", 833 | "\n", 834 | "\n", 835 | "\n", 836 | "\n", 837 | "\n", 838 | "\n", 839 | "\n", 840 | "\n", 841 | "\n", 842 | "\n", 843 | "\n", 844 | "\n", 845 | "\n", 846 | "\n", 847 | "\n", 848 | "\n", 849 | "\n", 850 | "\n", 851 | "\n", 852 | "\n", 853 | "\n", 854 | "\n", 855 | "\n", 856 | "\n", 857 | "\n", 858 | "\n", 859 | "\n", 860 | "\n", 861 | "\n", 862 | "\n", 863 | "\n", 864 | "\n", 865 | "\n", 866 | "\n", 867 | "\n", 868 | "\n", 869 | "\n", 870 | "\n", 871 | "O\n", 872 | "N+\n", 873 | "O-\n", 874 | "N\n", 875 | "F\n", 876 | "N\n", 877 | "NH\n", 878 | "" 879 | ], 880 | "text/plain": [ 881 | "" 882 | ] 883 | }, 884 | "metadata": {}, 885 | "output_type": "display_data" 886 | } 887 | ], 888 | "source": [ 889 | "# Multiple plot demo\n", 890 | "from IPython.display import display, HTML\n", 891 | "\n", 892 | "def sv_visualize(i, ratio, lam, method, view):\n", 893 | " print('SR-MMP: ', target_dataset.features[i, -1][0])\n", 894 | " saliency_vanilla = gradient_calculator.transform(\n", 895 | " saliency_arrays_vanilla, ch_axis=3, method=method, lam=0)\n", 896 | " saliency_smooth = gradient_calculator.transform(\n", 897 | " saliency_arrays_smooth, ch_axis=3, method=method, lam=lam)\n", 898 | " saliency_bayes = gradient_calculator.transform(\n", 899 | " saliency_arrays_bayes, ch_axis=3, method=method, lam=lam)\n", 900 | " \n", 901 | " if view == 'view':\n", 902 | " svg_vanilla = sv.visualize(saliency_vanilla[i], target_smiles[i], visualize_ratio=ratio)\n", 903 | " svg_smooth = sv.visualize(saliency_smooth[i], target_smiles[i], visualize_ratio=ratio)\n", 904 | " svg_bayes = sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio)\n", 905 | " # display(svg_vanilla, svg_smooth, svg_bayes)\n", 906 | " display(svg_bayes)\n", 907 | " elif view == 'save':\n", 908 | " os.makedirs('results/visualize', exist_ok=True)\n", 909 | " sv.visualize(saliency_vanilla[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_vanilla.png'.format(i))\n", 910 | " sv.visualize(saliency_smooth[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_smooth.png'.format(i))\n", 911 | " sv.visualize(saliency_bayes[i], target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/{}_bayes.png'.format(i))\n", 912 | " else:\n", 913 | " print(view, 'not supported')\n", 914 | "\n", 915 | "interact(sv_visualize, i=(0, len(target_dataset) - 1, 1), ratio=(0, 1.0, 0.1), lam=(-3.0, 3.1, 0.1), method=['square', 'abs', 'raw'], view=['view', 'save'])" 916 | ] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "execution_count": null, 921 | "metadata": { 922 | "collapsed": true 923 | }, 924 | "outputs": [], 925 | "source": [] 926 | }, 927 | { 928 | "cell_type": "markdown", 929 | "metadata": {}, 930 | "source": [ 931 | "## Visualization of specific molecules, including Tyrphostin 9" 932 | ] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "execution_count": 17, 937 | "metadata": { 938 | "collapsed": false 939 | }, 940 | "outputs": [ 941 | { 942 | "name": "stderr", 943 | "output_type": "stream", 944 | "text": [ 945 | "100%|██████████| 4/4 [00:00<00:00, 378.38it/s]\n", 946 | "INFO:chainer_chemistry.dataset.parsers.csv_file_parser:Preprocess finished. FAIL 0, SUCCESS 4, TOTAL 4\n" 947 | ] 948 | } 949 | ], 950 | "source": [ 951 | "import tempfile\n", 952 | "import pandas\n", 953 | "\n", 954 | "from chainer_chemistry.dataset.parsers import CSVFileParser\n", 955 | "from chainer_chemistry.dataset.preprocessors import preprocess_method_dict\n", 956 | "\n", 957 | "smiles_list = [\n", 958 | " 'CC(C)(C=O)Cc1cc(C(C)(C)C)c(O)c(C(C)(C)C)c1', # (a)\n", 959 | "# 'CC(C)(C=O)Cc1cc(C(C)(C)C)c(C)c(C(C)(C)C)c1',\n", 960 | " 'CC(C)(C)C1=C(C)C(C(C)(C)C)=CC(/C=C(C#N)\\C#N)=C1', # (b)\n", 961 | " 'CC(C)(C)C1=C(O)C(C(C)(C)C)=CC(/C=C(C#N)\\C#N)=C1', # (c)\n", 962 | " 'O=C1Nc2ccc(I)cc2C1=Cc1cc(Br)c(O)c(Br)c1' # (d) \n", 963 | "]\n", 964 | "\n", 965 | "preprocessor = preprocess_method_dict['nfp']()\n", 966 | "parser = CSVFileParser(preprocessor,\n", 967 | " labels=None, smiles_col='smiles')\n", 968 | "\n", 969 | "with tempfile.TemporaryDirectory() as dirpath:\n", 970 | " csv_path = os.path.join(dirpath, 'tmp.csv')\n", 971 | " df = pandas.DataFrame({'smiles': smiles_list})\n", 972 | " df.to_csv(csv_path)\n", 973 | " result = parser.parse(csv_path, return_smiles=True)\n", 974 | "\n", 975 | "custom_target_dataset, custom_target_smiles = result['dataset'], result['smiles']" 976 | ] 977 | }, 978 | { 979 | "cell_type": "code", 980 | "execution_count": 26, 981 | "metadata": { 982 | "collapsed": false 983 | }, 984 | "outputs": [ 985 | { 986 | "data": { 987 | "image/svg+xml": [ 988 | "\n", 989 | " \n", 990 | "\n", 991 | "\n", 992 | "\n", 993 | "\n", 994 | "\n", 995 | "\n", 996 | "\n", 997 | "\n", 998 | "\n", 999 | "\n", 1000 | "\n", 1001 | "\n", 1002 | "\n", 1003 | "\n", 1004 | "\n", 1005 | "\n", 1006 | "\n", 1007 | "\n", 1008 | "\n", 1009 | "\n", 1010 | "\n", 1011 | "\n", 1012 | "\n", 1013 | "\n", 1014 | "\n", 1015 | "\n", 1016 | "\n", 1017 | "\n", 1018 | "\n", 1019 | "\n", 1020 | "\n", 1021 | "\n", 1022 | "\n", 1023 | "\n", 1024 | "\n", 1025 | "\n", 1026 | "\n", 1027 | "\n", 1028 | "\n", 1029 | "\n", 1030 | "\n", 1031 | "\n", 1032 | "\n", 1033 | "\n", 1034 | "\n", 1035 | "\n", 1036 | "\n", 1037 | "\n", 1038 | "\n", 1039 | "\n", 1040 | "\n", 1041 | "\n", 1042 | "\n", 1043 | "\n", 1044 | "\n", 1045 | "\n", 1046 | "\n", 1047 | "\n", 1048 | "O\n", 1049 | "OH\n", 1050 | "" 1051 | ], 1052 | "text/plain": [ 1053 | "" 1054 | ] 1055 | }, 1056 | "metadata": {}, 1057 | "output_type": "display_data" 1058 | } 1059 | ], 1060 | "source": [ 1061 | "# --- BayesGrad ---\n", 1062 | "# `train=True` enables dropout, which corresponds to BayesGrad\n", 1063 | "saliency_arrays_bayes = gradient_calculator.compute_vanilla(\n", 1064 | " custom_target_dataset, converter=concat_mols, preprocess_fn=preprocess_fun,\n", 1065 | " M=M, train=True)\n", 1066 | "\n", 1067 | "def sv_visualize(i, ratio, lam, method, view):\n", 1068 | " saliency_bayes = gradient_calculator.transform(\n", 1069 | " saliency_arrays_bayes, ch_axis=3, method=method, lam=lam)\n", 1070 | " \n", 1071 | " if view == 'view':\n", 1072 | " svg_bayes = sv.visualize(saliency_bayes[i], custom_target_smiles[i], visualize_ratio=ratio)\n", 1073 | " display(svg_bayes)\n", 1074 | " elif view == 'save':\n", 1075 | " os.makedirs('results/visualize', exist_ok=True)\n", 1076 | " sv.visualize(saliency_bayes[i], custom_target_smiles[i], visualize_ratio=ratio, save_filepath='results/visualize/srmmp/custom_{}_bayes.png'.format(i))\n", 1077 | " else:\n", 1078 | " print(view, 'not supported')\n", 1079 | "\n", 1080 | "interact(sv_visualize, i=(0, len(custom_target_dataset) - 1, 1), ratio=(0, 1.0, 0.1), lam=(-3.0, 3.1, 0.1), method=['square', 'abs', 'raw'], view=['view', 'save'])" 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "code", 1085 | "execution_count": null, 1086 | "metadata": { 1087 | "collapsed": true 1088 | }, 1089 | "outputs": [], 1090 | "source": [] 1091 | }, 1092 | { 1093 | "cell_type": "code", 1094 | "execution_count": null, 1095 | "metadata": { 1096 | "collapsed": true 1097 | }, 1098 | "outputs": [], 1099 | "source": [] 1100 | } 1101 | ], 1102 | "metadata": { 1103 | "anaconda-cloud": {}, 1104 | "kernelspec": { 1105 | "display_name": "Python [default]", 1106 | "language": "python", 1107 | "name": "python3" 1108 | }, 1109 | "language_info": { 1110 | "codemirror_mode": { 1111 | "name": "ipython", 1112 | "version": 3 1113 | }, 1114 | "file_extension": ".py", 1115 | "mimetype": "text/x-python", 1116 | "name": "python", 1117 | "nbconvert_exporter": "python", 1118 | "pygments_lexer": "ipython3", 1119 | "version": "3.5.2" 1120 | }, 1121 | "widgets": { 1122 | "state": { 1123 | "4ddc337df552443b94c357e8015dae4f": { 1124 | "views": [ 1125 | { 1126 | "cell_index": 21.0 1127 | } 1128 | ] 1129 | }, 1130 | "86b5e4662cd34a549c070dc3e90515b3": { 1131 | "views": [ 1132 | { 1133 | "cell_index": 12.0 1134 | } 1135 | ] 1136 | }, 1137 | "8fa97c75fec14b5abc88618435d810c2": { 1138 | "views": [ 1139 | { 1140 | "cell_index": 17.0 1141 | } 1142 | ] 1143 | }, 1144 | "cb91b59b01c44714911d7b8b48dc954e": { 1145 | "views": [ 1146 | { 1147 | "cell_index": 16.0 1148 | } 1149 | ] 1150 | }, 1151 | "ec91304623f74358bd014a7ff21f88f4": { 1152 | "views": [ 1153 | { 1154 | "cell_index": 4.0 1155 | } 1156 | ] 1157 | } 1158 | }, 1159 | "version": "1.2.0" 1160 | } 1161 | }, 1162 | "nbformat": 4, 1163 | "nbformat_minor": 2 1164 | } 1165 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/models/__init__.py -------------------------------------------------------------------------------- /models/ggnn_drop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from chainer chemistry v0.3.0 3 | Only the difference is to add dropout function 4 | """ 5 | import chainer 6 | from chainer import cuda 7 | from chainer import functions 8 | from chainer import links 9 | 10 | import chainer_chemistry 11 | from chainer_chemistry.config import MAX_ATOMIC_NUM 12 | from chainer_chemistry.links import EmbedAtomID 13 | from chainer_chemistry.links import GraphLinear 14 | 15 | 16 | class GGNNDrop(chainer.Chain): 17 | """Gated Graph Neural Networks (GGNN) 18 | 19 | See: Li, Y., Tarlow, D., Brockschmidt, M., & Zemel, R. (2015).\ 20 | Gated graph sequence neural networks. \ 21 | `arXiv:1511.05493 `_ 22 | 23 | Args: 24 | out_dim (int): dimension of output feature vector 25 | hidden_dim (int): dimension of feature vector 26 | associated to each atom 27 | n_layers (int): number of layers 28 | n_atom_types (int): number of types of atoms 29 | concat_hidden (bool): If set to True, readout is executed in each layer 30 | and the result is concatenated 31 | weight_tying (bool): enable weight_tying or not 32 | 33 | """ 34 | NUM_EDGE_TYPE = 4 35 | 36 | def __init__(self, out_dim, hidden_dim=16, 37 | n_layers=4, n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, 38 | weight_tying=True, dropout_ratio=0): 39 | super(GGNNDrop, self).__init__() 40 | n_readout_layer = n_layers if concat_hidden else 1 41 | n_message_layer = 1 if weight_tying else n_layers 42 | with self.init_scope(): 43 | # Update 44 | self.embed = EmbedAtomID(out_size=hidden_dim, in_size=n_atom_types) 45 | self.message_layers = chainer.ChainList( 46 | *[GraphLinear(hidden_dim, self.NUM_EDGE_TYPE * hidden_dim) 47 | for _ in range(n_message_layer)] 48 | ) 49 | self.update_layer = links.GRU(2 * hidden_dim, hidden_dim) 50 | # Readout 51 | self.i_layers = chainer.ChainList( 52 | *[GraphLinear(2 * hidden_dim, out_dim) 53 | for _ in range(n_readout_layer)] 54 | ) 55 | self.j_layers = chainer.ChainList( 56 | *[GraphLinear(hidden_dim, out_dim) 57 | for _ in range(n_readout_layer)] 58 | ) 59 | self.out_dim = out_dim 60 | self.hidden_dim = hidden_dim 61 | self.n_layers = n_layers 62 | self.concat_hidden = concat_hidden 63 | self.weight_tying = weight_tying 64 | self.dropout_ratio = dropout_ratio 65 | 66 | def update(self, h, adj, step=0): 67 | # --- Message & Update part --- 68 | # (minibatch, atom, ch) 69 | mb, atom, ch = h.shape 70 | out_ch = ch 71 | message_layer_index = 0 if self.weight_tying else step 72 | m = functions.reshape(self.message_layers[message_layer_index](h), 73 | (mb, atom, out_ch, self.NUM_EDGE_TYPE)) 74 | # m: (minibatch, atom, ch, edge_type) 75 | # Transpose 76 | m = functions.transpose(m, (0, 3, 1, 2)) 77 | # m: (minibatch, edge_type, atom, ch) 78 | 79 | adj = functions.reshape(adj, (mb * self.NUM_EDGE_TYPE, atom, atom)) 80 | # (minibatch * edge_type, atom, out_ch) 81 | m = functions.reshape(m, (mb * self.NUM_EDGE_TYPE, atom, out_ch)) 82 | 83 | m = chainer_chemistry.functions.matmul(adj, m) 84 | 85 | # (minibatch * edge_type, atom, out_ch) 86 | m = functions.reshape(m, (mb, self.NUM_EDGE_TYPE, atom, out_ch)) 87 | # Take sum 88 | m = functions.sum(m, axis=1) 89 | # (minibatch, atom, out_ch) 90 | 91 | # --- Update part --- 92 | # Contraction 93 | h = functions.reshape(h, (mb * atom, ch)) 94 | 95 | # Contraction 96 | m = functions.reshape(m, (mb * atom, ch)) 97 | 98 | out_h = self.update_layer(functions.concat((h, m), axis=1)) 99 | # Expansion 100 | out_h = functions.reshape(out_h, (mb, atom, ch)) 101 | return out_h 102 | 103 | def readout(self, h, h0, step=0): 104 | # --- Readout part --- 105 | index = step if self.concat_hidden else 0 106 | # h, h0: (minibatch, atom, ch) 107 | g = functions.sigmoid( 108 | self.i_layers[index](functions.concat((h, h0), axis=2))) \ 109 | * self.j_layers[index](h) 110 | g = functions.sum(g, axis=1) # sum along atom's axis 111 | return g 112 | 113 | def __call__(self, atom_array, adj): 114 | """Forward propagation 115 | 116 | Args: 117 | atom_array (numpy.ndarray): minibatch of molecular which is 118 | represented with atom IDs (representing C, O, S, ...) 119 | `atom_array[mol_index, atom_index]` represents `mol_index`-th 120 | molecule's `atom_index`-th atomic number 121 | adj (numpy.ndarray): minibatch of adjancency matrix with edge-type 122 | information 123 | 124 | Returns: 125 | ~chainer.Variable: minibatch of fingerprint 126 | """ 127 | # reset state 128 | self.update_layer.reset_state() 129 | if atom_array.dtype == self.xp.int32: 130 | h = self.embed(atom_array) # (minibatch, max_num_atoms) 131 | else: 132 | h = atom_array 133 | h0 = functions.copy(h, cuda.get_device_from_array(h.data).id) 134 | g_list = [] 135 | for step in range(self.n_layers): 136 | h = chainer.functions.dropout(h, ratio=self.dropout_ratio) 137 | h = self.update(h, adj, step) 138 | if self.concat_hidden: 139 | g = self.readout(h, h0, step) 140 | g_list.append(g) 141 | 142 | if self.concat_hidden: 143 | return functions.concat(g_list, axis=1) 144 | else: 145 | g = self.readout(h, h0, 0) 146 | return g 147 | -------------------------------------------------------------------------------- /models/nfp_drop.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Neural Fingerprint 3 | 4 | Copied from chainer chemistry v0.3.0 5 | Only the difference is to add dropout function 6 | """ 7 | import chainer 8 | from chainer import cuda, Variable 9 | from chainer import functions 10 | import numpy 11 | 12 | import chainer_chemistry 13 | from chainer_chemistry.config import MAX_ATOMIC_NUM 14 | 15 | 16 | class NFPUpdate(chainer.Chain): 17 | """NFP sub module for update part 18 | 19 | Args: 20 | in_channels (int): input channel dimension 21 | out_channels (int): output channel dimension 22 | max_degree (int): max degree of edge 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels, max_degree=6): 26 | super(NFPUpdate, self).__init__() 27 | num_degree_type = max_degree + 1 28 | with self.init_scope(): 29 | self.graph_linears = chainer.ChainList( 30 | *[chainer_chemistry.links.GraphLinear(in_channels, out_channels) 31 | for _ in range(num_degree_type)] 32 | ) 33 | self.max_degree = max_degree 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | 37 | def __call__(self, h, adj, deg_conds): 38 | # h (minibatch, atom, ch) 39 | # h encodes each atom's info in ch axis of size hidden_dim 40 | # adjs (minibatch, atom, atom) 41 | 42 | # --- Message part --- 43 | # Take sum along adjacent atoms 44 | 45 | # fv (minibatch, atom, ch) 46 | fv = chainer_chemistry.functions.matmul(adj, h) 47 | 48 | # --- Update part --- 49 | # s0, s1, s2 = fv.shape 50 | if self.xp is numpy: 51 | zero_array = numpy.zeros(fv.shape, dtype=numpy.float32) 52 | else: 53 | zero_array = self.xp.zeros_like(fv) 54 | 55 | fvds = [functions.where(cond, fv, zero_array) for cond in deg_conds] 56 | 57 | out_h = 0 58 | for graph_linear, fvd in zip(self.graph_linears, fvds): 59 | out_h = out_h + graph_linear(fvd) 60 | 61 | # out_x shape (minibatch, max_num_atoms, hidden_dim) 62 | out_h = functions.sigmoid(out_h) 63 | return out_h 64 | 65 | 66 | class NFPReadout(chainer.Chain): 67 | """NFP sub module for readout part 68 | 69 | Args: 70 | in_channels (int): dimension of feature vector associated to each 71 | atom (node) 72 | out_size (int): output dimension of feature vector associated to each 73 | molecule (graph) 74 | """ 75 | 76 | def __init__(self, in_channels, out_size): 77 | super(NFPReadout, self).__init__() 78 | with self.init_scope(): 79 | self.output_weight = chainer_chemistry.links.GraphLinear( 80 | in_channels, out_size) 81 | self.in_channels = in_channels 82 | self.out_size = out_size 83 | 84 | def __call__(self, h): 85 | # input h shape (minibatch, atom, ch) 86 | # return i shape (minibatch, ch) 87 | 88 | # --- Readout part --- 89 | i = self.output_weight(h) 90 | i = functions.softmax(i, axis=2) # softmax along channel axis 91 | i = functions.sum(i, axis=1) # sum along atom's axis 92 | return i 93 | 94 | 95 | class NFPDrop(chainer.Chain): 96 | 97 | """Neural Finger Print (NFP) 98 | 99 | Args: 100 | out_dim (int): dimension of output feature vector 101 | hidden_dim (int): dimension of feature vector 102 | associated to each atom 103 | max_degree (int): max degree of atoms 104 | when molecules are regarded as graphs 105 | n_atom_types (int): number of types of atoms 106 | n_layer (int): number of layers 107 | """ 108 | 109 | def __init__(self, out_dim, hidden_dim=16, n_layers=4, max_degree=6, 110 | n_atom_types=MAX_ATOMIC_NUM, concat_hidden=False, 111 | dropout_ratio=0): 112 | super(NFPDrop, self).__init__() 113 | num_degree_type = max_degree + 1 114 | with self.init_scope(): 115 | self.embed = chainer_chemistry.links.EmbedAtomID( 116 | in_size=n_atom_types, out_size=hidden_dim) 117 | self.layers = chainer.ChainList( 118 | *[NFPUpdate(hidden_dim, hidden_dim, max_degree=max_degree) 119 | for _ in range(n_layers)]) 120 | self.read_out_layers = chainer.ChainList( 121 | *[NFPReadout(hidden_dim, out_dim) 122 | for _ in range(n_layers)]) 123 | self.out_dim = out_dim 124 | self.hidden_dim = hidden_dim 125 | self.max_degree = max_degree 126 | self.num_degree_type = num_degree_type 127 | self.n_layers = n_layers 128 | self.concat_hidden = concat_hidden 129 | self.dropout_ratio = dropout_ratio 130 | 131 | def __call__(self, atom_array, adj): 132 | """Forward propagation 133 | 134 | Args: 135 | atom_array (numpy.ndarray): minibatch of molecular which is 136 | represented with atom IDs (representing C, O, S, ...) 137 | `atom_array[mol_index, atom_index]` represents `mol_index`-th 138 | molecule's `atom_index`-th atomic number 139 | adj (numpy.ndarray): minibatch of adjancency matrix 140 | `adj[mol_index]` represents `mol_index`-th molecule's 141 | adjacency matrix 142 | 143 | Returns: 144 | ~chainer.Variable: minibatch of fingerprint 145 | """ 146 | if atom_array.dtype == self.xp.int32: 147 | # atom_array: (minibatch, atom) 148 | h = self.embed(atom_array) 149 | else: 150 | h = atom_array 151 | # h: (minibatch, atom, ch) 152 | g = 0 153 | 154 | # --- NFP update & readout --- 155 | # degree_mat: (minibatch, max_num_atoms) 156 | if isinstance(adj, Variable): 157 | adj_array = adj.data 158 | else: 159 | adj_array = adj 160 | degree_mat = self.xp.sum(adj_array, axis=1) 161 | # deg_condst: (minibatch, atom, ch) 162 | deg_conds = [self.xp.broadcast_to( 163 | ((degree_mat - degree) == 0)[:, :, None], h.shape) 164 | for degree in range(1, self.num_degree_type + 1)] 165 | g_list = [] 166 | for update, readout in zip(self.layers, self.read_out_layers): 167 | h = chainer.functions.dropout(h, ratio=self.dropout_ratio) 168 | h = update(h, adj, deg_conds) 169 | dg = readout(h) 170 | g = g + dg 171 | if self.concat_hidden: 172 | g_list.append(g) 173 | 174 | if self.concat_hidden: 175 | return functions.concat(g_list, axis=2) 176 | else: 177 | return g 178 | -------------------------------------------------------------------------------- /models/predictor.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import functions as F 3 | import chainer.links as L 4 | import sys 5 | import os 6 | 7 | from chainer_chemistry.models import GGNN 8 | from chainer_chemistry.models import NFP 9 | from chainer_chemistry.models import SchNet 10 | from chainer_chemistry.models import WeaveNet 11 | 12 | sys.path.append(os.path.dirname(__file__)) 13 | from models.nfp_drop import NFPDrop 14 | from models.ggnn_drop import GGNNDrop 15 | 16 | 17 | class MLPDrop(chainer.Chain): 18 | """Basic implementation for MLP with dropout""" 19 | # def __init__(self, hidden_dim, out_dim, n_layers=2, activation=F.relu): 20 | def __init__(self, out_dim, hidden_dim, n_layers=1, activation=F.relu, 21 | dropout_ratio=0.25): 22 | super(MLPDrop, self).__init__() 23 | if n_layers <= 0: 24 | raise ValueError('n_layers must be positive integer, but set {}' 25 | .format(n_layers)) 26 | layers = [L.Linear(None, hidden_dim) for i in range(n_layers - 1)] 27 | with self.init_scope(): 28 | self.layers = chainer.ChainList(*layers) 29 | self.l_out = L.Linear(None, out_dim) 30 | self.activation = activation 31 | self.dropout_ratio = dropout_ratio 32 | 33 | def __call__(self, x): 34 | h = F.dropout(x, ratio=self.dropout_ratio) 35 | for l in self.layers: 36 | h = F.dropout(self.activation(l(h)), ratio=self.dropout_ratio) 37 | h = self.l_out(h) 38 | return h 39 | 40 | 41 | def build_predictor(method, n_unit, conv_layers, class_num, 42 | dropout_ratio=0.25, n_layers=1): 43 | print('dropout_ratio, n_layers', dropout_ratio, n_layers) 44 | mlp_class = MLPDrop 45 | if method == 'nfp': 46 | print('Use NFP predictor...') 47 | predictor = GraphConvPredictor( 48 | NFP(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers), 49 | mlp_class(out_dim=class_num, hidden_dim=n_unit, dropout_ratio=dropout_ratio, 50 | n_layers=n_layers)) 51 | elif method == 'nfpdrop': 52 | print('Use NFPDrop predictor...') 53 | predictor = GraphConvPredictor( 54 | NFPDrop(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers, 55 | dropout_ratio=dropout_ratio), 56 | mlp_class(out_dim=class_num, hidden_dim=n_unit, 57 | dropout_ratio=dropout_ratio, 58 | n_layers=n_layers)) 59 | elif method == 'ggnn': 60 | print('Use GGNN predictor...') 61 | predictor = GraphConvPredictor( 62 | GGNN(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers), 63 | mlp_class(out_dim=class_num, hidden_dim=n_unit, 64 | dropout_ratio=dropout_ratio, n_layers=n_layers)) 65 | elif method == 'ggnndrop': 66 | print('Use GGNNDrop predictor...') 67 | predictor = GraphConvPredictor( 68 | GGNNDrop(out_dim=n_unit, hidden_dim=n_unit, n_layers=conv_layers, 69 | dropout_ratio=dropout_ratio), 70 | mlp_class(out_dim=class_num, hidden_dim=n_unit, 71 | dropout_ratio=dropout_ratio, n_layers=n_layers)) 72 | elif method == 'schnet': 73 | print('Use SchNet predictor...') 74 | predictor = SchNet(out_dim=class_num, hidden_dim=n_unit, 75 | n_layers=conv_layers, readout_hidden_dim=n_unit) 76 | elif method == 'weavenet': 77 | print('Use WeaveNet predictor...') 78 | n_atom = 20 79 | n_sub_layer = 1 80 | weave_channels = [50] * conv_layers 81 | predictor = GraphConvPredictor( 82 | WeaveNet(weave_channels=weave_channels, hidden_dim=n_unit, 83 | n_sub_layer=n_sub_layer, n_atom=n_atom), 84 | mlp_class(out_dim=class_num, hidden_dim=n_unit, 85 | dropout_ratio=dropout_ratio, n_layers=n_layers)) 86 | else: 87 | raise ValueError('[ERROR] Invalid predictor: method={}'.format(method)) 88 | return predictor 89 | 90 | 91 | class GraphConvPredictor(chainer.Chain): 92 | """Wrapper class that combines a graph convolution and MLP.""" 93 | 94 | def __init__(self, graph_conv, mlp): 95 | """Constructor 96 | 97 | Args: 98 | graph_conv: graph convolution network to obtain molecule feature 99 | representation 100 | mlp: multi layer perceptron, used as final connected layer 101 | """ 102 | 103 | super(GraphConvPredictor, self).__init__() 104 | with self.init_scope(): 105 | self.graph_conv = graph_conv 106 | self.mlp = mlp 107 | 108 | def __call__(self, atoms, adjs): 109 | x = self.graph_conv(atoms, adjs) 110 | x = self.mlp(x) 111 | return x 112 | 113 | def predict(self, atoms, adjs): 114 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 115 | x = self.__call__(atoms, adjs) 116 | return F.sigmoid(x) 117 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chainer==4.2.0 2 | chainer-chemistry==0.4.0 3 | matplotlib==2.2.2 4 | future==0.16.0 5 | cairosvg==2.1.3 6 | ipython==5.1.0 7 | -------------------------------------------------------------------------------- /saliency/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/saliency/__init__.py -------------------------------------------------------------------------------- /saliency/calculator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/saliency/calculator/__init__.py -------------------------------------------------------------------------------- /saliency/calculator/base_calculator.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABCMeta 3 | from abc import abstractmethod 4 | 5 | import chainer 6 | from chainer import cuda 7 | from chainer.dataset.convert import concat_examples, \ 8 | _concat_arrays_with_padding 9 | from chainer.iterators import SerialIterator 10 | 11 | from future.utils import with_metaclass 12 | import numpy 13 | 14 | 15 | _sampling_axis = 0 16 | 17 | 18 | def _to_tuple(x): 19 | if not isinstance(x, tuple): 20 | x = (x,) 21 | return x 22 | 23 | 24 | def _to_variable(x): 25 | if not isinstance(x, chainer.Variable): 26 | x = chainer.Variable(x) 27 | return x 28 | 29 | 30 | def _extract_numpy(x): 31 | if isinstance(x, chainer.Variable): 32 | x = x.data 33 | return cuda.to_cpu(x) 34 | 35 | 36 | class BaseCalculator(with_metaclass(ABCMeta, object)): 37 | 38 | def __init__(self, model): 39 | self.model = model # type: chainer.Chain 40 | self._device = cuda.get_device_from_array(*model.params()).id 41 | # print('device', self._device) 42 | 43 | def compute( 44 | self, data, M=1, method='vanilla', batchsize=16, 45 | converter=concat_examples, retain_inputs=False, preprocess_fn=None, 46 | postprocess_fn=None, train=False): 47 | method_dict = { 48 | 'vanilla': self.compute_vanilla, 49 | 'smooth': self.compute_smooth, 50 | 'bayes': self.compute_bayes, 51 | } 52 | return method_dict[method]( 53 | data, batchsize=batchsize, M=M, converter=converter, 54 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn, 55 | postprocess_fn=postprocess_fn, train=train) 56 | 57 | def compute_vanilla(self, data, batchsize=16, M=1, 58 | converter=concat_examples, retain_inputs=False, 59 | preprocess_fn=None, postprocess_fn=None, train=False): 60 | """VanillaGrad""" 61 | saliency_list = [] 62 | for _ in range(M): 63 | with chainer.using_config('train', train): 64 | saliency = self._forward( 65 | data, fn=self._compute_core, batchsize=batchsize, 66 | converter=converter, 67 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn, 68 | postprocess_fn=postprocess_fn) 69 | saliency_list.append(cuda.to_cpu(saliency)) 70 | return numpy.stack(saliency_list, axis=_sampling_axis) 71 | 72 | def compute_smooth(self, data, M=10, batchsize=16, 73 | converter=concat_examples, retain_inputs=False, 74 | preprocess_fn=None, postprocess_fn=None, train=False, 75 | scale=0.15, mode='relative'): 76 | """SmoothGrad 77 | Reference 78 | https://github.com/PAIR-code/saliency/blob/master/saliency/base.py#L54 79 | """ 80 | 81 | def smooth_fn(*inputs): 82 | #TODO: support cupy input 83 | target_array = inputs[self.target_key].data 84 | xp = cuda.get_array_module(target_array) 85 | 86 | noise = xp.random.normal( 87 | 0, scale, inputs[self.target_key].data.shape) 88 | if mode == 'absolute': 89 | # `scale` is used as is 90 | pass 91 | elif mode == 'relative': 92 | # `scale_axis` is used to calculate `max` and `min` of target_array 93 | # As default, all axes except batch axis are treated as `scale_axis`. 94 | scale_axis = tuple(range(1, target_array.ndim)) 95 | noise = noise * (xp.max(target_array, axis=scale_axis, keepdims=True) 96 | - xp.min(target_array, axis=scale_axis, keepdims=True)) 97 | # print('[DEBUG] noise', noise.shape) 98 | else: 99 | raise ValueError("[ERROR] Unexpected value mode={}" 100 | .format(mode)) 101 | inputs[self.target_key].data += noise 102 | return self._compute_core(*inputs) 103 | 104 | saliency_list = [] 105 | for _ in range(M): 106 | with chainer.using_config('train', train): 107 | saliency = self._forward( 108 | data, fn=smooth_fn, batchsize=batchsize, 109 | converter=converter, 110 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn, 111 | postprocess_fn=postprocess_fn) 112 | saliency_array = cuda.to_cpu(saliency) 113 | saliency_list.append(saliency_array) 114 | return numpy.stack(saliency_list, axis=_sampling_axis) 115 | 116 | def compute_bayes(self, data, M=10, batchsize=16, 117 | converter=concat_examples, retain_inputs=False, 118 | preprocess_fn=None, postprocess_fn=None, train=True): 119 | """BayesGrad""" 120 | warnings.warn('`compute_bayes` method maybe deleted in the future...' 121 | 'please use `compute_vanilla` with train=True instead.') 122 | assert train 123 | # This is actually just an alias of `compute_vanilla` with `train=True` 124 | # Maybe deleted in the future. 125 | return self.compute_vanilla( 126 | data, M=M, batchsize=batchsize, converter=converter, 127 | retain_inputs=retain_inputs, preprocess_fn=preprocess_fn, 128 | postprocess_fn=postprocess_fn, train=True) 129 | 130 | def transform(self, saliency_arrays, method='raw', lam=0, ch_axis=2): 131 | if method == 'raw': 132 | h = numpy.sum(saliency_arrays, axis=ch_axis) 133 | elif method == 'abs': 134 | h = numpy.sum(numpy.abs(saliency_arrays), axis=ch_axis) 135 | elif method == 'square': 136 | h = numpy.sum(saliency_arrays ** 2, axis=ch_axis) 137 | else: 138 | raise ValueError('') 139 | 140 | sampling_axis = _sampling_axis 141 | if lam == 0: 142 | return numpy.mean(h, axis=sampling_axis) 143 | else: 144 | if h.shape[sampling_axis] == 1: 145 | # VanillaGrad does not support LCB/UCB calculation 146 | raise ValueError( 147 | 'saliency_arrays.shape[{}] must be larget than 1'.format(sampling_axis)) 148 | return numpy.mean(h, axis=sampling_axis) + lam * numpy.std( 149 | h, axis=sampling_axis) 150 | 151 | @abstractmethod 152 | def _compute_core(self, *inputs): 153 | raise NotImplementedError 154 | 155 | def _forward(self, data, fn=None, batchsize=16, 156 | converter=concat_examples, retain_inputs=False, 157 | preprocess_fn=None, postprocess_fn=None): 158 | """Forward data by iterating with batch 159 | 160 | Args: 161 | data: "train_x array" or "chainer dataset" 162 | fn (Callable): Main function to forward. Its input argument is 163 | either Variable, cupy.ndarray or numpy.ndarray, and returns 164 | Variable. 165 | batchsize (int): batch size 166 | converter (Callable): convert from `data` to `inputs` 167 | retain_inputs (bool): If True, this instance keeps inputs in 168 | `self.inputs` or not. 169 | preprocess_fn (Callable): Its input is numpy.ndarray or 170 | cupy.ndarray, it can return either Variable, cupy.ndarray or 171 | numpy.ndarray 172 | postprocess_fn (Callable): Its input argument is Variable, 173 | but this method may return either Variable, cupy.ndarray or 174 | numpy.ndarray. 175 | 176 | Returns (tuple or numpy.ndarray): forward result 177 | 178 | """ 179 | input_list = None 180 | output_list = None 181 | it = SerialIterator(data, batch_size=batchsize, repeat=False, 182 | shuffle=False) 183 | for batch in it: 184 | inputs = converter(batch, self._device) 185 | inputs = _to_tuple(inputs) 186 | 187 | if preprocess_fn: 188 | inputs = preprocess_fn(*inputs) 189 | inputs = _to_tuple(inputs) 190 | 191 | inputs = (_to_variable(x) for x in inputs) 192 | 193 | outputs = fn(*inputs) 194 | 195 | # Init 196 | if retain_inputs: 197 | if input_list is None: 198 | input_list = [[] for _ in range(len(inputs))] 199 | for j, input in enumerate(inputs): 200 | input_list[j].append(cuda.to_cpu(input)) 201 | if output_list is None: 202 | output_list = [[] for _ in range(len(outputs))] 203 | 204 | if postprocess_fn: 205 | outputs = postprocess_fn(*outputs) 206 | outputs = _to_tuple(outputs) 207 | for j, output in enumerate(outputs): 208 | output_list[j].append(_extract_numpy(output)) 209 | 210 | if retain_inputs: 211 | self.inputs = [numpy.concatenate( 212 | in_array) for in_array in input_list] 213 | 214 | result = [_concat(output) for output in output_list] 215 | 216 | # result = [numpy.concatenate(output) for output in output_list] 217 | if len(result) == 1: 218 | return result[0] 219 | else: 220 | return result 221 | 222 | 223 | def _concat(batch_list): 224 | try: 225 | return numpy.concatenate(batch_list) 226 | except Exception as e: 227 | # Thre is a case that each input has different shape, 228 | # we cannot concatenate into array in this case. 229 | 230 | elem_list = [elem for batch in batch_list for elem in batch] 231 | return _concat_arrays_with_padding(elem_list, padding=0) 232 | -------------------------------------------------------------------------------- /saliency/calculator/gradient_calculator.py: -------------------------------------------------------------------------------- 1 | from saliency.calculator.base_calculator import BaseCalculator 2 | 3 | 4 | class GradientCalculator(BaseCalculator): 5 | 6 | def __init__(self, model, eval_fun=None, eval_key=None, target_key=0, 7 | multiply_target=False): 8 | super(GradientCalculator, self).__init__(model) 9 | # self.model = model 10 | # self._device = cuda.get_array_module(model) 11 | self.eval_fun = eval_fun 12 | self.eval_key = eval_key 13 | self.target_key = target_key 14 | 15 | self.multiply_target = multiply_target 16 | 17 | def _compute_core(self, *inputs): 18 | # outputs = fn(*inputs) 19 | # outputs = _to_tuple(outputs) 20 | self.model.cleargrads() 21 | result = self.eval_fun(*inputs) 22 | if self.eval_key is None: 23 | eval_var = result 24 | elif isinstance(self.eval_key, str): 25 | eval_var = result[self.eval_key] 26 | else: 27 | raise TypeError('Unexpected type {} for eval_key' 28 | .format(type(self.eval_key))) 29 | # TODO: Consider how deal with the case when eval_var is not scalar, 30 | # 1. take sum 31 | # 2. raise error (default behavior) 32 | # I think option 1 "take sum" is better, since gradient is calculated 33 | # automatically independently in that case. 34 | eval_var.backward(retain_grad=True) 35 | 36 | if self.target_key is None: 37 | target_var = inputs 38 | elif isinstance(self.target_key, int): 39 | target_var = inputs[self.target_key] 40 | else: 41 | raise TypeError('Unexpected type {} for target_key' 42 | .format(type(self.target_key))) 43 | saliency = target_var.grad 44 | if self.multiply_target: 45 | saliency *= target_var.data 46 | outputs = (saliency,) 47 | return outputs 48 | -------------------------------------------------------------------------------- /saliency/calculator/integrated_gradients_calculator.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from chainer import Variable 3 | 4 | from saliency.calculator.base_calculator import BaseCalculator 5 | from saliency.calculator.gradient_calculator import GradientCalculator 6 | 7 | 8 | class IntegratedGradientsCalculator(GradientCalculator): 9 | 10 | def __init__(self, model, eval_fun=None, eval_key=None, target_key=0, 11 | baseline=None, steps=25): 12 | super(IntegratedGradientsCalculator, self).__init__( 13 | model, eval_fun=eval_fun, eval_key=eval_key, target_key=target_key, 14 | multiply_target=False 15 | ) 16 | # self.target_key = target_key 17 | self.baseline = baseline or 0. 18 | self.steps = steps 19 | 20 | def get_target_var(self, inputs): 21 | if self.target_key is None: 22 | target_var = inputs 23 | elif isinstance(self.target_key, int): 24 | target_var = inputs[self.target_key] 25 | else: 26 | raise TypeError('Unexpected type {} for target_key' 27 | .format(type(self.target_key))) 28 | return target_var 29 | 30 | def _compute_core(self, *inputs): 31 | 32 | total_grads = 0. 33 | target_var = self.get_target_var(inputs) 34 | base = self.baseline 35 | diff = target_var.data - base 36 | for alpha in numpy.linspace(0., 1., self.steps): 37 | # TODO: consider case target_key=None 38 | interpolated_inputs = ( 39 | Variable(base + alpha * diff) if self.target_key == i else elem 40 | for i, elem in enumerate(inputs)) 41 | total_grads += super( 42 | IntegratedGradientsCalculator, self)._compute_core( 43 | *interpolated_inputs)[0] 44 | saliency = total_grads * diff 45 | return saliency, 46 | -------------------------------------------------------------------------------- /saliency/calculator/occlusion_calculator.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy 4 | 5 | import chainer 6 | import six 7 | 8 | from saliency.calculator.base_calculator import BaseCalculator 9 | 10 | 11 | def _to_tuple(x): 12 | if isinstance(x, int): 13 | x = (x,) 14 | elif isinstance(x, (list, tuple)): 15 | x = tuple(x) 16 | else: 17 | raise TypeError('Unexpected type {}'.format(type(x))) 18 | return x 19 | 20 | 21 | class OcclusionCalculator(BaseCalculator): 22 | 23 | def __init__(self, model, eval_fun=None, eval_key=None, 24 | enable_backprop=False, size=1, slide_axis=(2, 3), 25 | target_key=0): 26 | super(OcclusionCalculator, self).__init__(model) 27 | # self.model = model 28 | # self._device = cuda.get_array_module(model) 29 | self.eval_fun = eval_fun 30 | self.eval_key = eval_key 31 | self.enable_backprop = enable_backprop 32 | self.slide_axis = _to_tuple(slide_axis) 33 | size = _to_tuple(size) 34 | if len(self.slide_axis) != size: 35 | size = size * len(self.slide_axis) 36 | self.size = size 37 | print('slide_axis', self.slide_axis, 'size', self.size) 38 | self.target_key = target_key 39 | 40 | def _compute_core(self, *inputs): 41 | if self.target_key is None: 42 | target_var = inputs 43 | elif isinstance(self.target_key, int): 44 | target_var = inputs[self.target_key] 45 | else: 46 | raise TypeError('Unexpected type {} for target_key' 47 | .format(type(self.target_key))) 48 | 49 | def _extract_score(result): 50 | if self.eval_key is None: 51 | score = result 52 | elif isinstance(self.eval_key, str): 53 | score = result[self.eval_key] 54 | else: 55 | raise TypeError('Unexpected type {} for eval_key' 56 | .format(type(self.eval_key))) 57 | return score 58 | 59 | # Usually, backward() is not necessary for calculating occlusion 60 | with chainer.using_config('enable_backprop', self.enable_backprop): 61 | original_result = self.eval_fun(*inputs) 62 | original_score = _extract_score(original_result) 63 | 64 | # TODO: xp and value assign dynamically 65 | xp = numpy 66 | value = 0. 67 | 68 | # fill with `value` 69 | target_dim = target_var.ndim 70 | batch_size = target_var.shape[0] 71 | occlusion_window_shape = [1] * target_dim 72 | occlusion_window_shape[0] = batch_size 73 | for axis, size in zip(self.slide_axis, self.size): 74 | occlusion_window_shape[axis] = size 75 | occlusion_scores_shape = [1] * target_dim 76 | occlusion_scores_shape[0] = batch_size 77 | for axis, size in zip(self.slide_axis, self.size): 78 | occlusion_scores_shape[axis] = target_var.shape[axis] 79 | # print('[DEBUG] occlusion_shape', occlusion_window_shape) 80 | occlusion_window = xp.ones(occlusion_window_shape, dtype=target_var.dtype) * value 81 | # print('[DEBUG] occlusion_window.shape', occlusion_window.shape) 82 | occlusion_scores = xp.zeros(occlusion_scores_shape, dtype=xp.float32) 83 | # print('[DEBUG] occlusion_scores.shape', occlusion_scores.shape) 84 | 85 | def _extract_index(slide_axis, size, start_indices): 86 | colon = slice(None) 87 | index = [colon] * target_dim 88 | for axis, size, start in zip(slide_axis, size, start_indices): 89 | index[axis] = slice(start, start + size, 1) 90 | return index 91 | 92 | end_list = [target_var.data.shape[axis] - size for axis, size 93 | in zip(self.slide_axis, self.size)] 94 | 95 | for start in itertools.product(*[six.moves.range(end) for end in end_list]): 96 | target_var_occluded = target_var.data.copy() 97 | occlude_index = _extract_index(self.slide_axis, self.size, start) 98 | target_var_occluded[occlude_index] = occlusion_window 99 | 100 | # Usually, backward() is not necessary for calculating occlusion 101 | with chainer.using_config('enable_backprop', self.enable_backprop): 102 | # TODO: consider case target_key=None 103 | occluded_inputs = (target_var_occluded if self.target_key == i else elem 104 | for i, elem in enumerate(inputs)) 105 | occluded_result = self.eval_fun(*occluded_inputs) 106 | occluded_score = _extract_score(occluded_result) 107 | score_diff_var = original_score - occluded_score 108 | # TODO: expand_dim dynamically 109 | score_diff = xp.broadcast_to(score_diff_var.data[:, :, None], occlusion_window_shape) 110 | occlusion_scores[occlude_index] += score_diff 111 | outputs = (occlusion_scores,) 112 | return outputs 113 | 114 | 115 | if __name__ == '__main__': 116 | # TODO: test 117 | raise NotImplementedError() 118 | oc = OcclusionCalculator(model) 119 | saliency_array = oc.compute_vanilla() 120 | saliency = oc.transform(saliency_array) 121 | -------------------------------------------------------------------------------- /saliency/visualizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/bayesgrad/5db613391777b20b7a367c274804f0b736991b0a/saliency/visualizer/__init__.py -------------------------------------------------------------------------------- /saliency/visualizer/base_visualizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | 4 | from future.utils import with_metaclass 5 | 6 | 7 | class BaseVisualizer(with_metaclass(ABCMeta, object)): 8 | 9 | @abstractmethod 10 | def visualize(self, *args, **kwargs): 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /saliency/visualizer/smiles_visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from rdkit import Chem 4 | from rdkit.Chem import rdDepictor 5 | from rdkit.Chem.Draw import rdMolDraw2D 6 | 7 | 8 | from saliency.visualizer.base_visualizer import BaseVisualizer 9 | 10 | 11 | def _convert_to_2d(axes, nrows, ncols): 12 | if nrows == 1 and ncols == 1: 13 | axes = numpy.array([[axes]]) 14 | elif nrows == 1: 15 | axes = axes[None, :] 16 | elif ncols == 1: 17 | axes = axes[:, None] 18 | else: 19 | pass 20 | assert axes.ndim == 2 21 | return axes 22 | 23 | 24 | def is_visible(begin, end): 25 | if begin <= 0 or end <= 0: 26 | return 0 27 | elif begin >= 1 or end >= 1: 28 | return 1 29 | else: 30 | return (begin + end) * 0.5 31 | 32 | 33 | # Default color function 34 | def red(x): 35 | # return in RGB order 36 | # x=0 -> 1, 1, 1 (white) 37 | # x=1 -> 1, 0, 0 (red) 38 | return 1., 1. - x, 1. - x 39 | 40 | 41 | def min_max_scaler(saliency): 42 | """Normalize saliency to value 0-1""" 43 | maxv = numpy.max(saliency) 44 | minv = numpy.min(saliency) 45 | if maxv == minv: 46 | saliency = numpy.zeros_like(saliency) 47 | else: 48 | saliency = (saliency - minv) / (maxv - minv) 49 | return saliency 50 | 51 | 52 | class SmilesVisualizer(BaseVisualizer): 53 | 54 | def visualize(self, saliency, smiles, save_filepath=None, 55 | visualize_ratio=1.0, color_fn=red, scaler=min_max_scaler, legend=''): 56 | mol = Chem.MolFromSmiles(smiles) 57 | num_atoms = mol.GetNumAtoms() 58 | rdDepictor.Compute2DCoords(mol) 59 | Chem.SanitizeMol(mol) 60 | Chem.Kekulize(mol) 61 | n_atoms = mol.GetNumAtoms() 62 | # highlight = list(range(n_atoms)) 63 | 64 | # --- type check --- 65 | assert saliency.ndim == 1 66 | # Cut saliency array for unnecessary tail part 67 | saliency = saliency[:num_atoms] 68 | # Normalize to [0, 1] 69 | saliency = scaler(saliency) 70 | # normed_saliency = copy.deepcopy(saliency) 71 | 72 | if visualize_ratio < 1.0: 73 | threshold_index = int(n_atoms * visualize_ratio) 74 | idx = numpy.argsort(saliency) 75 | idx = numpy.flip(idx, axis=0) 76 | # set threshold to top `visualize_ratio` saliency 77 | threshold = saliency[idx[threshold_index]] 78 | saliency = numpy.where(saliency < threshold, 0., saliency) 79 | else: 80 | threshold = numpy.min(saliency) 81 | 82 | highlight_atoms = list(map(lambda g: g.__int__(), numpy.where(saliency >= threshold)[0])) 83 | atom_colors = {i: color_fn(e) for i, e in enumerate(saliency)} 84 | bondlist = [bond.GetIdx() for bond in mol.GetBonds()] 85 | 86 | def color_bond(bond): 87 | begin = saliency[bond.GetBeginAtomIdx()] 88 | end = saliency[bond.GetEndAtomIdx()] 89 | return color_fn(is_visible(begin, end)) 90 | bondcolorlist = {i: color_bond(bond) for i, bond in enumerate(mol.GetBonds())} 91 | drawer = rdMolDraw2D.MolDraw2DSVG(500, 375) 92 | drawer.DrawMolecule( 93 | mol, highlightAtoms=highlight_atoms, 94 | highlightAtomColors=atom_colors, highlightBonds=bondlist, 95 | highlightBondColors=bondcolorlist, legend=legend) 96 | drawer.FinishDrawing() 97 | svg = drawer.GetDrawingText() 98 | if save_filepath: 99 | extention = save_filepath.split('.')[-1] 100 | if extention == 'svg': 101 | print('saving svg to {}'.format(save_filepath)) 102 | with open(save_filepath, 'w') as f: 103 | f.write(svg) 104 | elif extention == 'png': 105 | import cairosvg 106 | print('saving png to {}'.format(save_filepath)) 107 | # cairosvg.svg2png( 108 | # url=svg_save_filepath, write_to=save_filepath) 109 | # print('svg type', type(svg)) 110 | cairosvg.svg2png(bytestring=svg, write_to=save_filepath) 111 | else: 112 | raise ValueError( 113 | 'Unsupported extention {} for save_filepath {}' 114 | .format(extention, save_filepath)) 115 | else: 116 | from IPython.core.display import SVG 117 | return SVG(svg.replace('svg:', '')) 118 | --------------------------------------------------------------------------------