├── .gitignore ├── README.md ├── conda.yaml ├── config ├── count_nodes │ ├── dataset.yaml │ ├── full.yaml │ ├── minimal.yaml │ └── train.yaml ├── infection │ ├── datasets.yaml │ └── train.yaml └── solubility │ └── train.yaml ├── data └── delaney-processed.csv ├── models ├── infection │ ├── lr.001_nodes.1_count1_wd.0001_MPUKHT │ │ ├── events.out.tfevents.1555228025 │ │ ├── experiment.latest.yaml │ │ ├── model.latest.pt │ │ └── optimizer.latest.pt │ ├── lr.001_nodes.1_count1_wd0_QYNATD │ │ ├── events.out.tfevents.1555269841 │ │ ├── experiment.latest.yaml │ │ ├── model.latest.pt │ │ └── optimizer.latest.pt │ ├── max_bias │ │ ├── events.out.tfevents.1555604341 │ │ ├── experiment.latest.yaml │ │ ├── model.latest.pt │ │ └── optimizer.latest.pt │ ├── max_nobias │ │ ├── events.out.tfevents.1555604167 │ │ ├── experiment.latest.yaml │ │ ├── model.latest.pt │ │ └── optimizer.latest.pt │ ├── sum_bias │ │ ├── events.out.tfevents.1555603943 │ │ ├── experiment.latest.yaml │ │ ├── model.latest.pt │ │ └── optimizer.latest.pt │ └── sum_nobias │ │ ├── events.out.tfevents.1555603768 │ │ ├── experiment.latest.yaml │ │ ├── model.latest.pt │ │ └── optimizer.latest.pt └── solubility │ └── layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG │ ├── events.out.tfevents.1555363822 │ ├── experiment.latest.yaml │ ├── model.latest.pt │ └── optimizer.latest.pt ├── notebooks ├── Hijacking-Autograd-for-LRP-Graph.ipynb ├── Hijacking-Autograd-for-LRP.ipynb ├── Infection-Explanation-BigGraph.ipynb ├── Infection-LRP-SimpleExamples.ipynb ├── LRP-index-scatter.ipynb ├── LRP-linear-ReLU.ipynb ├── NxGraphs.ipynb ├── Solubility-GraphFeatures.ipynb ├── Solubility-LRP.ipynb └── biggraph.pt ├── resources ├── sucrose-atoms.png ├── sucrose-bonds.png └── sucrose.png ├── scripts.sh ├── setup.py └── src ├── __init__.py ├── config.py ├── count_nodes ├── __init__.py ├── dataset.py ├── layout.py ├── networks.py ├── notes.md └── train.py ├── guidedbackprop ├── __init__.py ├── autograd_tricks.py └── graphs.py ├── infection ├── __init__.py ├── dataset.py ├── layout.py ├── networks.py ├── notes.md ├── predict.py └── train.py ├── relevance ├── __init__.py ├── autograd_tricks.py ├── graphs.py ├── oldies.py └── patch.py ├── relevance_regression.py ├── saver.py ├── solubility ├── __init__.py ├── dataset.py ├── layout.py ├── networks.py ├── notes.md ├── predict.py └── train.py ├── test_edge_linear.py ├── test_global_linear.py ├── test_node_linear.py ├── utils.py └── yaml_ext.py /.gitignore: -------------------------------------------------------------------------------- 1 | notebooks/ 2 | data 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Jupyter Notebook 57 | .ipynb_checkpoints 58 | 59 | # IPython 60 | profile_default/ 61 | ipython_config.py 62 | 63 | # Environments 64 | .env 65 | .venv 66 | env/ 67 | venv/ 68 | ENV/ 69 | env.bak/ 70 | venv.bak/ 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Explainability Techniques for Graph Convolutional Networks 2 | 3 | Code and notebooks for the paper ["Explainability Techniques for Graph Convolutional Networks"](https://arxiv.org/abs/1905.13686) 4 | accepted at the ICML 2019 Workshop ["Learning and Reasoning with Graph-Structured Data"](https://graphreason.github.io/). 5 | 6 | ## Overview 7 | A Graph Network trained to predict the solubility of organic molecules is applied to _sucrose_, 8 | the prediction is explained using [Layer-wise Relevance Propagation](http://heatmapping.org) that assigns 9 | positive and negative relevance to the nodes and edges of the molecular graph: 10 | 11 | ![Sucrose solubility LRP](resources/sucrose.png) 12 | 13 | The predicted solubility can be broken down to the individual features of the atoms and their bonds: 14 | 15 | ![Sucrose solubility LRP nodes](resources/sucrose-atoms.png) 16 | ![Sucrose solubility LRP edges](resources/sucrose-bonds.png) 17 | 18 | ## Code structure 19 | - `src`, `config`, `data` contain code, configuration files and data for the experiments 20 | - `infection`, `solubility` contain the code for the two experiments in the paper 21 | - `torchgraphs` contain the core graph network library 22 | - `guidedbackrprop`, `relevance` contain the code to run Guided Backpropagation and Layer-wise Relevance Propagation on top of PyTorch's `autograd` 23 | - `notebooks`, `models` contain a visualization of the datasets, the trained models and the results of our experiments 24 | - `test` contains unit tests for the `torchgraphs` module (core GN library) 25 | - `conda.yaml` contains the conda environment for the project 26 | 27 | ## Setup 28 | The project is build on top of Python 3.7, PyTorch 1.1+, 29 | [torchgraphs](https://github.com/baldassarreFe/torchgraphs) 0.0.1 and many other open source projects. 30 | 31 | A [Conda](https://conda.io) environment for the project can be installed as: 32 | ```bash 33 | conda env create -n gn-exp -f conda.yaml 34 | conda activate gn-exp 35 | python setup.py develop 36 | pytest 37 | ``` 38 | 39 | ## Training 40 | Detailed instructions for data processing, training and hyperparameter search can be found in the respective subfolders: 41 | - Infection: [infection/notes.md](./src/infection/notes.md) 42 | - Solubility: [solubility/notes.md](./src/solubility/notes.md) 43 | 44 | ## Experimental results 45 | The results of our experiments are visualized through the notebooks in [`notebooks`](./notebooks): 46 | ```bash 47 | conda activate gn-exp 48 | cd notebooks 49 | jupyter lab 50 | ``` 51 | -------------------------------------------------------------------------------- /conda.yaml: -------------------------------------------------------------------------------- 1 | name: tg-experiments 2 | channels: 3 | - rdkit 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - attrs=19.1.0=py37_1 8 | - backcall=0.1.0=py37_0 9 | - blas=1.0=mkl 10 | - bleach=3.1.0=py37_0 11 | - bzip2=1.0.6=h14c3975_5 12 | - ca-certificates=2019.1.23=0 13 | - cairo=1.14.12=h8948797_3 14 | - certifi=2019.3.9=py37_0 15 | - cffi=1.12.2=py37h2e261b9_1 16 | - cudatoolkit=10.0.130=0 17 | - cudnn=7.3.1=cuda10.0_0 18 | - cycler=0.10.0=py37_0 19 | - dbus=1.13.6=h746ee38_0 20 | - decorator=4.4.0=py37_1 21 | - defusedxml=0.5.0=py37_1 22 | - entrypoints=0.3=py37_0 23 | - expat=2.2.6=he6710b0_0 24 | - fontconfig=2.13.0=h9420a91_0 25 | - freetype=2.9.1=h8a8886c_1 26 | - glib=2.56.2=hd408876_0 27 | - gmp=6.1.2=h6c8ec71_1 28 | - gst-plugins-base=1.14.0=hbbd80ab_1 29 | - gstreamer=1.14.0=hb453b48_1 30 | - icu=58.2=h9c2bf20_1 31 | - intel-openmp=2019.3=199 32 | - ipykernel=5.1.0=py37h39e3cac_0 33 | - ipython=7.4.0=py37h39e3cac_0 34 | - ipython_genutils=0.2.0=py37_0 35 | - jedi=0.13.3=py37_0 36 | - jinja2=2.10.1=py37_0 37 | - jpeg=9b=h024ee3a_2 38 | - jsonschema=3.0.1=py37_0 39 | - jupyter_client=5.2.4=py37_0 40 | - jupyter_console=6.0.0=py37_0 41 | - jupyter_core=4.4.0=py37_0 42 | - jupyterlab=0.35.4=py37hf63ae98_0 43 | - jupyterlab_server=0.2.0=py37_0 44 | - kiwisolver=1.0.1=py37hf484d3e_0 45 | - libboost=1.67.0=h46d08c1_4 46 | - libedit=3.1.20181209=hc058e9b_0 47 | - libffi=3.2.1=hd88cf55_4 48 | - libgcc-ng=8.2.0=hdf63c60_1 49 | - libgfortran-ng=7.3.0=hdf63c60_0 50 | - libpng=1.6.36=hbc83047_0 51 | - libsodium=1.0.16=h1bed415_0 52 | - libstdcxx-ng=8.2.0=hdf63c60_1 53 | - libtiff=4.0.10=h2733197_2 54 | - libuuid=1.0.3=h1bed415_2 55 | - libxcb=1.13=h1bed415_1 56 | - libxml2=2.9.9=he19cac6_0 57 | - markupsafe=1.1.1=py37h7b6447c_0 58 | - matplotlib=3.0.3=py37h5429711_0 59 | - mistune=0.8.4=py37h7b6447c_0 60 | - mkl=2019.3=199 61 | - mkl_fft=1.0.12=py37ha843d7b_0 62 | - mkl_random=1.0.2=py37hd81dba3_0 63 | - munch=2.3.2=py37_0 64 | - nbconvert=5.4.1=py37_3 65 | - nbformat=4.4.0=py37_0 66 | - ncurses=6.1=he6710b0_1 67 | - ninja=1.9.0=py37hfd86e86_0 68 | - notebook=5.7.8=py37_0 69 | - numpy-base=1.16.4=py37hde5b4d6_0 70 | - olefile=0.46=py37_0 71 | - openssl=1.1.1c=h7b6447c_1 72 | - pandas=0.24.2=py37he6710b0_0 73 | - pandoc=2.2.3.2=0 74 | - pandocfilters=1.4.2=py37_1 75 | - parso=0.3.4=py37_0 76 | - pcre=8.43=he6710b0_0 77 | - pexpect=4.6.0=py37_0 78 | - pickleshare=0.7.5=py37_0 79 | - pillow=5.4.1=py37h34e0f95_0 80 | - pip=19.0.3=py37_0 81 | - pixman=0.38.0=h7b6447c_0 82 | - prometheus_client=0.6.0=py37_0 83 | - prompt_toolkit=2.0.9=py37_0 84 | - ptyprocess=0.6.0=py37_0 85 | - py-boost=1.67.0=py37h04863e7_4 86 | - pyaml=18.11.0=py37_0 87 | - pycparser=2.19=py37_0 88 | - pygments=2.3.1=py37_0 89 | - pyparsing=2.4.0=py_0 90 | - pyqt=5.9.2=py37h05f1152_2 91 | - pyrsistent=0.14.11=py37h7b6447c_0 92 | - python=3.7.3=h0371630_0 93 | - python-dateutil=2.8.0=py37_0 94 | - pytorch=1.1.0=py3.7_cuda10.0.130_cudnn7.5.1_0 95 | - pytz=2018.9=py37_0 96 | - pyyaml=5.1=py37h7b6447c_0 97 | - pyzmq=18.0.0=py37he6710b0_0 98 | - qt=5.9.7=h5867ecd_1 99 | - rdkit=2019.03.1.0=py37hc20afe1_1 100 | - readline=7.0=h7b6447c_5 101 | - scikit-learn=0.20.3=py37hd81dba3_0 102 | - scipy=1.2.1=py37h7c811a0_0 103 | - send2trash=1.5.0=py37_0 104 | - setuptools=41.0.0=py37_0 105 | - sip=4.19.8=py37hf484d3e_0 106 | - six=1.12.0=py37_0 107 | - sqlite=3.27.2=h7b6447c_0 108 | - terminado=0.8.1=py37_1 109 | - testpath=0.4.2=py37_0 110 | - tk=8.6.8=hbc83047_0 111 | - tornado=6.0.2=py37h7b6447c_0 112 | - tqdm=4.31.1=py37_1 113 | - traitlets=4.3.2=py37_0 114 | - wcwidth=0.1.7=py37_0 115 | - webencodings=0.5.1=py37_1 116 | - wheel=0.33.1=py37_0 117 | - xz=5.2.4=h14c3975_4 118 | - yaml=0.1.7=had09818_2 119 | - zeromq=4.3.1=he6710b0_3 120 | - zlib=1.2.11=h7b6447c_3 121 | - zstd=1.3.7=h0b5b093_0 122 | - pip: 123 | - networkx==2.3 124 | - numpy==1.16.4 125 | - protobuf==3.7.1 126 | - tensorboardx==1.6 127 | - torch==1.1.0 128 | - torch-scatter==1.2.0 129 | - git+https://github.com/baldassarreFe/torchgraphs@v0.0.1#egg=torchgraphs 130 | -------------------------------------------------------------------------------- /config/count_nodes/dataset.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | edge_features_shape: 2 3 | node_features_shape: 4 4 | global_features_shape: 2 5 | informative_features: 2 6 | _train_: 7 | min_nodes: 0 8 | max_nodes: 20 9 | num_samples: 100_000 10 | _val_: 11 | min_nodes: 20 12 | max_nodes: 60 13 | num_samples: 50_000 14 | _test_: 15 | min_nodes: 0 16 | max_nodes: 60 17 | num_samples: 100_000 18 | 19 | opts: 20 | seed: _auto_ 21 | folder: ~/experiments/count-nodes -------------------------------------------------------------------------------- /config/count_nodes/full.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _class_: count_nodes.networks.FullGN 3 | in_edge_features_shape: 2 4 | in_node_features_shape: 4 5 | in_global_features_shape: 2 6 | out_edge_features_shape: 2 7 | out_node_features_shape: 1 8 | out_global_features_shape: 1 9 | -------------------------------------------------------------------------------- /config/count_nodes/minimal.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | _class_: count_nodes.networks.MinimalGN 3 | in_node_features_shape: 4 4 | out_node_features_shape: 1 5 | out_global_features_shape: 1 6 | -------------------------------------------------------------------------------- /config/count_nodes/train.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 1_000 3 | epochs: 20 4 | save_every: 0 5 | restore: False 6 | l1: 0 7 | 8 | opts: 9 | log: True 10 | folder: ~/experiments/count-nodes 11 | session: _auto_ 12 | cpus: _auto_ 13 | seed: _auto_ 14 | device: _auto_ 15 | 16 | optimizer: 17 | _class_: torch.optim.Adam 18 | -------------------------------------------------------------------------------- /config/infection/datasets.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | train: 3 | min_nodes: 10 4 | max_nodes: 30 5 | max_percent_sick: .1 6 | max_percent_immune: .3 7 | max_percent_virtual: .3 8 | num_samples: 100_000 9 | val: 10 | min_nodes: 10 11 | max_nodes: 60 12 | max_percent_sick: .4 13 | max_percent_immune: .6 14 | max_percent_virtual: .5 15 | num_samples: 20_000 16 | 17 | folder: ~/experiments/infection/data -------------------------------------------------------------------------------- /config/infection/train.yaml: -------------------------------------------------------------------------------- 1 | name: infection 2 | 3 | model: 4 | fn: infection.networks.InfectionGN 5 | kwargs: 6 | aggregation: max 7 | bias: yes 8 | 9 | optimizer: 10 | fn: torch.optim.Adam 11 | kwargs: 12 | lr: .001 13 | 14 | session: 15 | epochs: 20 16 | batch_size: 1_000 17 | losses: 18 | nodes: 1 19 | count: 1 20 | l1: 0 21 | data: 22 | folder: ~/experiments/{name}/data 23 | log: 24 | folder: ~/experiments/{name}/runs/{tags}_{rand} 25 | when: 26 | - every batch 27 | checkpoint: 28 | folder: ~/experiments/{name}/runs/{tags}_{rand} 29 | when: 30 | #- every epoch 31 | - last epoch 32 | -------------------------------------------------------------------------------- /config/solubility/train.yaml: -------------------------------------------------------------------------------- 1 | name: solubility 2 | 3 | model: 4 | fn: solubility.networks.SolubilityGN 5 | kwargs: 6 | num_layers: 2 7 | hidden_bias: yes 8 | hidden_node: 16 9 | aggregation: mean 10 | 11 | optimizer: 12 | fn: torch.optim.Adam 13 | kwargs: 14 | lr: .001 15 | 16 | session: 17 | epochs: 20 18 | batch_size: 50 19 | losses: 20 | solubility: 1 21 | l1: 0 22 | data: 23 | path: ~/experiments/{name}/data/delaney-processed.csv 24 | train: .7 25 | val: .3 26 | log: 27 | folder: ~/experiments/{name}/runs/{tags}_{rand} 28 | when: 29 | - every batch 30 | checkpoint: 31 | folder: ~/experiments/{name}/runs/{tags}_{rand} 32 | when: 33 | - last epoch 34 | -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/events.out.tfevents.1555228025: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/events.out.tfevents.1555228025 -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/experiment.latest.yaml: -------------------------------------------------------------------------------- 1 | name: infection 2 | tags: 3 | - lr.001 4 | - nodes.1 5 | - count1 6 | - wd.0001 7 | epoch: 20 8 | samples: 2000000 9 | model: 10 | fn: infection.networks.InfectionGN 11 | args: [] 12 | kwargs: {} 13 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT/checkpoints/model.e0020.pt 14 | optimizer: 15 | fn: torch.optim.Adam 16 | args: [] 17 | kwargs: 18 | lr: 0.001 19 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT/checkpoints/optimizer.e0020.pt 20 | sessions: 21 | - epochs: 20 22 | batch_size: 1000 23 | losses: 24 | nodes: 0.1 25 | count: 1 26 | l1: 0.0001 27 | seed: 60 28 | cpus: 31 29 | device: cuda 30 | status: DONE 31 | datetime_started: 2019-04-14 07:47:39.901626 32 | datetime_completed: 2019-04-14 07:51:05.397617 33 | data: 34 | folder: /experiments/infection/data 35 | log: 36 | when: 37 | - every batch 38 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT 39 | checkpoint: 40 | when: 41 | - last epoch 42 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT 43 | cuda: 44 | driver: '396.37' 45 | gpus: 46 | - model: GeForce GTX 1080 Ti 47 | utilization: 18 % 48 | memory_used: 1639 MiB 49 | memory_total: 11178 MiB 50 | - model: GeForce GTX 1080 Ti 51 | utilization: 0 % 52 | memory_used: 0 MiB 53 | memory_total: 11178 MiB 54 | - model: GeForce GTX 1080 Ti 55 | utilization: 0 % 56 | memory_used: 0 MiB 57 | memory_total: 11178 MiB 58 | - model: GeForce GTX 1080 Ti 59 | utilization: 0 % 60 | memory_used: 0 MiB 61 | memory_total: 11178 MiB 62 | - model: GeForce GTX 1080 Ti 63 | utilization: 0 % 64 | memory_used: 0 MiB 65 | memory_total: 11178 MiB 66 | - model: GeForce GTX 1080 Ti 67 | utilization: 0 % 68 | memory_used: 0 MiB 69 | memory_total: 11178 MiB 70 | -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/model.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/model.latest.pt -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/optimizer.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/optimizer.latest.pt -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd0_QYNATD/events.out.tfevents.1555269841: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/events.out.tfevents.1555269841 -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd0_QYNATD/experiment.latest.yaml: -------------------------------------------------------------------------------- 1 | name: infection 2 | tags: 3 | - lr.001 4 | - nodes.1 5 | - count1 6 | - wd0 7 | epoch: 20 8 | samples: 2000000 9 | model: 10 | fn: infection.networks.InfectionGN 11 | args: [] 12 | kwargs: {} 13 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD/checkpoints/model.e0020.pt 14 | optimizer: 15 | fn: torch.optim.Adam 16 | args: [] 17 | kwargs: 18 | lr: 0.001 19 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD/checkpoints/optimizer.e0020.pt 20 | sessions: 21 | - epochs: 20 22 | batch_size: 1000 23 | losses: 24 | nodes: 0.1 25 | count: 1 26 | l1: 0 27 | seed: 76 28 | cpus: 31 29 | device: cuda 30 | status: DONE 31 | datetime_started: 2019-04-14 19:24:34.801354 32 | datetime_completed: 2019-04-14 19:28:47.920841 33 | data: 34 | folder: /experiments/infection/data 35 | log: 36 | when: 37 | - every batch 38 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD 39 | checkpoint: 40 | when: 41 | - last epoch 42 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD 43 | cuda: 44 | driver: '396.37' 45 | gpus: 46 | - model: GeForce GTX 1080 Ti 47 | utilization: 1 % 48 | memory_used: 3086 MiB 49 | memory_total: 11178 MiB 50 | - model: GeForce GTX 1080 Ti 51 | utilization: 0 % 52 | memory_used: 0 MiB 53 | memory_total: 11178 MiB 54 | - model: GeForce GTX 1080 Ti 55 | utilization: 0 % 56 | memory_used: 0 MiB 57 | memory_total: 11178 MiB 58 | - model: GeForce GTX 1080 Ti 59 | utilization: 0 % 60 | memory_used: 0 MiB 61 | memory_total: 11178 MiB 62 | - model: GeForce GTX 1080 Ti 63 | utilization: 0 % 64 | memory_used: 0 MiB 65 | memory_total: 11178 MiB 66 | - model: GeForce GTX 1080 Ti 67 | utilization: 0 % 68 | memory_used: 10 MiB 69 | memory_total: 11178 MiB 70 | -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd0_QYNATD/model.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/model.latest.pt -------------------------------------------------------------------------------- /models/infection/lr.001_nodes.1_count1_wd0_QYNATD/optimizer.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/optimizer.latest.pt -------------------------------------------------------------------------------- /models/infection/max_bias/events.out.tfevents.1555604341: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_bias/events.out.tfevents.1555604341 -------------------------------------------------------------------------------- /models/infection/max_bias/experiment.latest.yaml: -------------------------------------------------------------------------------- 1 | name: infection 2 | tags: [] 3 | epoch: 20 4 | samples: 2000000 5 | model: 6 | fn: infection.networks.InfectionGN 7 | args: [] 8 | kwargs: 9 | aggregation: max 10 | bias: true 11 | state_dict: /experiments/infection/runs/_YAPSEY/checkpoints/model.e0020.pt 12 | optimizer: 13 | fn: torch.optim.Adam 14 | args: [] 15 | kwargs: 16 | lr: 0.001 17 | state_dict: /experiments/infection/runs/_YAPSEY/checkpoints/optimizer.e0020.pt 18 | sessions: 19 | - epochs: 20 20 | batch_size: 1000 21 | losses: 22 | nodes: 1 23 | count: 0 24 | l1: 0.001 25 | seed: 6 26 | cpus: 11 27 | device: cuda 28 | status: DONE 29 | datetime_started: 2019-04-18 16:19:27.499119 30 | datetime_completed: 2019-04-18 16:21:42.604716 31 | data: 32 | folder: /experiments/infection/data 33 | log: 34 | when: 35 | - every batch 36 | folder: /experiments/infection/runs/_YAPSEY 37 | checkpoint: 38 | when: 39 | - last epoch 40 | folder: /experiments/infection/runs/_YAPSEY 41 | cuda: 42 | driver: '418.43' 43 | gpus: 44 | - model: GeForce GTX 1050 Ti with Max-Q Design 45 | utilization: 0 % 46 | memory_used: 10 MiB 47 | memory_total: 4042 MiB 48 | -------------------------------------------------------------------------------- /models/infection/max_bias/model.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_bias/model.latest.pt -------------------------------------------------------------------------------- /models/infection/max_bias/optimizer.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_bias/optimizer.latest.pt -------------------------------------------------------------------------------- /models/infection/max_nobias/events.out.tfevents.1555604167: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_nobias/events.out.tfevents.1555604167 -------------------------------------------------------------------------------- /models/infection/max_nobias/experiment.latest.yaml: -------------------------------------------------------------------------------- 1 | name: infection 2 | tags: [] 3 | epoch: 20 4 | samples: 2000000 5 | model: 6 | fn: infection.networks.InfectionGN 7 | args: [] 8 | kwargs: 9 | aggregation: max 10 | bias: false 11 | state_dict: /experiments/infection/runs/_VFJYGG/checkpoints/model.e0020.pt 12 | optimizer: 13 | fn: torch.optim.Adam 14 | args: [] 15 | kwargs: 16 | lr: 0.001 17 | state_dict: /experiments/infection/runs/_VFJYGG/checkpoints/optimizer.e0020.pt 18 | sessions: 19 | - epochs: 20 20 | batch_size: 1000 21 | losses: 22 | nodes: 1 23 | count: 0 24 | l1: 0.001 25 | seed: 44 26 | cpus: 11 27 | device: cuda 28 | status: DONE 29 | datetime_started: 2019-04-18 16:16:34.420894 30 | datetime_completed: 2019-04-18 16:18:45.621812 31 | data: 32 | folder: /experiments/infection/data 33 | log: 34 | when: 35 | - every batch 36 | folder: /experiments/infection/runs/_VFJYGG 37 | checkpoint: 38 | when: 39 | - last epoch 40 | folder: /experiments/infection/runs/_VFJYGG 41 | cuda: 42 | driver: '418.43' 43 | gpus: 44 | - model: GeForce GTX 1050 Ti with Max-Q Design 45 | utilization: 0 % 46 | memory_used: 10 MiB 47 | memory_total: 4042 MiB 48 | -------------------------------------------------------------------------------- /models/infection/max_nobias/model.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_nobias/model.latest.pt -------------------------------------------------------------------------------- /models/infection/max_nobias/optimizer.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_nobias/optimizer.latest.pt -------------------------------------------------------------------------------- /models/infection/sum_bias/events.out.tfevents.1555603943: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_bias/events.out.tfevents.1555603943 -------------------------------------------------------------------------------- /models/infection/sum_bias/experiment.latest.yaml: -------------------------------------------------------------------------------- 1 | name: infection 2 | tags: [] 3 | epoch: 20 4 | samples: 2000000 5 | model: 6 | fn: infection.networks.InfectionGN 7 | args: [] 8 | kwargs: 9 | aggregation: sum 10 | bias: true 11 | state_dict: /experiments/infection/runs/_EHRTVG/checkpoints/model.e0020.pt 12 | optimizer: 13 | fn: torch.optim.Adam 14 | args: [] 15 | kwargs: 16 | lr: 0.001 17 | state_dict: /experiments/infection/runs/_EHRTVG/checkpoints/optimizer.e0020.pt 18 | sessions: 19 | - epochs: 20 20 | batch_size: 1000 21 | losses: 22 | nodes: 1 23 | count: 0 24 | l1: 0.001 25 | seed: 89 26 | cpus: 11 27 | device: cuda 28 | status: DONE 29 | datetime_started: 2019-04-18 16:12:49.619048 30 | datetime_completed: 2019-04-18 16:15:03.664174 31 | data: 32 | folder: /experiments/infection/data 33 | log: 34 | when: 35 | - every batch 36 | folder: /experiments/infection/runs/_EHRTVG 37 | checkpoint: 38 | when: 39 | - last epoch 40 | folder: /experiments/infection/runs/_EHRTVG 41 | cuda: 42 | driver: '418.43' 43 | gpus: 44 | - model: GeForce GTX 1050 Ti with Max-Q Design 45 | utilization: 0 % 46 | memory_used: 10 MiB 47 | memory_total: 4042 MiB 48 | -------------------------------------------------------------------------------- /models/infection/sum_bias/model.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_bias/model.latest.pt -------------------------------------------------------------------------------- /models/infection/sum_bias/optimizer.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_bias/optimizer.latest.pt -------------------------------------------------------------------------------- /models/infection/sum_nobias/events.out.tfevents.1555603768: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_nobias/events.out.tfevents.1555603768 -------------------------------------------------------------------------------- /models/infection/sum_nobias/experiment.latest.yaml: -------------------------------------------------------------------------------- 1 | name: infection 2 | tags: [] 3 | epoch: 20 4 | samples: 2000000 5 | model: 6 | fn: infection.networks.InfectionGN 7 | args: [] 8 | kwargs: 9 | aggregation: sum 10 | bias: false 11 | state_dict: /experiments/infection/runs/_QDOVMP/checkpoints/model.e0020.pt 12 | optimizer: 13 | fn: torch.optim.Adam 14 | args: [] 15 | kwargs: 16 | lr: 0.001 17 | state_dict: /experiments/infection/runs/_QDOVMP/checkpoints/optimizer.e0020.pt 18 | sessions: 19 | - epochs: 20 20 | batch_size: 1000 21 | losses: 22 | nodes: 1 23 | count: 0 24 | l1: 0.001 25 | seed: 30 26 | cpus: 11 27 | device: cuda 28 | status: DONE 29 | datetime_started: 2019-04-18 16:09:54.821299 30 | datetime_completed: 2019-04-18 16:12:04.285842 31 | data: 32 | folder: /experiments/infection/data 33 | log: 34 | when: 35 | - every batch 36 | folder: /experiments/infection/runs/_QDOVMP 37 | checkpoint: 38 | when: 39 | - last epoch 40 | folder: /experiments/infection/runs/_QDOVMP 41 | cuda: 42 | driver: '418.43' 43 | gpus: 44 | - model: GeForce GTX 1050 Ti with Max-Q Design 45 | utilization: 0 % 46 | memory_used: 10 MiB 47 | memory_total: 4042 MiB 48 | -------------------------------------------------------------------------------- /models/infection/sum_nobias/model.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_nobias/model.latest.pt -------------------------------------------------------------------------------- /models/infection/sum_nobias/optimizer.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_nobias/optimizer.latest.pt -------------------------------------------------------------------------------- /models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/events.out.tfevents.1555363822: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/events.out.tfevents.1555363822 -------------------------------------------------------------------------------- /models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/experiment.latest.yaml: -------------------------------------------------------------------------------- 1 | name: solubility 2 | tags: 3 | - layers3 4 | - lr.01 5 | - biasyes 6 | - size64 7 | - wd.001 8 | - dryes 9 | - e50 10 | - sum 11 | epoch: 50 12 | samples: 39450 13 | model: 14 | fn: solubility.networks.SolubilityGN 15 | args: [] 16 | kwargs: 17 | num_layers: 3 18 | hidden_bias: true 19 | hidden_node: 64 20 | aggregation: sum 21 | dropout: 50 22 | state_dict: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/checkpoints/model.e0050.pt 23 | optimizer: 24 | fn: torch.optim.Adam 25 | args: [] 26 | kwargs: 27 | lr: 0.01 28 | state_dict: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/checkpoints/optimizer.e0050.pt 29 | sessions: 30 | - epochs: 50 31 | batch_size: 50 32 | losses: 33 | solubility: 1 34 | l1: 0.001 35 | seed: 94 36 | cpus: 31 37 | device: cuda 38 | status: DONE 39 | datetime_started: 2019-04-15 21:30:22.449869 40 | datetime_completed: 2019-04-15 21:42:14.467594 41 | data: 42 | path: /experiments/solubility/data/delaney-processed.csv 43 | train: 0.7 44 | val: 0.3 45 | log: 46 | when: 47 | - every batch 48 | folder: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG 49 | checkpoint: 50 | when: 51 | - last epoch 52 | folder: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG 53 | cuda: 54 | driver: '396.37' 55 | gpus: 56 | - model: GeForce GTX 1080 Ti 57 | utilization: 5 % 58 | memory_used: 2432 MiB 59 | memory_total: 11178 MiB 60 | - model: GeForce GTX 1080 Ti 61 | utilization: 0 % 62 | memory_used: 0 MiB 63 | memory_total: 11178 MiB 64 | - model: GeForce GTX 1080 Ti 65 | utilization: 0 % 66 | memory_used: 0 MiB 67 | memory_total: 11178 MiB 68 | - model: GeForce GTX 1080 Ti 69 | utilization: 0 % 70 | memory_used: 0 MiB 71 | memory_total: 11178 MiB 72 | - model: GeForce GTX 1080 Ti 73 | utilization: 0 % 74 | memory_used: 0 MiB 75 | memory_total: 11178 MiB 76 | - model: GeForce GTX 1080 Ti 77 | utilization: 0 % 78 | memory_used: 0 MiB 79 | memory_total: 11178 MiB 80 | loss_sol_train: 0.5124920198672624 81 | loss_sol_val: 0.5948061519316165 82 | -------------------------------------------------------------------------------- /models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/model.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/model.latest.pt -------------------------------------------------------------------------------- /models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/optimizer.latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/optimizer.latest.pt -------------------------------------------------------------------------------- /notebooks/Solubility-GraphFeatures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Solubility" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 181, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "import numpy as np\n", 18 | "import pandas as pd\n", 19 | "from pandas.api.types import CategoricalDtype\n", 20 | "\n", 21 | "from rdkit import Chem\n", 22 | "from rdkit.Chem import AllChem\n", 23 | "\n", 24 | "import torchgraphs as tg" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 182, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "text/html": [ 35 | "
\n", 36 | "\n", 49 | "\n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | "
Compound IDESOL predicted log solubility in mols per litreMinimum DegreeMolecular WeightNumber of H-Bond DonorsNumber of RingsNumber of Rotatable BondsPolar Surface Areameasured log solubility in mols per litresmiles
0Amigdalin-0.9741457.432737202.32-0.77OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...
1Fenfuram-2.8851201.22512242.24-3.30Cc1occc1C(=O)Nc2ccccc2
2citral-2.5791152.23700417.07-2.06CC(C)=CCCC(C)=CC(=O)
3Picene-6.6182278.3540500.00-7.87c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43
4Thiophene-2.232284.1430100.00-1.33c1ccsc1
\n", 133 | "
" 134 | ], 135 | "text/plain": [ 136 | " Compound ID ESOL predicted log solubility in mols per litre \\\n", 137 | "0 Amigdalin -0.974 \n", 138 | "1 Fenfuram -2.885 \n", 139 | "2 citral -2.579 \n", 140 | "3 Picene -6.618 \n", 141 | "4 Thiophene -2.232 \n", 142 | "\n", 143 | " Minimum Degree Molecular Weight Number of H-Bond Donors Number of Rings \\\n", 144 | "0 1 457.432 7 3 \n", 145 | "1 1 201.225 1 2 \n", 146 | "2 1 152.237 0 0 \n", 147 | "3 2 278.354 0 5 \n", 148 | "4 2 84.143 0 1 \n", 149 | "\n", 150 | " Number of Rotatable Bonds Polar Surface Area \\\n", 151 | "0 7 202.32 \n", 152 | "1 2 42.24 \n", 153 | "2 4 17.07 \n", 154 | "3 0 0.00 \n", 155 | "4 0 0.00 \n", 156 | "\n", 157 | " measured log solubility in mols per litre \\\n", 158 | "0 -0.77 \n", 159 | "1 -3.30 \n", 160 | "2 -2.06 \n", 161 | "3 -7.87 \n", 162 | "4 -1.33 \n", 163 | "\n", 164 | " smiles \n", 165 | "0 OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)... \n", 166 | "1 Cc1occc1C(=O)Nc2ccccc2 \n", 167 | "2 CC(C)=CCCC(C)=CC(=O) \n", 168 | "3 c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43 \n", 169 | "4 c1ccsc1 " 170 | ] 171 | }, 172 | "execution_count": 182, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "df = pd.read_csv('../data/delaney-processed.csv')\n", 179 | "df.head()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 183, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "molecule = Chem.MolFromSmiles(df.smiles[0])" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "## Atom features" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 184, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "symbols = CategoricalDtype([\n", 205 | " 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',\n", 206 | " 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',\n", 207 | " 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', # H?\n", 208 | " 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',\n", 209 | " 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown'\n", 210 | "], ordered=True)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 185, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "text/html": [ 221 | "
\n", 222 | "\n", 235 | "\n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | "
degreehydrogensimpl_valencesymbol
index
0111O
1222C
2311C
3200O
4311C
\n", 290 | "
" 291 | ], 292 | "text/plain": [ 293 | " degree hydrogens impl_valence symbol\n", 294 | "index \n", 295 | "0 1 1 1 O\n", 296 | "1 2 2 2 C\n", 297 | "2 3 1 1 C\n", 298 | "3 2 0 0 O\n", 299 | "4 3 1 1 C" 300 | ] 301 | }, 302 | "execution_count": 185, 303 | "metadata": {}, 304 | "output_type": "execute_result" 305 | } 306 | ], 307 | "source": [ 308 | "atoms_df = []\n", 309 | "for i in range(molecule.GetNumAtoms()):\n", 310 | " atom = molecule.GetAtomWithIdx(i)\n", 311 | " atoms_df.append({\n", 312 | " 'index': i,\n", 313 | " 'symbol': atom.GetSymbol(),\n", 314 | " 'degree': atom.GetDegree(),\n", 315 | " 'hydrogens': atom.GetTotalNumHs(),\n", 316 | " 'impl_valence': atom.GetImplicitValence(),\n", 317 | " })\n", 318 | "atoms_df = pd.DataFrame.from_records(atoms_df, index='index')\n", 319 | "#atoms_df.degree.cat.set_categories([0, 1, 2, 3, 4, 5])\n", 320 | "#atoms_df.hydrogens.cat.set_categories([0, 1, 2, 3, 4])\n", 321 | "#atoms_df.impl_valence.cat.set_categories([0, 1, 2, 3, 4, 5])\n", 322 | "atoms_df.symbol = atoms_df.symbol.astype(symbols)\n", 323 | "atoms_df.head()" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 186, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "node_features = torch.tensor(pd.get_dummies(atoms_df, columns=['symbol']).values, dtype=torch.float)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "## Bond features" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 187, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "bonds = CategoricalDtype([\n", 349 | " 'SINGLE',\n", 350 | " 'DOUBLE',\n", 351 | " 'TRIPLE',\n", 352 | " 'AROMATIC'\n", 353 | "], ordered=True)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 188, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "data": { 363 | "text/html": [ 364 | "
\n", 365 | "\n", 378 | "\n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | "
conjringtype
senderreceiver
01-1.0-1.0SINGLE
10-1.0-1.0SINGLE
2-1.0-1.0SINGLE
21-1.0-1.0SINGLE
3-1.01.0SINGLE
\n", 431 | "
" 432 | ], 433 | "text/plain": [ 434 | " conj ring type\n", 435 | "sender receiver \n", 436 | "0 1 -1.0 -1.0 SINGLE\n", 437 | "1 0 -1.0 -1.0 SINGLE\n", 438 | " 2 -1.0 -1.0 SINGLE\n", 439 | "2 1 -1.0 -1.0 SINGLE\n", 440 | " 3 -1.0 1.0 SINGLE" 441 | ] 442 | }, 443 | "execution_count": 188, 444 | "metadata": {}, 445 | "output_type": "execute_result" 446 | } 447 | ], 448 | "source": [ 449 | "bonds_df = []\n", 450 | "for bond in molecule.GetBonds():\n", 451 | " bonds_df.append({\n", 452 | " 'sender': bond.GetBeginAtomIdx(),\n", 453 | " 'receiver': bond.GetEndAtomIdx(),\n", 454 | " 'type': bond.GetBondType().name,\n", 455 | " 'conj': bond.GetIsConjugated(),\n", 456 | " 'ring': bond.IsInRing()\n", 457 | " })\n", 458 | " bonds_df.append({\n", 459 | " 'receiver': bond.GetBeginAtomIdx(),\n", 460 | " 'sender': bond.GetEndAtomIdx(),\n", 461 | " 'type': bond.GetBondType().name,\n", 462 | " 'conj': bond.GetIsConjugated(),\n", 463 | " 'ring': bond.IsInRing()\n", 464 | " })\n", 465 | "bonds_df = pd.DataFrame.from_records(bonds_df, index=['sender', 'receiver'])\n", 466 | "bonds_df.conj = bonds_df.conj * 2. - 1\n", 467 | "bonds_df.ring = bonds_df.ring * 2. - 1\n", 468 | "bonds_df.type = bonds_df.type.astype(bonds)\n", 469 | "bonds_df.head()" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 189, 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [ 478 | "edge_features = torch.tensor(pd.get_dummies(bonds_df, columns=['type']).values, dtype=torch.float)\n", 479 | "senders = torch.tensor(bonds_df.index.get_level_values('sender'))\n", 480 | "receivers = torch.tensor(bonds_df.index.get_level_values('receiver'))" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 191, 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "text/plain": [ 491 | "Graph(n=32, e=68, n_shape=torch.Size([47]), e_shape=torch.Size([6]), g_shape=None)" 492 | ] 493 | }, 494 | "execution_count": 191, 495 | "metadata": {}, 496 | "output_type": "execute_result" 497 | } 498 | ], 499 | "source": [ 500 | "def smiles_to_graph(smiles: str) -> tg.Graph:\n", 501 | " molecule = Chem.MolFromSmiles(df.smiles[0])\n", 502 | " \n", 503 | " atoms_df = []\n", 504 | " for i in range(molecule.GetNumAtoms()):\n", 505 | " atom = molecule.GetAtomWithIdx(i)\n", 506 | " atoms_df.append({\n", 507 | " 'index': i,\n", 508 | " 'symbol': atom.GetSymbol(),\n", 509 | " 'degree': atom.GetDegree(),\n", 510 | " 'hydrogens': atom.GetTotalNumHs(),\n", 511 | " 'impl_valence': atom.GetImplicitValence(),\n", 512 | " })\n", 513 | " atoms_df = pd.DataFrame.from_records(atoms_df, index='index')\n", 514 | " atoms_df.symbol = atoms_df.symbol.astype(symbols)\n", 515 | " \n", 516 | " node_features = torch.tensor(pd.get_dummies(atoms_df, columns=['symbol']).values, dtype=torch.float)\n", 517 | " \n", 518 | " bonds_df = []\n", 519 | " for bond in molecule.GetBonds():\n", 520 | " bonds_df.append({\n", 521 | " 'sender': bond.GetBeginAtomIdx(),\n", 522 | " 'receiver': bond.GetEndAtomIdx(),\n", 523 | " 'type': bond.GetBondType().name,\n", 524 | " 'conj': bond.GetIsConjugated(),\n", 525 | " 'ring': bond.IsInRing()\n", 526 | " })\n", 527 | " bonds_df.append({\n", 528 | " 'receiver': bond.GetBeginAtomIdx(),\n", 529 | " 'sender': bond.GetEndAtomIdx(),\n", 530 | " 'type': bond.GetBondType().name,\n", 531 | " 'conj': bond.GetIsConjugated(),\n", 532 | " 'ring': bond.IsInRing()\n", 533 | " })\n", 534 | " bonds_df = pd.DataFrame.from_records(bonds_df, index=['sender', 'receiver'])\n", 535 | " bonds_df.conj = bonds_df.conj * 2. - 1\n", 536 | " bonds_df.ring = bonds_df.ring * 2. - 1\n", 537 | " bonds_df.type = bonds_df.type.astype(bonds)\n", 538 | " \n", 539 | " edge_features = torch.tensor(pd.get_dummies(bonds_df, columns=['type']).values, dtype=torch.float)\n", 540 | " senders = torch.tensor(bonds_df.index.get_level_values('sender'))\n", 541 | " receivers = torch.tensor(bonds_df.index.get_level_values('receiver'))\n", 542 | " \n", 543 | " return tg.Graph(\n", 544 | " num_nodes=molecule.GetNumAtoms(),\n", 545 | " num_edges=molecule.GetNumBonds() * 2,\n", 546 | " node_features=node_features,\n", 547 | " edge_features=edge_features,\n", 548 | " senders=senders,\n", 549 | " receivers=receivers\n", 550 | " )\n", 551 | "\n", 552 | "smiles_to_graph(df.smiles[0])" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 219, 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "class SolubilityDataset(torch.utils.data.Dataset):\n", 562 | " def __init__(self, path):\n", 563 | " self.df = pd.read_csv(path)\n", 564 | " self.df['molecules'] = self.df.smiles.apply(smiles_to_graph)\n", 565 | "\n", 566 | " def __len__(self):\n", 567 | " return len(self.df)\n", 568 | "\n", 569 | " def __getitem__(self, item):\n", 570 | " mol = self.df['molecules'].iloc[item]\n", 571 | " target = self.df['measured log solubility in mols per litre'].iloc[item]\n", 572 | " return mol, target\n", 573 | " \n", 574 | "sd = SolubilityDataset('../data/delaney-processed.csv')" 575 | ] 576 | } 577 | ], 578 | "metadata": { 579 | "kernelspec": { 580 | "display_name": "Python 3", 581 | "language": "python", 582 | "name": "python3" 583 | }, 584 | "language_info": { 585 | "codemirror_mode": { 586 | "name": "ipython", 587 | "version": 3 588 | }, 589 | "file_extension": ".py", 590 | "mimetype": "text/x-python", 591 | "name": "python", 592 | "nbconvert_exporter": "python", 593 | "pygments_lexer": "ipython3", 594 | "version": "3.7.1" 595 | } 596 | }, 597 | "nbformat": 4, 598 | "nbformat_minor": 2 599 | } 600 | -------------------------------------------------------------------------------- /notebooks/biggraph.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/notebooks/biggraph.pt -------------------------------------------------------------------------------- /resources/sucrose-atoms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/resources/sucrose-atoms.png -------------------------------------------------------------------------------- /resources/sucrose-bonds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/resources/sucrose-bonds.png -------------------------------------------------------------------------------- /resources/sucrose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/resources/sucrose.png -------------------------------------------------------------------------------- /scripts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Conda environment 4 | conda env export | sed '/prefix: .*/ d' > conda.yaml 5 | conda env create -n gn-exp -f conda.yaml 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='graph-network-explainability', 5 | version='0.0.1', 6 | packages=find_packages(where='src'), 7 | package_dir={"": "src"}, 8 | ) 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/__init__.py -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from collections import Mapping, defaultdict 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import yaml 6 | import munch 7 | 8 | 9 | class Config(munch.Munch): 10 | 11 | def __setattr__(self, key, value): 12 | if isinstance(value, Mapping): 13 | value = Config.fromDict(value) 14 | super(Config, self).__setattr__(key, value) 15 | 16 | @staticmethod 17 | def build(*new_configs, **cfg_args): 18 | c = Config() 19 | for new_config in new_configs: 20 | if isinstance(new_config, Mapping): 21 | Config._update_rec(c, new_config) 22 | elif isinstance(new_config, str) and '=' in new_config: 23 | Config._update_rec(c, Config.from_dotted(new_config)) 24 | elif Path(new_config).suffix in {'.yml', '.yaml'}: 25 | Config._update_rec(c, Config.from_yaml(new_config)) 26 | elif Path(new_config).suffix == '.json': 27 | Config._update_rec(c, Config.from_json(new_config)) 28 | Config._update_rec(c, cfg_args) 29 | return c 30 | 31 | @staticmethod 32 | def from_yaml(file: Union[str, Path]): 33 | with open(file, 'r') as f: 34 | return Config.fromDict(yaml.safe_load(f)) 35 | 36 | @staticmethod 37 | def from_json(file: Union[str, Path]): 38 | import json 39 | with open(file, 'r') as f: 40 | return Config.fromDict(json.load(f)) 41 | 42 | @staticmethod 43 | def from_dotted(dotted_str: str): 44 | """Parse a string of named arguments that use dots to indicate hierarchy, e.g. `name=test opts.cpus=4` 45 | """ 46 | 47 | def recursively_defaultdict(): 48 | return defaultdict(recursively_defaultdict) 49 | 50 | config = recursively_defaultdict() 51 | 52 | for name_dotted, value in (pair.split('=') for pair in dotted_str.split(' ')): 53 | c = config 54 | name_head, *name_rest = name_dotted.lstrip('-').split('.') 55 | while len(name_rest) > 0: 56 | c = c[name_head] 57 | name_head, *name_rest = name_rest 58 | c[name_head] = yaml.safe_load(value) 59 | return Config.fromDict(config) 60 | 61 | @staticmethod 62 | def _update_rec(old_config, new_config): 63 | for k in new_config.keys(): 64 | if k in old_config and isinstance(old_config[k], Mapping) and isinstance(new_config[k], Mapping): 65 | Config._update_rec(old_config[k], new_config[k]) 66 | else: 67 | setattr(old_config, k, new_config[k]) 68 | -------------------------------------------------------------------------------- /src/count_nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/count_nodes/__init__.py -------------------------------------------------------------------------------- /src/count_nodes/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | 4 | import torch 5 | from torch.utils import data 6 | 7 | import torchgraphs as tg 8 | 9 | 10 | class NodeCountDataset(data.Dataset): 11 | def __init__(self, min_nodes, max_nodes, num_samples, informative_features, 12 | edge_features_shape, node_features_shape, global_features_shape): 13 | self.num_samples = num_samples 14 | self.min_nodes = min_nodes 15 | self.max_nodes = max_nodes 16 | self.node_features_shape = node_features_shape 17 | self.edge_features_shape = edge_features_shape 18 | self.global_features_shape = global_features_shape 19 | self.informative_features = informative_features 20 | self.samples = [None] * self.num_samples 21 | 22 | def __len__(self): 23 | return self.num_samples 24 | 25 | def __getitem__(self, item): 26 | self.samples[item] = self.samples[item] if self.samples[item] is not None else self._random_sample() 27 | return self.samples[item] 28 | 29 | def _random_sample(self): 30 | num_nodes = np.random.randint(self.min_nodes, self.max_nodes) 31 | num_edges = np.random.randint(0, num_nodes * (num_nodes - 1) + 1) 32 | g_nx = nx.gnm_random_graph(num_nodes, num_edges, directed=True) 33 | g = tg.Graph.from_networkx(g_nx) 34 | g = g.evolve( 35 | node_features=torch.empty(num_nodes, self.node_features_shape).uniform_(-1, 1), 36 | edge_features=torch.empty(num_edges, self.edge_features_shape).uniform_(-1, 1), 37 | global_features=torch.empty(self.global_features_shape).uniform_(-1, 1) 38 | ) 39 | 40 | if self.informative_features > 0: 41 | feats = np.random.rand(num_nodes, self.informative_features) > .5 42 | target = np.logical_and(*feats.transpose()).sum() 43 | g.node_features[:, :self.informative_features] = torch.from_numpy(feats.astype(np.float32)) * 2 - 1 44 | else: 45 | target = g.num_nodes 46 | 47 | return g, target 48 | 49 | 50 | def create_and_save(): 51 | import random 52 | import argparse 53 | from tqdm import tqdm 54 | from pathlib import Path 55 | from config import Config 56 | 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--yaml', nargs='+', default=[]) 59 | parser.add_argument('--dry-run', action='store_true') 60 | args, rest = parser.parse_known_args() 61 | 62 | config = Config() 63 | for y in args.yaml: 64 | config.update_from_yaml(y) 65 | if len(rest) > 0: 66 | if rest[0] != '--': 67 | rest = ' '.join(rest) 68 | print(f"Error: additional config must be separated by '--', got:\n{rest}") 69 | exit(1) 70 | config.update_from_cli(' '.join(rest[1:])) 71 | 72 | print(config.toYAML()) 73 | if args.dry_run: 74 | print('Dry run, exiting.') 75 | exit(0) 76 | del args, rest 77 | 78 | random.seed(config.opts.seed) 79 | np.random.seed(config.opts.seed) 80 | torch.random.manual_seed(config.opts.seed) 81 | 82 | folder = Path(config.opts.folder).expanduser().resolve() / 'data' 83 | folder.mkdir(parents=True, exist_ok=True) 84 | 85 | datasets_common = {k: v for k, v in config.datasets.items() if k not in {'_train_', '_val_', '_test_'}} 86 | for name in ['train', 'val', 'test']: 87 | path = folder / f'{name}.pt' 88 | dataset = NodeCountDataset( 89 | **datasets_common, 90 | **{k: v for k, v in config.datasets.get(f'_{name}_', {}).items()} 91 | ) 92 | samples = len([s for s in tqdm(dataset, desc=name.capitalize(), unit='samples', leave=True)]) 93 | torch.save(dataset, path) 94 | tqdm.write(f'{name.capitalize()}: saved {samples} samples in\t{path}') 95 | 96 | 97 | if __name__ == '__main__': 98 | create_and_save() 99 | -------------------------------------------------------------------------------- /src/count_nodes/layout.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | 6 | from tensorboard import summary as summary_lib 7 | from tensorboard.plugins.custom_scalar import layout_pb2 8 | 9 | layout_summary = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=[ 10 | layout_pb2.Category( 11 | title='Losses', 12 | chart=[ 13 | layout_pb2.Chart( 14 | title='Train', multiline=layout_pb2.MultilineChartContent(tag=['loss/train/mse', 'loss/train/l1'])), 15 | layout_pb2.Chart( 16 | title='Val', multiline=layout_pb2.MultilineChartContent(tag=['loss/train/mse', 'loss/val/mse'])), 17 | ]) 18 | ])) 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('folder', help='The log folder to place the layout in') 22 | args = parser.parse_args() 23 | 24 | folder = (Path(args.folder) / 'layout').expanduser().resolve() 25 | with tf.summary.FileWriter(folder) as writer: 26 | writer.add_summary(layout_summary) 27 | 28 | print('Layout saved to', folder) 29 | -------------------------------------------------------------------------------- /src/count_nodes/networks.py: -------------------------------------------------------------------------------- 1 | import torch_scatter 2 | 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | import torchgraphs as tg 7 | from torchgraphs.utils import segment_lengths_to_ids 8 | 9 | 10 | class FullGN(nn.Module): 11 | def __init__( 12 | self, 13 | in_edge_features_shape, in_node_features_shape, in_global_features_shape, 14 | out_edge_features_shape, out_node_features_shape, out_global_features_shape 15 | ): 16 | super().__init__() 17 | self.f_e = nn.Linear(in_edge_features_shape, out_edge_features_shape) 18 | self.f_s = nn.Linear(in_node_features_shape, out_edge_features_shape) 19 | self.f_r = nn.Linear(in_node_features_shape, out_edge_features_shape) 20 | self.f_u = nn.Linear(in_global_features_shape, out_edge_features_shape) 21 | 22 | self.g_n = nn.Linear(in_node_features_shape, out_node_features_shape) 23 | self.g_in = nn.Linear(out_edge_features_shape, out_node_features_shape) 24 | self.g_out = nn.Linear(out_edge_features_shape, out_node_features_shape) 25 | self.g_u = nn.Linear(in_global_features_shape, out_node_features_shape) 26 | 27 | self.h_n = nn.Linear(out_node_features_shape, out_global_features_shape) 28 | self.h_e = nn.Linear(out_edge_features_shape, out_global_features_shape) 29 | self.h_u = nn.Linear(in_global_features_shape, out_global_features_shape) 30 | 31 | def forward(self, graphs: tg.GraphBatch): 32 | edges = F.relu( 33 | self.f_e(graphs.edge_features) + 34 | self.f_s(graphs.node_features).index_select(dim=0, index=graphs.senders) + 35 | self.f_r(graphs.node_features).index_select(dim=0, index=graphs.receivers) + 36 | tg.utils.repeat_tensor(self.f_u(graphs.global_features), graphs.num_edges_by_graph) 37 | ) 38 | nodes = F.relu( 39 | self.g_n(graphs.node_features) + 40 | self.g_in(torch_scatter.scatter_add(edges, graphs.receivers, dim=0, dim_size=graphs.num_nodes)) + 41 | self.g_out(torch_scatter.scatter_add(edges, graphs.senders, dim=0, dim_size=graphs.num_nodes)) + 42 | tg.utils.repeat_tensor(self.g_u(graphs.global_features), graphs.num_nodes_by_graph) 43 | ) 44 | globals = ( 45 | self.h_e(torch_scatter.scatter_add( 46 | edges, segment_lengths_to_ids(graphs.num_edges_by_graph), dim=0, dim_size=graphs.num_graphs)) + 47 | self.h_n(torch_scatter.scatter_add( 48 | nodes, segment_lengths_to_ids(graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) + 49 | self.h_u(graphs.global_features) 50 | ) 51 | return graphs.evolve( 52 | edge_features=edges, 53 | node_features=nodes, 54 | global_features=globals, 55 | ) 56 | 57 | 58 | class MinimalGN(nn.Module): 59 | def __init__(self, in_node_features_shape, out_node_features_shape, out_global_features_shape): 60 | super().__init__() 61 | self.g_n = nn.Linear(in_node_features_shape, out_node_features_shape) 62 | self.h_n = nn.Linear(out_node_features_shape, out_global_features_shape) 63 | 64 | def forward(self, graphs: tg.GraphBatch): 65 | nodes = F.relu(self.g_n(graphs.node_features)) 66 | globals = self.h_n(torch_scatter.scatter_add( 67 | nodes, segment_lengths_to_ids(graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) 68 | return graphs.evolve( 69 | num_edges=0, 70 | edge_features=None, 71 | node_features=None, 72 | global_features=globals, 73 | senders=None, 74 | receivers=None 75 | ) 76 | -------------------------------------------------------------------------------- /src/count_nodes/notes.md: -------------------------------------------------------------------------------- 1 | # Count Nodes 2 | 3 | ## Task 4 | Count the number of nodes in the graph. It's a case of simple aggregation from nodes to global. 5 | 6 | Cases: 7 | 1. Node features are all random, just count the number of nodes 8 | 2. Node features are a combination of informative features and random features. Only some of them are considered 9 | active and therefore counted. Specifically, if values of the informative features of a node are **all** 1, 10 | that node should be counted, otherwise is skipped. 11 | 12 | ## Generalization strategy 13 | During training the network only sees _small_ graphs, i.e. graphs whose node count is below a threshold. 14 | During testing the network is presented larger graphs, with a number of nodes outside the range used for training. 15 | 16 | ## Weighting function 17 | Weighting is only necessary for case 2: 18 | - The number of nodes is selected at random as `n~Uniform(0, max_nodes)` 19 | - Out of `K` features, `I` are considered independently informative and set to 0 or 1 with prob 0.5 20 | - The `I` informative features are evaluated through a logical AND, i.e. all must be 1 for the node to be counted 21 | - For this reason, the target values for the dataset will follow a binomial distribution 22 | 23 | ```t~Binomial(max_nodes, 0.5^I)``` 24 | 25 | with mean `max_nodes * 0.5^I` and variance `max_nodes * 0.5^I (1 - 0.5^I)` 26 | - Graphs that have a very small or very large number or active nodes are rare and should be weighted more. 27 | - A suitable weighting function is the negative log-likelihood of the target value: 28 | 29 | ```- ln Binomial(target | max_nodes, 0.5^I)``` 30 | 31 | ## Full network 32 | The full version of the graph network operates as such: 33 | ``` 34 | e_{s \to r}^{t+1} = ReLU [ f_e(e_{s \to r}^t) + f_s(n_s^t) + f_r(n_r^t) + f_u(u^t)] 35 | 36 | n_i^{t+1} = ReLU [ g_n(n_i^t) + g_{in}(agg_s(e_{s \to i}^{t+1})) + g_{out}(agg_r(e_{i \to r}^{t+1})) + g_u(u^t)] 37 | 38 | u_i^{t+1} = h_n(agg_i(n_i^{t+1})) + h_e(agg_{ij}(e_{i \to j}^{t+1})) + h_u(u^t) 39 | ``` 40 | 41 | using summation as an aggregation function 42 | 43 | ## Minimal network 44 | For this task, the minimal version of the network should only use: 45 | ``` 46 | n_i^{t+1} = ReLU [ g_n(n_i^t) ] 47 | 48 | u_i^{t+1} = h_n(agg_i(n_i^{t+1})) 49 | ``` 50 | 51 | ## Workflow 52 | 53 | 1. Create base folder 54 | ```bash 55 | COUNT_NODES=~/experiments/count-nodes/ 56 | mkdir -p "$COUNT_NODES/{runs,data}" 57 | ``` 58 | 2. Create dataset 59 | ```bash 60 | python -m count_nodes.dataset --yaml 61 | ``` 62 | 3. Create tensorboard layout 63 | ```bash 64 | python -m count_nodes.layout --folder "$COUNT_NODES/runs" 65 | ``` 66 | 4. Launch experiments 67 | ```bash 68 | python -m count_nodes.train --yaml -- [] 69 | 70 | conda activate tg-experiments 71 | for i in 1 2 3 4 5; do 72 | for type in full minimal; do 73 | for lr in .01 .001; do 74 | for wd in 0 .01 .001; do 75 | python -m count_nodes.train --yaml ../config/count_nodes/{train,${type}}.yml -- \ 76 | opts.session=${type}_lr${lr}_wd${wd}_${i} \ 77 | optimizer.lr=${lr} \ 78 | training.l1=${wd} training.epochs=40 79 | done 80 | done 81 | done 82 | done 83 | ``` 84 | 5. Visualize 85 | ```bash 86 | tensorboard --logdir "$COUNT_NODES/runs" 87 | ``` -------------------------------------------------------------------------------- /src/count_nodes/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import textwrap 4 | 5 | import pandas as pd 6 | from pathlib import Path 7 | 8 | import tqdm 9 | import numpy as np 10 | import torchgraphs as tg 11 | from munch import Munch 12 | from scipy import stats 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn, optim 17 | from torch.utils import data 18 | from tensorboardX import SummaryWriter 19 | 20 | from count_nodes.dataset import NodeCountDataset 21 | from saver import Saver 22 | from utils import load_class 23 | from config import Config 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--yaml', nargs='+', default=[]) 27 | parser.add_argument('--dry-run', action='store_true') 28 | args, rest = parser.parse_known_args() 29 | 30 | config = Config() 31 | for y in args.yaml: 32 | config.update_from_yaml(y) 33 | if len(rest) > 0: 34 | if rest[0] != '--': 35 | rest = ' '.join(rest) 36 | print(f"Error: additional config must be separated by '--', got:\n{rest}") 37 | exit(1) 38 | config.update_from_cli(' '.join(rest[1:])) 39 | 40 | print('Config summary:', config.toYAML(), sep='\n') 41 | if args.dry_run: 42 | print('Dry run, exiting.') 43 | exit(0) 44 | del args, rest 45 | 46 | random.seed(config.opts.seed) 47 | np.random.seed(config.opts.seed) 48 | torch.random.manual_seed(config.opts.seed) 49 | 50 | folder_base = (Path(config.opts.folder)).expanduser().resolve() 51 | folder_data = folder_base / 'data' 52 | folder_run = folder_base / 'runs' / config.opts.session 53 | saver = Saver(folder_run) 54 | logger = SummaryWriter(folder_run.as_posix()) 55 | 56 | ModelClass = load_class(config.model._class_) 57 | net: nn.Module = ModelClass(**{k: v for k, v in config.model.items() if k != '_class_'}) 58 | net.to(config.opts.device) 59 | 60 | OptimizerClass = load_class(config.optimizer._class_) 61 | optimizer: optim.Optimizer = OptimizerClass(params=net.parameters(), 62 | **{k: v for k, v in config.optimizer.items() if k != '_class_'}) 63 | 64 | if config.training.restore: 65 | train_state = saver.load(model=net, optimizer=optimizer, device=config.training.device) 66 | else: 67 | train_state = Munch(epochs=0, samples=0) 68 | 69 | if config.opts.log: 70 | with open(folder_run / 'config.yml', mode='w') as f: 71 | f.write(config.toYAML()) 72 | logger.add_text( 73 | 'Config', 74 | textwrap.indent(config.toYAML(), ' '), 75 | global_step=train_state.samples) 76 | 77 | 78 | def make_dataloader(dataset, shuffle) -> data.DataLoader: 79 | return data.DataLoader( 80 | dataset, 81 | batch_size=config.training.batch_size, 82 | collate_fn=tg.GraphBatch.collate, 83 | num_workers=config.opts.cpus, 84 | shuffle=shuffle, 85 | pin_memory='cuda' in str(config.opts.device), 86 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)) 87 | ) 88 | 89 | 90 | dataset_train: NodeCountDataset = torch.load(folder_data / 'train.pt') 91 | dataset_val: NodeCountDataset = torch.load(folder_data / 'val.pt') 92 | dataset_test: NodeCountDataset = torch.load(folder_data / 'test.pt') 93 | dataloader_train = make_dataloader(dataset_train, shuffle=True) 94 | dataloader_val = make_dataloader(dataset_val, shuffle=False) 95 | dataloader_test = make_dataloader(dataset_test, shuffle=False) 96 | 97 | if dataset_train.informative_features > 0: 98 | binom = stats.binom(n=dataset_train.max_nodes, p=.5 ** dataset_train.informative_features) 99 | 100 | 101 | def weight_fn(targets): 102 | return targets.new_tensor(- binom.logpmf(targets.cpu().numpy())) 103 | else: 104 | def weight_fn(targets): 105 | return torch.ones_like(targets) 106 | 107 | epoch_bar_postfix = {} 108 | epoch_start = train_state.epochs + 1 109 | epoch_end = train_state.epochs + 1 + config.training.epochs 110 | epoch_bar = tqdm.trange(epoch_start, epoch_end, desc='Training', unit='e', leave=True) 111 | for epoch in epoch_bar: 112 | # Training loop 113 | net.train() 114 | loss_mse_train = 0 115 | train_bar_postfix = {} 116 | with tqdm.tqdm(desc=f'Train {epoch}', total=dataloader_train.dataset.num_samples, unit='g') as train_bar: 117 | for graphs, targets in dataloader_train: 118 | graphs = graphs.to(config.opts.device) 119 | targets = targets.float().to(config.opts.device) 120 | weights = weight_fn(targets) 121 | 122 | loss_total = 0 123 | results = net(graphs).global_features.squeeze() 124 | losses_mse = F.mse_loss(results, targets, reduction='none') 125 | losses_mse_weighted = losses_mse * weights 126 | loss_total += losses_mse_weighted.mean() 127 | 128 | if config.training.l1 > 0: 129 | loss_l1 = sum([p.abs().sum() for p in net.parameters()]) * config.training.l1 130 | loss_total += loss_l1 131 | train_bar_postfix['L1'] = f'{loss_l1.item():.5f}' 132 | if config.opts.log: 133 | logger.add_scalar('loss/train/l1', loss_l1.item(), global_step=train_state.samples) 134 | 135 | optimizer.zero_grad() 136 | loss_total.backward() 137 | optimizer.step(closure=None) 138 | 139 | train_state.samples += graphs.num_graphs 140 | loss_mse_train += losses_mse.sum().item() 141 | 142 | train_bar.update(graphs.num_graphs) 143 | train_bar_postfix['MSE'] = f'{losses_mse.mean().item():.5f}' 144 | train_bar.set_postfix(train_bar_postfix) 145 | if config.opts.log: 146 | logger.add_scalar('loss/train/all', loss_total.item(), global_step=train_state.samples) 147 | logger.add_scalar('loss/train/mse', losses_mse.mean().item(), global_step=train_state.samples) 148 | epoch_bar_postfix['train/mse'] = f'{loss_mse_train / dataloader_train.dataset.num_samples:.5f}' 149 | epoch_bar.set_postfix(epoch_bar_postfix) 150 | 151 | # Saving 152 | train_state.epochs += 1 153 | if config.training.save_every > 0 and epoch % config.training.save_every == 0: 154 | saver.save(name=epoch, model=net, optimizer=optimizer, **train_state) 155 | 156 | # Validation loop 157 | net.eval() 158 | loss_mse_val = 0 159 | with torch.no_grad(): 160 | with tqdm.tqdm(desc=f'Val {epoch}', total=dataloader_val.dataset.num_samples, unit='g') as val_bar: 161 | for graphs, targets in dataloader_val: 162 | graphs = graphs.to(config.opts.device) 163 | targets = targets.float().to(config.opts.device) 164 | 165 | results = net(graphs).global_features.squeeze() 166 | losses_mse = F.mse_loss(results, targets, reduction='none') 167 | 168 | loss_mse_val += losses_mse.sum().item() 169 | val_bar.update(graphs.num_graphs) 170 | val_bar.set_postfix_str(f'MSE: {losses_mse.mean().item():.5f}') 171 | if config.opts.log: 172 | logger.add_scalar( 173 | 'loss/val/mse', loss_mse_val / dataloader_val.dataset.num_samples, global_step=train_state.samples) 174 | epoch_bar_postfix['val/mse'] = f'{loss_mse_val / dataloader_val.dataset.num_samples:.5f}' 175 | epoch_bar.set_postfix(epoch_bar_postfix) 176 | epoch_bar.close() 177 | 178 | net.eval() 179 | df = {k: [] for k in ['Loss', 'Nodes', 'Edges', 'Predict', 'Target']} 180 | with torch.no_grad(): 181 | with tqdm.tqdm(desc='Test', total=dataloader_test.dataset.num_samples, leave=True, unit='g') as test_bar: 182 | for graphs, targets in dataloader_test: 183 | graphs = graphs.to(config.opts.device) 184 | targets = targets.float().to(config.opts.device) 185 | 186 | results = net(graphs).global_features.squeeze() 187 | losses_mse = F.mse_loss(results, targets, reduction='none') 188 | 189 | df['Loss'].append(losses_mse.cpu().numpy()) 190 | df['Nodes'].append(graphs.num_nodes_by_graph.cpu().numpy()) 191 | df['Edges'].append(graphs.num_edges_by_graph.cpu().numpy()) 192 | df['Target'].append(targets.int().cpu().numpy()) 193 | df['Predict'].append(results.cpu().numpy()) 194 | 195 | test_bar.update(graphs.num_graphs) 196 | test_bar.set_postfix_str(f'MSE: {losses_mse.mean().item():.5f}') 197 | df = pd.DataFrame({k: np.concatenate(v) for k, v in df.items()}).rename_axis('GraphId').reset_index() 198 | 199 | # Split the results based on whether the number of nodes was present in the training set or not 200 | df_train_test = df \ 201 | .groupby(np.where(df.Nodes < dataset_train.max_nodes, 202 | f'Train [{dataset_train.min_nodes}, {dataset_train.max_nodes - 1})', 203 | f'Test [{dataset_train.max_nodes}, {dataset_test.max_nodes - 1})')) \ 204 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'Loss': 'mean'}) \ 205 | .sort_index(ascending=False) \ 206 | .rename_axis(index='Dataset') \ 207 | .rename(str.capitalize, axis='columns', level=1) 208 | 209 | # Split the results in ranges based on the number of nodes and compute the average loss per range 210 | df_losses_by_node_range = df \ 211 | .groupby(df.Nodes // 10) \ 212 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'Loss': 'mean'}) \ 213 | .rename_axis(index='NodeRange') \ 214 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index') \ 215 | .rename(str.capitalize, axis='columns', level=1) 216 | 217 | # Split the results in ranges based on the number of nodes and compute the average loss per range 218 | df_worst_graphs_by_node_range = df \ 219 | .groupby(df.Nodes // 10) \ 220 | .apply(lambda df_gr: df_gr.nlargest(5, 'Loss').set_index('GraphId')) \ 221 | .rename_axis(index={'Nodes': 'NodeRange'}) \ 222 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0) 223 | 224 | print( 225 | df_train_test.to_string(float_format=lambda x: f'{x:.2f}'), 226 | df_losses_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'), 227 | df_worst_graphs_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'), 228 | sep='\n\n') 229 | if config.opts.log: 230 | logger.add_text( 231 | 'Generalization', 232 | textwrap.indent(df_train_test.to_string(float_format=lambda x: f'{x:.2f}'), ' '), 233 | global_step=train_state.samples) 234 | logger.add_text( 235 | 'Loss by range', 236 | textwrap.indent(df_losses_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'), ' '), 237 | global_step=train_state.samples) 238 | logger.add_text( 239 | 'Samples', 240 | textwrap.indent(df_worst_graphs_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'), ' '), 241 | global_step=train_state.samples) 242 | 243 | params = [f'{name}:\n{param.data.cpu().numpy().round(3)}' for name, param in net.named_parameters()] 244 | print('Parameters:', *params, sep='\n\n') 245 | if config.opts.log: 246 | logger.add_text('Parameters', textwrap.indent('\n\n'.join(params), ' '), global_step=train_state.samples) 247 | 248 | logger.close() 249 | -------------------------------------------------------------------------------- /src/guidedbackprop/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO split these into separate modules 2 | from .autograd_tricks import add, sum, \ 3 | cat, index_select, repeat_tensor, \ 4 | scatter_add, scatter_mean, scatter_max, \ 5 | linear, relu, get_aggregation 6 | from .graphs import EdgeLinearGuidedBP, NodeLinearGuidedBP, GlobalLinearGuidedBP, \ 7 | EdgeReLUGuidedBP, NodeReLUGuidedBP, GlobalReLUGuidedBP 8 | -------------------------------------------------------------------------------- /src/guidedbackprop/autograd_tricks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch_scatter 5 | 6 | 7 | def hook(grad): 8 | return grad.clamp(min=0) 9 | 10 | 11 | def add(a, b): 12 | out = torch.add(a, b) 13 | out.register_hook(hook) 14 | return out 15 | 16 | 17 | def sum(tensor, dim=None, keepdim=False): 18 | out = torch.sum(tensor, dim=dim, keepdim=keepdim) 19 | out.register_hook(hook) 20 | return out 21 | 22 | 23 | def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): 24 | out = torch_scatter.scatter_add(src, index, dim, out, dim_size, fill_value) 25 | out.register_hook(hook) 26 | return out 27 | 28 | 29 | def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0): 30 | out = torch_scatter.scatter_mean(src, index, dim, out, dim_size, fill_value) 31 | out.register_hook(hook) 32 | return out 33 | 34 | 35 | def scatter_max(src, index, dim=-1, dim_size=None, fill_value=0): 36 | out = torch_scatter.scatter_max(src, index, dim, None, dim_size, fill_value) 37 | out[0].register_hook(hook) 38 | return out 39 | 40 | 41 | def linear(input, weight, bias=None): 42 | out = torch.nn.functional.linear(input, weight, bias) 43 | out.register_hook(hook) 44 | return out 45 | 46 | 47 | def index_select(src, dim, index): 48 | out = torch.index_select(src, dim, index) 49 | out.register_hook(hook) 50 | return out 51 | 52 | 53 | def cat(tensors, dim=0): 54 | out = torch.cat(tensors, dim) 55 | out.register_hook(hook) 56 | return out 57 | 58 | 59 | def repeat_tensor(src, repeats, dim=0): 60 | idx = src.new_tensor(np.arange(len(repeats)).repeat(repeats.cpu().numpy()), dtype=torch.long) 61 | return index_select(src, dim, idx) 62 | 63 | 64 | def relu(input): 65 | out = torch.nn.functional.relu(input) 66 | out.register_hook(hook) 67 | return out 68 | 69 | 70 | def get_aggregation(name): 71 | if name in ('add', 'sum'): 72 | return scatter_add 73 | elif name in ('mean', 'avg'): 74 | return scatter_mean 75 | elif name == 'max': 76 | from functools import wraps 77 | 78 | @wraps(scatter_max) 79 | def wrapper(*args, **kwargs): 80 | return scatter_max(*args, **kwargs)[0] 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /src/guidedbackprop/graphs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchgraphs as tg 3 | 4 | from . import autograd_tricks as guidedbp 5 | 6 | 7 | class EdgeLinearGuidedBP(tg.EdgeLinear): 8 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: 9 | new_edges = torch.tensor(0) 10 | 11 | if self.W_edge is not None: 12 | new_edges = guidedbp.add(new_edges, guidedbp.linear(graphs.edge_features, self.W_edge)) 13 | if self.W_sender is not None: 14 | new_edges = guidedbp.add( 15 | new_edges, 16 | guidedbp.index_select(guidedbp.linear(graphs.node_features, self.W_sender), 17 | dim=0, index=graphs.senders) 18 | ) 19 | if self.W_receiver is not None: 20 | new_edges = guidedbp.add( 21 | new_edges, 22 | guidedbp.index_select(guidedbp.linear(graphs.node_features, self.W_receiver), 23 | dim=0, index=graphs.receivers) 24 | ) 25 | if self.W_global is not None: 26 | new_edges = guidedbp.add( 27 | new_edges, 28 | guidedbp.repeat_tensor(guidedbp.linear(graphs.global_features, self.W_global), 29 | dim=0, repeats=graphs.num_edges_by_graph) 30 | ) 31 | if self.bias is not None: 32 | new_edges = guidedbp.add(new_edges, self.bias) 33 | 34 | return graphs.evolve(edge_features=new_edges) 35 | 36 | 37 | class NodeLinearGuidedBP(tg.NodeLinear): 38 | def __init__(self, out_features, node_features=None, incoming_features=None, outgoing_features=None, 39 | global_features=None, aggregation=None, bias=True): 40 | super(NodeLinearGuidedBP, self).__init__(out_features, node_features, incoming_features, 41 | outgoing_features, global_features, 42 | guidedbp.get_aggregation(aggregation), 43 | bias) 44 | 45 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: 46 | new_nodes = torch.tensor(0) 47 | 48 | if self.W_node is not None: 49 | new_nodes = guidedbp.add( 50 | new_nodes, 51 | guidedbp.linear(graphs.node_features, self.W_node) 52 | ) 53 | if self.W_incoming is not None: 54 | new_nodes = guidedbp.add( 55 | new_nodes, 56 | guidedbp.linear( 57 | self.aggregation(graphs.edge_features, dim=0, index=graphs.receivers, dim_size=graphs.num_nodes), 58 | self.W_incoming) 59 | ) 60 | if self.W_outgoing is not None: 61 | new_nodes = guidedbp.add( 62 | new_nodes, 63 | guidedbp.linear( 64 | self.aggregation(graphs.edge_features, dim=0, index=graphs.senders, dim_size=graphs.num_nodes), 65 | self.W_outgoing) 66 | ) 67 | if self.W_global is not None: 68 | new_nodes = guidedbp.add( 69 | new_nodes, 70 | guidedbp.repeat_tensor(guidedbp.linear(graphs.global_features, self.W_global), dim=0, 71 | repeats=graphs.num_nodes_by_graph) 72 | ) 73 | if self.bias is not None: 74 | new_nodes = guidedbp.add(new_nodes, self.bias) 75 | 76 | return graphs.evolve(node_features=new_nodes) 77 | 78 | 79 | class GlobalLinearGuidedBP(tg.GlobalLinear): 80 | def __init__(self, out_features, node_features=None, edge_features=None, global_features=None, 81 | aggregation=None, bias=True): 82 | super(GlobalLinearGuidedBP, self).__init__(out_features, node_features, edge_features, 83 | global_features, guidedbp.get_aggregation(aggregation), bias) 84 | 85 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: 86 | new_globals = torch.tensor(0) 87 | 88 | if self.W_node is not None: 89 | index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph) 90 | new_globals = guidedbp.add( 91 | new_globals, 92 | guidedbp.linear(self.aggregation(graphs.node_features, dim=0, index=index, dim_size=graphs.num_graphs), 93 | self.W_node) 94 | ) 95 | if self.W_edges is not None: 96 | index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph) 97 | new_globals = guidedbp.add( 98 | new_globals, 99 | guidedbp.linear(self.aggregation(graphs.edge_features, dim=0, index=index, dim_size=graphs.num_graphs), 100 | self.W_edges) 101 | ) 102 | if self.W_global is not None: 103 | new_globals = guidedbp.add( 104 | new_globals, 105 | guidedbp.linear(graphs.global_features, self.W_global) 106 | ) 107 | if self.bias is not None: 108 | new_globals = guidedbp.add(new_globals, self.bias) 109 | 110 | return graphs.evolve(global_features=new_globals) 111 | 112 | 113 | class EdgeReLUGuidedBP(tg.EdgeFunction): 114 | def __init__(self): 115 | super(EdgeReLUGuidedBP, self).__init__(guidedbp.relu) 116 | 117 | 118 | class NodeReLUGuidedBP(tg.NodeFunction): 119 | def __init__(self): 120 | super(NodeReLUGuidedBP, self).__init__(guidedbp.relu) 121 | 122 | 123 | class GlobalReLUGuidedBP(tg.GlobalFunction): 124 | def __init__(self): 125 | super(GlobalReLUGuidedBP, self).__init__(guidedbp.relu) 126 | -------------------------------------------------------------------------------- /src/infection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/infection/__init__.py -------------------------------------------------------------------------------- /src/infection/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Tuple 3 | 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | import networkx as nx 8 | import torchgraphs as tg 9 | 10 | 11 | class InfectionDataset(torch.utils.data.Dataset): 12 | def __init__(self, max_percent_immune, max_percent_sick, max_percent_virtual, min_nodes, max_nodes): 13 | if max_percent_sick + max_percent_immune > 1: 14 | raise ValueError(f"Cannot have a population with `max_percent_sick`={max_percent_sick}" 15 | f"and `max_percent_immune`={max_percent_immune}") 16 | self.min_nodes = min_nodes 17 | self.max_nodes = max_nodes 18 | self.max_percent_immune = max_percent_immune 19 | self.max_percent_sick = max_percent_sick 20 | self.max_percent_virtual = max_percent_virtual 21 | self.node_features_shape = 4 22 | self.edge_features_shape = 2 23 | self.samples: List[Tuple[tg.Graph, tg.Graph]] = [] 24 | 25 | def __len__(self): 26 | return len(self.samples) 27 | 28 | def __getitem__(self, item): 29 | return self.samples[item] 30 | 31 | def random_sample(self): 32 | num_nodes = np.random.randint(self.min_nodes, self.max_nodes) 33 | g_nx = nx.barabasi_albert_graph(num_nodes, 2).to_directed() 34 | 35 | # Remove some edges 36 | num_edges = int(.7 * g_nx.number_of_edges()) 37 | edges_to_remove = np.random.choice(g_nx.number_of_edges(), 38 | size=g_nx.number_of_edges() - num_edges, replace=False) 39 | edges_to_remove = [list(g_nx.edges)[i] for i in edges_to_remove] 40 | g_nx.remove_edges_from(edges_to_remove) 41 | 42 | # Create node features: sick (at least one), immune and at risk 43 | num_sick = np.random.randint(1, max(1, int(num_nodes * self.max_percent_sick)) + 1) 44 | num_immune = np.random.randint(0, int(num_nodes * self.max_percent_immune) + 1) 45 | sick, immune, atrisk = np.split(g_nx.nodes, [num_sick, num_sick + num_immune]) 46 | 47 | node_features = torch.empty(num_nodes, self.node_features_shape) 48 | node_features[:, :2] = -1 49 | node_features[sick, 0] = 1 50 | node_features[immune, 1] = 1 51 | node_features[:, 2:].uniform_(-1, 1) 52 | 53 | # Create edge features: in person and virtual 54 | num_virtual = np.random.randint(0, int(num_edges * self.max_percent_virtual) + 1) 55 | virtual = np.random.randint(0, num_edges, size=num_virtual) 56 | edge_features = torch.empty(num_edges, self.edge_features_shape) 57 | edge_features[:, 0] = -1 58 | edge_features[virtual, 0] = 1 59 | edge_features[:, 1:].uniform_(-1, 1) 60 | 61 | g = tg.Graph.from_networkx(g_nx).evolve(node_features=node_features, edge_features=edge_features) 62 | 63 | # Create target by spreading the infection to the non-virtual neighbors who are not immune 64 | virtual = {list(g_nx.edges)[i] for i in virtual} 65 | infected = list({ 66 | infection_target for infection_src in sick for infection_target in g_nx.neighbors(infection_src) 67 | if infection_target not in immune and (infection_src, infection_target) not in virtual 68 | }) 69 | 70 | target = torch.zeros((num_nodes, 1), dtype=torch.int8) 71 | target[sick] = 1 72 | target[infected] = 1 73 | target = tg.Graph(node_features=target, global_features=target.sum(dim=0)) 74 | 75 | return g, target 76 | 77 | 78 | def generate(cfg): 79 | from tqdm import trange 80 | from utils import set_seeds 81 | 82 | cfg.setdefault('seed', 0) 83 | set_seeds(cfg.seed) 84 | print(f'Random seed: {cfg.seed}') 85 | 86 | folder = Path(cfg.folder).expanduser().resolve() 87 | folder.mkdir(parents=True, exist_ok=True) 88 | print(f'Saving datasets in: {folder}') 89 | 90 | with open(folder / 'datasets.yaml', 'w') as f: 91 | f.write(cfg.toYAML()) 92 | for p, params in cfg.datasets.items(): 93 | dataset = InfectionDataset(**{k: v for k, v in params.items() if k != 'num_samples'}) 94 | dataset.samples = [dataset.random_sample() for _ in 95 | trange(params.num_samples, desc=p.capitalize(), unit='samples', leave=True)] 96 | path = folder.joinpath(p).with_suffix('.pt') 97 | torch.save(dataset, path) 98 | print(f'{p.capitalize()}: saved {len(dataset)} samples in: {path}') 99 | 100 | 101 | def describe(cfg): 102 | import pandas as pd 103 | target = Path(cfg.target).expanduser().resolve() 104 | if target.is_dir(): 105 | paths = target.glob('*.pt') 106 | else: 107 | paths = [target] 108 | for p in paths: 109 | print(f"Loading dataset from: {p}") 110 | dataset = torch.load(p) 111 | if not isinstance(dataset, InfectionDataset): 112 | raise ValueError(f'Not an InfectionDataset: {p}') 113 | print(f"{p.with_suffix('').name.capitalize()} contains:\n" 114 | f"min_nodes: {dataset.min_nodes}\n" 115 | f"max_nodes: {dataset.max_nodes}\n" 116 | f"max_percent_immune: {dataset.max_percent_immune}\n" 117 | f"max_percent_sick: {dataset.max_percent_sick}\n" 118 | f"node_features_shape: {dataset.node_features_shape}\n" 119 | f"edge_features_shape: {dataset.edge_features_shape}\n" 120 | f"samples: {len(dataset)}") 121 | df = pd.DataFrame.from_records( 122 | { 123 | 'num_nodes': g.num_nodes, 124 | 'num_edges': g.num_edges, 125 | 'degree': g.degree.float().mean().item(), 126 | 'infected': g.node_features[:, 0].sum().item() / g.num_nodes, 127 | 'immune': g.node_features[:, 1].sum().item() / g.num_nodes, 128 | 'infected_post': t.node_features[:, 0].sum().item() / t.num_nodes, 129 | } for g, t in dataset.samples) 130 | print(f'\n{df.describe()}') 131 | 132 | 133 | def main(): 134 | from argparse import ArgumentParser 135 | from config import Config 136 | 137 | parser = ArgumentParser() 138 | subparsers = parser.add_subparsers() 139 | 140 | sp_print = subparsers.add_parser('print', help='Print parsed configuration') 141 | sp_print.add_argument('config', nargs='*') 142 | sp_print.set_defaults(command=lambda c: print(c.toYAML())) 143 | 144 | sp_generate = subparsers.add_parser('generate', help='Generate new datasets') 145 | sp_generate.add_argument('config', nargs='*') 146 | sp_generate.set_defaults(command=generate) 147 | 148 | sp_describe = subparsers.add_parser('describe', help='Describe existing datasets') 149 | sp_describe.add_argument('config', nargs='*') 150 | sp_describe.set_defaults(command=describe) 151 | 152 | args = parser.parse_args() 153 | cfg = Config.build(*args.config) 154 | args.command(cfg) 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /src/infection/layout.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | 6 | from tensorboard import summary as summary_lib 7 | from tensorboard.plugins.custom_scalar import layout_pb2 8 | 9 | layout_summary = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=[ 10 | layout_pb2.Category( 11 | title='losses', 12 | chart=[ 13 | # Chart 'losses' (include all losses, exclude upper and lower bounds) 14 | layout_pb2.Chart( 15 | title='losses', 16 | multiline=layout_pb2.MultilineChartContent( 17 | tag=[ 18 | r'loss(?!.*bound.*)' 19 | ] 20 | ) 21 | ), 22 | ]) 23 | ])) 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('folder', help='The log folder to place the layout in') 27 | args = parser.parse_args() 28 | 29 | folder = (Path(args.folder) / 'layout').expanduser().resolve() 30 | with tf.summary.FileWriter(folder) as writer: 31 | writer.add_summary(layout_summary) 32 | 33 | print('Layout saved to', folder) 34 | -------------------------------------------------------------------------------- /src/infection/networks.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import torchgraphs as tg 8 | 9 | 10 | class InfectionGN(nn.Module): 11 | def __init__(self, aggregation, bias): 12 | super().__init__() 13 | self.encoder = nn.Sequential(OrderedDict({ 14 | 'edge': tg.EdgeLinear(4, edge_features=2, bias=bias), 15 | 'edge_relu': tg.EdgeReLU(), 16 | 'node': tg.NodeLinear(8, node_features=4, bias=bias), 17 | 'node_relu': tg.NodeReLU(), 18 | })) 19 | self.hidden = nn.Sequential(OrderedDict({ 20 | 'edge': tg.EdgeLinear(8, edge_features=4, sender_features=8, bias=bias), 21 | 'edge_relu': tg.EdgeReLU(), 22 | 'node': tg.NodeLinear(8, node_features=8, incoming_features=8, aggregation=aggregation, bias=bias), 23 | 'node_relu': tg.NodeReLU() 24 | })) 25 | self.readout_nodes = tg.NodeLinear(1, node_features=8, bias=True) 26 | self.readout_globals = tg.GlobalLinear(1, node_features=8, aggregation='sum', bias=bias) 27 | 28 | def forward(self, graphs): 29 | graphs = self.encoder(graphs) 30 | graphs = self.hidden(graphs) 31 | nodes = self.readout_nodes(graphs).node_features 32 | globals = self.readout_globals(graphs).global_features 33 | 34 | return graphs.evolve( 35 | node_features=nodes, 36 | num_edges=0, 37 | num_edges_by_graph=None, 38 | edge_features=None, 39 | global_features=globals, 40 | senders=None, 41 | receivers=None 42 | ) 43 | 44 | 45 | def _reset_parameters(module): 46 | for name, param in module.named_parameters(): 47 | if 'bias' in name: 48 | bound = 1 / math.sqrt(param.numel()) 49 | nn.init.uniform_(param, -bound, bound) 50 | else: 51 | nn.init.kaiming_uniform_(param, a=math.sqrt(5)) 52 | 53 | 54 | def describe(cfg): 55 | from pathlib import Path 56 | from utils import import_ 57 | klass = import_(cfg.model.klass) 58 | model = klass(*cfg.model.args, **cfg.model.kwargs) 59 | if 'state_dict' in cfg: 60 | model.load_state_dict(torch.load(Path(cfg.state_dict).expanduser().resolve())) 61 | print(model) 62 | print(f'Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}') 63 | 64 | for name, parameter in model.named_parameters(): 65 | print(f'{name} {tuple(parameter.shape)}:') 66 | if 'state_dict' in cfg: 67 | print(parameter.numpy().round()) 68 | print() 69 | 70 | 71 | def main(): 72 | from argparse import ArgumentParser 73 | from config import Config 74 | 75 | parser = ArgumentParser() 76 | subparsers = parser.add_subparsers() 77 | 78 | sp_print = subparsers.add_parser('print', help='Print parsed configuration') 79 | sp_print.add_argument('config', nargs='*') 80 | sp_print.set_defaults(command=lambda c: print(c.toYAML())) 81 | 82 | sp_describe = subparsers.add_parser('describe', help='Describe a model') 83 | sp_describe.add_argument('config', nargs='*') 84 | sp_describe.set_defaults(command=describe) 85 | 86 | args = parser.parse_args() 87 | cfg = Config.build(*args.config) 88 | args.command(cfg) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /src/infection/notes.md: -------------------------------------------------------------------------------- 1 | # Infection 2 | 3 | ## Task 4 | 1. One node of the graph is _infected_ and identified through a 1 in one element of its feature vector. 5 | The remaining features are unused. The infection is spread to the neighbors (directed) of the infected nodes. 6 | 7 | 2. One node of the graph is _infected_ and identified through a 1 in the first element of its feature vector. Some nodes 8 | are immune and identified through a 1 in the second element of their feature vectors. 9 | Again, the infection is spread to the neighbors, but some of them are immune. 10 | 11 | The network should output a prediction of 1 for nodes that are infected and 0 for the others, effectively 12 | identifying the neighbors of the infected node, the connection type and the immunity status. 13 | The network also outputs a graph-level prediction that should correspond to the total number of infected nodes. 14 | 15 | ## Data 16 | 17 | ### Training 18 | 100,000 graphs are generated using the [Barabási–Albert model](https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model), 19 | every graph contains between 10 and 30 nodes. Up to 10% of the nodes are sick and up to 30% are immune. Edges are virtual with a percentage of up to 30%. 20 | 21 | ### Testing 22 | 20,000 graphs for testing are generated in a similar way, but containing between 10 and 60 nodes. 23 | Up to 40% of the nodes are sick and up to 60% are immune. Edges are virtual with a percentage of up to 50%. 24 | 25 | ## Losses 26 | 27 | Three losses are used for training: a node-level loss, a global-level loss and a regularization term. 28 | The three terms are added together in a weighted sum and constitute the final training objective. 29 | 30 | ### Node-level classification 31 | At node-level, the network has to output a single binary prediction of whether the node will be sick or not. 32 | The loss on this prediction is computed as Binary Cross Entropy. 33 | 34 | ### Global-level regression 35 | The network should also output a global-level prediction corresponding to the total number of infected nodes. 36 | The loss on this prediction is computed as Mean Squared Error 37 | 38 | Since the total number of infected nodes in the training set is not homogeneously distributed, we weight the losses 39 | computed on individual graphs using the negative log frequency of the true value. For example, if 15 is the 40 | ground-truth number of infected nodes after the spread for a given input graph, we weight the MSE of the network's 41 | prediction as `-ln(# of graphs with 15 infected nodes in the training set / # number of graphs in the training set)` 42 | 43 | ### L1 regularization 44 | The weights of the network are regularized with L1 regularization. 45 | 46 | ## Workflow 47 | 48 | 1. Create base folder 49 | ```bash 50 | INFECTION=~/experiments/infection/ 51 | mkdir -p "$INFECTION/"{runs,data} 52 | ``` 53 | 2. Create dataset (plus a small one for debug) 54 | ```bash 55 | python -m infection.dataset generate ../config/infection/datasets.yaml "folder=${INFECTION}/data" 56 | python -m infection.dataset generate ../config/infection/datasets.yaml \ 57 | "folder=${INFECTION}/smalldata" \ 58 | datasets.{train,val}.num_samples=5000 59 | ``` 60 | 4. Launch one experiment: 61 | ```bash 62 | python -m infection.train \ 63 | --experiment ../config/infection/train.yaml \ 64 | --model ../config/infection/minimal.yaml 65 | ``` 66 | Or make a grid search over the hyperparameters: 67 | ```bash 68 | conda activate tg-experiments 69 | function train { 70 | python -m infection.train \ 71 | --experiment ../config/infection/train.yaml \ 72 | "tags=[${1},lr${2},nodes${4},count${5},wd${3}]" \ 73 | --model "../config/infection/${1}.yaml" \ 74 | --optimizer "kwargs.lr=${2}" \ 75 | --session "losses.l1=${3}" "losses.nodes=${4}" "losses.count=${5}" 76 | } 77 | export -f train # use bash otherwise `export -f` won't work 78 | parallel --eta --max-procs 6 --load 80% --noswap 'train {1} {2} {3} {4} {5}' \ 79 | `# Architecture` ::: infectionGN \ 80 | `# Learning rate` ::: .01 .001 \ 81 | `# L1 loss` ::: 0 .0001 \ 82 | `# Infection loss` ::: 1 .1 \ 83 | `# Count loss` ::: 1 .1 84 | ``` 85 | 6. Visualize logs: 86 | ```bash 87 | tensorboard --logdir "$INFECTION/runs" 88 | ``` 89 | -------------------------------------------------------------------------------- /src/infection/predict.py: -------------------------------------------------------------------------------- 1 | # TODO this might be incompatible with recent code changes 2 | 3 | import tqdm 4 | import yaml 5 | import pyaml 6 | import multiprocessing 7 | from pathlib import Path 8 | from argparse import ArgumentParser 9 | 10 | import numpy as np 11 | from munch import AutoMunch, Munch 12 | 13 | import torch 14 | import torch.utils.data 15 | import torch.nn.functional as F 16 | import torchgraphs as tg 17 | 18 | from utils import parse_dotted, update_rec, import_ 19 | from .dataset import InfectionDataset 20 | 21 | parser = ArgumentParser() 22 | parser.add_argument('--model', nargs='+', required=True) 23 | parser.add_argument('--data', nargs='+', required=True, default=[]) 24 | parser.add_argument('--options', nargs='+', required=False, default=[]) 25 | parser.add_argument('--output', type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | 30 | # region Collecting phase 31 | 32 | # Defaults 33 | model = Munch(fn=None, args=[], kwargs={}, state_dict=None) 34 | data = [] 35 | options = AutoMunch() 36 | options.cpus = multiprocessing.cpu_count() - 1 37 | options.device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | options.output = args.output 39 | 40 | # Model from --model args 41 | for string in args.model: 42 | if '=' in string: 43 | update = parse_dotted(string) 44 | else: 45 | with open(string, 'r') as f: 46 | update = yaml.safe_load(f) 47 | # If the yaml file contains an entry with key `model` use that one instead 48 | if 'model' in update.keys(): 49 | update = update['model'] 50 | update_rec(model, update) 51 | 52 | # Data from --data args 53 | for path in args.data: 54 | path = Path(path).expanduser().resolve() 55 | if path.is_dir(): 56 | data.extend(path.glob('*.pt')) 57 | elif path.is_file() and path.suffix == '.pt': 58 | data.append(path) 59 | else: 60 | raise ValueError(f'Invalid data: {path}') 61 | 62 | # Options from --options args 63 | for string in args.options: 64 | if '=' in string: 65 | update = parse_dotted(string) 66 | else: 67 | with open(string, 'r') as f: 68 | update = yaml.safe_load(f) 69 | update_rec(options, update) 70 | 71 | # Resolving paths 72 | model.state_dict = Path(model.state_dict).expanduser().resolve() 73 | options.output = Path(options.output).expanduser().resolve() 74 | 75 | # Checks (some missing, others redundant) 76 | if model.fn is None: 77 | raise ValueError('Model constructor function not defined') 78 | if model.state_dict is None: 79 | raise ValueError(f'Model state dict is required to predict') 80 | if len(data) == 0: 81 | raise ValueError(f'No data to predict') 82 | if options.cpus < 0: 83 | raise ValueError(f'Invalid number of cpus: {options.cpus}') 84 | if options.output.exists() and not options.output.is_dir(): 85 | raise ValueError(f'Invalid output path {options.output}') 86 | 87 | 88 | pyaml.pprint({'model': model, 'options': options, 'data': data}, sort_dicts=False, width=200) 89 | # endregion 90 | 91 | # region Building phase 92 | # Model 93 | net: torch.nn.Module = import_(model.fn)(*model.args, **model.kwargs) 94 | net.load_state_dict(torch.load(model.state_dict)) 95 | net.to(options.device) 96 | 97 | # Output folder 98 | options.output.mkdir(parents=True, exist_ok=True) 99 | # endregion 100 | 101 | # region Training 102 | # Dataset and dataloader 103 | dataset_predict: InfectionDataset = torch.load(data[0]) 104 | dataloader_predict = torch.utils.data.DataLoader( 105 | dataset_predict, 106 | shuffle=False, 107 | num_workers=min(options.cpus, 1) if 'cuda' in options.device else options.cpus, 108 | pin_memory='cuda' in options.device, 109 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)), 110 | batch_size=options.batch_size, 111 | collate_fn=tg.GraphBatch.collate, 112 | ) 113 | 114 | # region Predict 115 | net.eval() 116 | torch.set_grad_enabled(False) 117 | i = 0 118 | with tqdm.tqdm(desc='Predict', total=len(dataloader_predict.dataset), unit='g') as bar: 119 | for graphs, *_ in dataloader_predict: 120 | graphs = graphs.to(options.device) 121 | 122 | results = net(graphs) 123 | results.node_features.sigmoid_() 124 | 125 | for result in results: 126 | torch.save(result.cpu(), options.output / f'output_{i:06d}.pt') 127 | i += 1 128 | 129 | bar.update(graphs.num_graphs) 130 | # endregion 131 | -------------------------------------------------------------------------------- /src/infection/train.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import yaml 3 | import pyaml 4 | import random 5 | import textwrap 6 | import multiprocessing 7 | from pathlib import Path 8 | from datetime import datetime 9 | from argparse import ArgumentParser 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import sklearn.metrics 14 | from munch import AutoMunch 15 | 16 | import torch 17 | import torch.utils.data 18 | import torch.nn.functional as F 19 | import torch_scatter 20 | import torchgraphs as tg 21 | from tensorboardX import SummaryWriter 22 | 23 | from saver import Saver 24 | from utils import git_info, cuda_info, parse_dotted, update_rec, set_seeds, import_, sort_dict, RunningWeightedAverage 25 | from .dataset import InfectionDataset 26 | 27 | parser = ArgumentParser() 28 | parser.add_argument('--experiment', nargs='+', required=True) 29 | parser.add_argument('--model', nargs='+', required=False, default=[]) 30 | parser.add_argument('--optimizer', nargs='+', required=False, default=[]) 31 | parser.add_argument('--session', nargs='+', required=False, default=[]) 32 | 33 | args = parser.parse_args() 34 | 35 | 36 | # region Collecting phase 37 | class Experiment(AutoMunch): 38 | @property 39 | def session(self): 40 | return self.sessions[-1] 41 | 42 | 43 | experiment = Experiment() 44 | 45 | # Experiment defaults 46 | experiment.name = 'experiment' 47 | experiment.tags = [] 48 | experiment.samples = 0 49 | experiment.model = {'fn': None, 'args': [], 'kwargs': {}} 50 | experiment.optimizer = {'fn': None, 'args': [], 'kwargs': {}} 51 | experiment.sessions = [] 52 | 53 | # Session defaults 54 | session = AutoMunch() 55 | session.losses = {'nodes': 0, 'count': 0, 'l1': 0} 56 | session.seed = random.randint(0, 99) 57 | session.cpus = multiprocessing.cpu_count() - 1 58 | session.device = 'cuda' if torch.cuda.is_available() else 'cpu' 59 | session.log = {'when': []} 60 | session.checkpoint = {'when': []} 61 | 62 | # Experiment configuration 63 | for string in args.experiment: 64 | if '=' in string: 65 | update = parse_dotted(string) 66 | else: 67 | with open(string, 'r') as f: 68 | update = yaml.safe_load(f) 69 | # If the current session is defined inside the experiment update the session instead 70 | if 'session' in update: 71 | update_rec(session, update.pop('session')) 72 | update_rec(experiment, update) 73 | 74 | # Model from --model args 75 | for string in args.model: 76 | if '=' in string: 77 | update = parse_dotted(string) 78 | else: 79 | with open(string, 'r') as f: 80 | update = yaml.safe_load(f) 81 | # If the yaml document contains a single entry with key `model` use that one instead 82 | if update.keys() == {'model'}: 83 | update = update['model'] 84 | update_rec(experiment.model, update) 85 | del update 86 | 87 | # Optimizer from --optimizer args 88 | for string in args.optimizer: 89 | if '=' in string: 90 | update = parse_dotted(string) 91 | else: 92 | with open(string, 'r') as f: 93 | update = yaml.safe_load(f) 94 | # If the yaml document contains a single entry with key `optimizer` use that one instead 95 | if update.keys() == {'optimizer'}: 96 | update = update['optimizer'] 97 | update_rec(experiment.optimizer, update) 98 | del update 99 | 100 | # Session from --session args 101 | for string in args.session: 102 | if '=' in string: 103 | update = parse_dotted(string) 104 | else: 105 | with open(string, 'r') as f: 106 | update = yaml.safe_load(f) 107 | # If the yaml document contains a single entry with key `session` use that one instead 108 | if update.keys() == {'session'}: 109 | update = update['session'] 110 | update_rec(session, update) 111 | del update 112 | 113 | # Checks (some missing, others redundant) 114 | if experiment.name is None or len(experiment.name) == 0: 115 | raise ValueError(f'Experiment name is empty: {experiment.name}') 116 | if experiment.tags is None: 117 | raise ValueError('Experiment tags is None') 118 | if experiment.model.fn is None: 119 | raise ValueError('Model constructor function not defined') 120 | if experiment.optimizer.fn is None: 121 | raise ValueError('Optimizer constructor function not defined') 122 | if session.cpus < 0: 123 | raise ValueError(f'Invalid number of cpus: {session.cpus}') 124 | if any(l < 0 for l in session.losses.values()) or all(l == 0 for l in session.losses.values()): 125 | raise ValueError(f'Invalid losses: {session.losses}') 126 | if len(experiment.sessions) > 0 and ('state_dict' not in experiment.model or 'state_dict' not in experiment.optimizer): 127 | raise ValueError(f'Model and optimizer state dicts are required to restore training') 128 | 129 | # Experiment computed fields 130 | experiment.epoch = sum((s.epochs for s in experiment.sessions), 0) 131 | 132 | # Session computed fields 133 | session.status = 'NEW' 134 | session.datetime_started = None 135 | session.datetime_completed = None 136 | git = git_info() 137 | if git is not None: 138 | session.git = git 139 | if 'cuda' in session.device: 140 | session.cuda = cuda_info() 141 | 142 | # Resolving paths 143 | rand_id = ''.join(chr(random.randint(ord('A'), ord('Z'))) for _ in range(6)) 144 | session.data.folder = Path(session.data.folder.replace('{name}', experiment.name)).expanduser().resolve().as_posix() 145 | session.log.folder = session.log.folder \ 146 | .replace('{name}', experiment.name) \ 147 | .replace('{tags}', '_'.join(experiment.tags)) \ 148 | .replace('{rand}', rand_id) 149 | if len(session.checkpoint.when) > 0: 150 | if len(session.log.when) > 0: 151 | session.log.folder = Path(session.log.folder).expanduser().resolve().as_posix() 152 | session.checkpoint.folder = session.checkpoint.folder \ 153 | .replace('{name}', experiment.name) \ 154 | .replace('{tags}', '_'.join(experiment.tags)) \ 155 | .replace('{rand}', rand_id) 156 | session.checkpoint.folder = Path(session.checkpoint.folder).expanduser().resolve().as_posix() 157 | if 'state_dict' in experiment.model: 158 | experiment.model.state_dict = Path(experiment.model.state_dict).expanduser().resolve().as_posix() 159 | if 'state_dict' in experiment.optimizer: 160 | experiment.optimizer.state_dict = Path(experiment.optimizer.state_dict).expanduser().resolve().as_posix() 161 | 162 | sort_dict(experiment, ['name', 'tags', 'epoch', 'samples', 'model', 'optimizer', 'sessions']) 163 | sort_dict(session, ['epochs', 'batch_size', 'losses', 'seed', 'cpus', 'device', 'samples', 'status', 164 | 'datetime_started', 'datetime_completed', 'data', 'log', 'checkpoint', 'git', 'gpus']) 165 | experiment.sessions.append(session) 166 | pyaml.pprint(experiment, sort_dicts=False, width=200) 167 | del session 168 | # endregion 169 | 170 | # region Building phase 171 | # Seeds (set them after the random run id is generated) 172 | set_seeds(experiment.session.seed) 173 | 174 | # Model 175 | model: torch.nn.Module = import_(experiment.model.fn)(*experiment.model.args, **experiment.model.kwargs) 176 | if 'state_dict' in experiment.model: 177 | model.load_state_dict(torch.load(experiment.model.state_dict)) 178 | model.to(experiment.session.device) 179 | 180 | # Optimizer 181 | optimizer: torch.optim.Optimizer = import_(experiment.optimizer.fn)( 182 | model.parameters(), *experiment.optimizer.args, **experiment.optimizer.kwargs) 183 | if 'state_dict' in experiment.optimizer: 184 | optimizer.load_state_dict(torch.load(experiment.optimizer.state_dict)) 185 | 186 | # Logger 187 | if len(experiment.session.log.when) > 0: 188 | logger = SummaryWriter(experiment.session.log.folder) 189 | logger.add_text( 190 | 'Experiment', textwrap.indent(pyaml.dump(experiment, safe=True, sort_dicts=False), ' '), experiment.samples) 191 | else: 192 | logger = None 193 | 194 | # Saver 195 | if len(experiment.session.checkpoint.when) > 0: 196 | saver = Saver(experiment.session.checkpoint.folder) 197 | if experiment.epoch == 0: 198 | saver.save_experiment(experiment, suffix=f'e{experiment.epoch:04d}') 199 | else: 200 | saver = None 201 | # endregion 202 | 203 | # Datasets and dataloaders 204 | dataloader_kwargs = dict( 205 | num_workers=min(experiment.session.cpus, 1) if 'cuda' in experiment.session.device else experiment.session.cpus, 206 | pin_memory='cuda' in experiment.session.device, 207 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)), 208 | batch_size=experiment.session.batch_size, 209 | collate_fn=tg.GraphBatch.collate, 210 | ) 211 | 212 | dataset_train: InfectionDataset = torch.load(Path(experiment.session.data.folder) / 'train.pt') 213 | dataloader_train = torch.utils.data.DataLoader( 214 | dataset_train, 215 | shuffle=True, 216 | **dataloader_kwargs 217 | ) 218 | count_weights = pd.Series(t.global_features.item() for g, t in dataset_train) \ 219 | .value_counts(normalize=True) \ 220 | .apply(np.log) \ 221 | .apply(np.negative) \ 222 | .astype(np.float32) \ 223 | .sort_index() 224 | 225 | dataset_val: InfectionDataset = torch.load(Path(experiment.session.data.folder) / 'val.pt') 226 | dataloader_val = torch.utils.data.DataLoader( 227 | dataset_val, 228 | shuffle=False, 229 | **dataloader_kwargs 230 | ) 231 | del dataloader_kwargs 232 | 233 | # region Training 234 | # Train and validation loops 235 | experiment.session.status = 'RUNNING' 236 | experiment.session.datetime_started = datetime.utcnow() 237 | 238 | graphs_df = {k: [] for k in ['LossInfection', 'LossCount', 'Nodes', 'Edges', 'InfectedStart', 'InfectedEnd', 239 | 'InfectedSum', 'InfectedCount', 'MeanPercError', 'AvgPrecision', 'AreaROC']} 240 | nodes_df = {k: [] for k in ['Targets', 'Results']} 241 | 242 | epoch_bar_postfix = {} 243 | epoch_bar = tqdm.trange(1, experiment.session.epochs + 1, desc='Epochs', unit='e', leave=True) 244 | for epoch_idx in epoch_bar: 245 | experiment.epoch += 1 246 | 247 | # region Training loop 248 | model.train() 249 | torch.set_grad_enabled(True) 250 | 251 | train_bar_postfix = {} 252 | metric_mpe_avg = RunningWeightedAverage() 253 | loss_bce_avg = RunningWeightedAverage() 254 | loss_count_avg = RunningWeightedAverage() 255 | loss_l1_avg = RunningWeightedAverage() 256 | loss_total_avg = RunningWeightedAverage() 257 | 258 | train_bar = tqdm.tqdm(desc=f'Train {experiment.epoch}', total=len(dataloader_train.dataset), unit='g') 259 | for graphs, targets in dataloader_train: 260 | graphs = graphs.to(experiment.session.device) 261 | targets = targets.to(experiment.session.device) 262 | results = model(graphs) 263 | 264 | loss_total = torch.tensor(0., device=experiment.session.device) 265 | 266 | if experiment.session.losses.nodes > 0: 267 | loss_bce = F.binary_cross_entropy_with_logits( 268 | results.node_features.squeeze(), targets.node_features.squeeze().float(), reduction='mean') 269 | loss_total += experiment.session.losses.nodes * loss_bce 270 | loss_bce_avg.add(loss_bce.item(), len(graphs)) 271 | train_bar_postfix['Infected'] = f'{loss_bce.item():.5f}' 272 | if 'every batch' in experiment.session.log.when: 273 | logger.add_scalar('loss/train/infected', loss_bce.item(), global_step=experiment.samples) 274 | 275 | if experiment.session.losses.count > 0: 276 | loss_count = F.mse_loss( 277 | results.global_features.squeeze(), targets.global_features.squeeze().float(), reduction='none') 278 | weights = loss_count.new_tensor(count_weights.loc[targets.global_features.squeeze().cpu().numpy()].values) 279 | loss_total += experiment.session.losses.count * torch.mean(loss_count * weights) 280 | loss_count_avg.add(loss_count.mean().item(), len(graphs)) 281 | train_bar_postfix['Count'] = f'{loss_count.mean().item():.5f}' 282 | 283 | metric_mpe = ((results.global_features - targets.global_features.float()).abs() / 284 | targets.global_features.float()).mean() 285 | metric_mpe_avg.add(metric_mpe,len(graphs)) 286 | 287 | if 'every batch' in experiment.session.log.when: 288 | logger.add_scalar('loss/train/count', loss_count.mean().item(), global_step=experiment.samples) 289 | logger.add_scalar('metric/train/mpe', metric_mpe.item(), global_step=experiment.samples) 290 | 291 | if experiment.session.losses.l1 > 0: 292 | loss_l1 = sum([p.abs().sum() for p in model.parameters()]) 293 | loss_total += experiment.session.losses.l1 * loss_l1 294 | loss_l1_avg.add(loss_l1.item(), len(graphs)) 295 | train_bar_postfix['L1'] = f'{loss_l1.item():.5f}' 296 | if 'every batch' in experiment.session.log.when: 297 | logger.add_scalar('loss/train/l1', loss_l1.item(), global_step=experiment.samples) 298 | 299 | loss_total_avg.add(loss_total.item(), len(graphs)) 300 | train_bar_postfix['Total'] = f'{loss_total.item():.5f}' 301 | if 'every batch' in experiment.session.log.when: 302 | logger.add_scalar('loss/train/total', loss_total.item(), global_step=experiment.samples) 303 | 304 | optimizer.zero_grad() 305 | loss_total.backward() 306 | optimizer.step(closure=None) 307 | 308 | experiment.samples += len(graphs) 309 | train_bar.update(len(graphs)) 310 | train_bar.set_postfix(train_bar_postfix) 311 | train_bar.close() 312 | 313 | epoch_bar_postfix['Train'] = f'{loss_total_avg.get():.4f}' 314 | epoch_bar.set_postfix(epoch_bar_postfix) 315 | 316 | if 'every epoch' in experiment.session.log.when and 'every batch' not in experiment.session.log.when: 317 | logger.add_scalar('loss/train/total', loss_total_avg.get(), global_step=experiment.samples) 318 | if experiment.session.losses.nodes > 0: 319 | logger.add_scalar('loss/train/infected', loss_bce_avg.get(), global_step=experiment.samples) 320 | if experiment.session.losses.count > 0: 321 | logger.add_scalar('loss/train/count', loss_count_avg.get(), global_step=experiment.samples) 322 | logger.add_scalar('metric/train/mpe', metric_mpe_avg.get(), global_step=experiment.samples) 323 | if experiment.session.losses.l1 > 0: 324 | logger.add_scalar('loss/train/l1', loss_l1_avg.get(), global_step=experiment.samples) 325 | 326 | del train_bar, train_bar_postfix, loss_bce_avg, loss_count_avg, loss_l1_avg, loss_total_avg, metric_mpe_avg 327 | # endregion 328 | 329 | # region Validation loop 330 | model.eval() 331 | torch.set_grad_enabled(False) 332 | 333 | val_bar_postfix = {} 334 | metric_mpe_avg = RunningWeightedAverage() 335 | loss_bce_avg = RunningWeightedAverage() 336 | loss_count_avg = RunningWeightedAverage() 337 | loss_total_avg = RunningWeightedAverage() 338 | loss_l1 = sum([p.abs().sum() for p in model.parameters()]) 339 | val_bar_postfix['L1'] = f'{loss_l1.item():.5f}' 340 | 341 | val_bar = tqdm.tqdm(desc=f'Val {experiment.epoch}', total=len(dataloader_val.dataset), unit='g') 342 | for batch_idx, (graphs, targets) in enumerate(dataloader_val): 343 | graphs = graphs.to(experiment.session.device) 344 | targets = targets.to(experiment.session.device) 345 | results = model(graphs) 346 | 347 | loss_total = torch.tensor(0., device=experiment.session.device) 348 | 349 | if experiment.session.losses.nodes > 0: 350 | loss_bce = F.binary_cross_entropy_with_logits( 351 | results.node_features.squeeze(), targets.node_features.squeeze().float(), reduction='mean') 352 | loss_total += experiment.session.losses.nodes * loss_bce 353 | loss_bce_avg.add(loss_bce.item(), len(graphs)) 354 | val_bar_postfix['Infected'] = f'{loss_bce.item():.5f}' 355 | 356 | if experiment.session.losses.count > 0: 357 | loss_count = F.mse_loss( 358 | results.global_features.squeeze(), targets.global_features.squeeze().float(), reduction='mean') 359 | loss_total += experiment.session.losses.count * loss_count 360 | loss_count_avg.add(loss_count.item(), len(graphs)) 361 | val_bar_postfix['Count'] = f'{loss_count.item():.5f}' 362 | 363 | metric_mpe_avg.add( 364 | ((results.global_features - targets.global_features.float()).abs() / 365 | targets.global_features.float()).mean(), 366 | len(graphs) 367 | ) 368 | 369 | if experiment.session.losses.l1 > 0: 370 | loss_total += experiment.session.losses.l1 * loss_l1 371 | 372 | val_bar_postfix['Total'] = f'{loss_total.item():.5f}' 373 | loss_total_avg.add(loss_total.item(), len(graphs)) 374 | 375 | # region Last epoch 376 | if epoch_idx == experiment.session.epochs: 377 | loss_bce_by_graph = torch_scatter.scatter_mean( 378 | F.binary_cross_entropy_with_logits( 379 | results.node_features.squeeze(), targets.node_features.squeeze().float(), reduction='none'), 380 | index=tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph), 381 | dim=0, 382 | dim_size=graphs.num_graphs 383 | ) 384 | loss_count_by_graph = F.mse_loss( 385 | results.global_features.squeeze(), targets.global_features.squeeze().float(), reduction='none') 386 | mpe_by_graph = ( 387 | (results.global_features - targets.global_features.float()).abs() / 388 | targets.global_features.float() 389 | ).squeeze() 390 | infected_by_graph_start_true = torch_scatter.scatter_add( 391 | graphs.node_features[:, 0].int().clamp(min=0), 392 | index=tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph), 393 | dim=0, 394 | dim_size=graphs.num_graphs 395 | ) 396 | infected_by_graph_sum_pred = torch_scatter.scatter_add( 397 | results.node_features.squeeze().sigmoid(), 398 | index=tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph), 399 | dim=0, 400 | dim_size=graphs.num_graphs 401 | ) 402 | avg_prec_by_graph = [] 403 | area_roc_by_graph = [] 404 | for t, r in zip(targets.node_features_by_graph, results.node_features_by_graph): 405 | avg_prec_by_graph.append(sklearn.metrics.average_precision_score( 406 | y_true=t.squeeze().cpu().int(), # numpy does not work with torch.int8 407 | y_score=r.squeeze().sigmoid().cpu()) 408 | ) 409 | try: 410 | area_roc_by_graph.append(sklearn.metrics.roc_auc_score( 411 | y_true=t.squeeze().cpu().int(), # numpy does not work with torch.int8 412 | y_score=r.squeeze().sigmoid().cpu()) 413 | ) 414 | except ValueError: 415 | # ValueError: Only one class present in y_true. ROC AUC score is not defined in that case. 416 | area_roc_by_graph.append(np.nan) 417 | 418 | nodes_df['Targets'].append(targets.node_features.squeeze().cpu().int()) # numpy doesn't convert torch.int8 419 | nodes_df['Results'].append(results.node_features.squeeze().sigmoid().cpu()) 420 | 421 | graphs_df['LossInfection'].append(loss_bce_by_graph.cpu()) 422 | graphs_df['LossCount'].append(loss_count_by_graph.cpu()) 423 | graphs_df['Nodes'].append(graphs.num_nodes_by_graph.cpu()) 424 | graphs_df['Edges'].append(graphs.num_edges_by_graph.cpu()) 425 | graphs_df['InfectedStart'].append(infected_by_graph_start_true.cpu()) 426 | graphs_df['InfectedEnd'].append(targets.global_features.squeeze().cpu()) 427 | graphs_df['InfectedSum'].append(infected_by_graph_sum_pred.cpu()) 428 | graphs_df['InfectedCount'].append(results.global_features.squeeze().cpu()) 429 | graphs_df['MeanPercError'].append(mpe_by_graph.cpu()) 430 | graphs_df['AvgPrecision'].append(np.array(avg_prec_by_graph)) 431 | graphs_df['AreaROC'].append(np.array(area_roc_by_graph)) 432 | # endregion 433 | 434 | val_bar.update(len(graphs)) 435 | val_bar.set_postfix(val_bar_postfix) 436 | val_bar.close() 437 | 438 | epoch_bar_postfix['Val'] = f'{loss_total_avg.get():.4f}' 439 | epoch_bar.set_postfix(epoch_bar_postfix) 440 | 441 | if ( 442 | 'every batch' in experiment.session.log.when or 443 | 'every epoch' in experiment.session.log.when or 444 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs 445 | ): 446 | logger.add_scalar('loss/val/total', loss_total_avg.get(), global_step=experiment.samples) 447 | if experiment.session.losses.nodes > 0: 448 | logger.add_scalar('loss/val/infected', loss_bce_avg.get(), global_step=experiment.samples) 449 | if experiment.session.losses.count > 0: 450 | logger.add_scalar('loss/val/count', loss_count_avg.get(), global_step=experiment.samples) 451 | logger.add_scalar('metric/val/mpe', metric_mpe_avg.get(), global_step=experiment.samples) 452 | if experiment.session.losses.l1 > 0: 453 | logger.add_scalar('loss/val/l1', loss_l1.item(), global_step=experiment.samples) 454 | 455 | del val_bar, val_bar_postfix, loss_bce_avg, loss_count_avg, loss_l1, loss_total_avg, metric_mpe_avg, batch_idx 456 | # endregion 457 | 458 | # Saving 459 | if epoch_idx == experiment.session.epochs: 460 | experiment.session.status = 'DONE' 461 | experiment.session.datetime_completed = datetime.utcnow() 462 | if ( 463 | 'every batch' in experiment.session.checkpoint.when or 464 | 'every epoch' in experiment.session.checkpoint.when or 465 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs 466 | ): 467 | saver.save(model, experiment, optimizer, suffix=f'e{experiment.epoch:04d}') 468 | epoch_bar.close() 469 | print() 470 | del epoch_bar, epoch_bar_postfix, epoch_idx 471 | # endregion 472 | 473 | # region Final report 474 | pd.options.display.precision = 2 475 | pd.options.display.max_columns = 999 476 | pd.options.display.expand_frame_repr = False 477 | 478 | nodes_df = pd.DataFrame({k: np.concatenate(v) for k, v in nodes_df.items()}) 479 | experiment.average_precision = sklearn.metrics.average_precision_score( 480 | y_true=nodes_df.Targets, y_score=nodes_df.Results) 481 | print('Average precision:', experiment.average_precision) 482 | if logger is not None: 483 | logger.add_scalar('metrics/val/avg_precision', experiment.average_precision, global_step=experiment.samples) 484 | logger.add_pr_curve('infection', labels=nodes_df.Targets.values, 485 | predictions=nodes_df.Results.values, global_step=experiment.samples) 486 | 487 | # noinspection PyUnreachableCode 488 | if False: 489 | import matplotlib.pyplot as plt 490 | 491 | precision, recall, _ = sklearn.metrics.precision_recall_curve( 492 | y_true=nodes_df.Targets, probas_pred=nodes_df.Results) 493 | plt.step(recall, precision, color='b', alpha=0.2, where='post') 494 | plt.fill_between(recall, precision, alpha=0.2, color='b', step='post') 495 | plt.xlabel('Recall') 496 | plt.ylabel('Precision') 497 | plt.ylim([0.0, 1.05]) 498 | plt.xlim([0.0, 1.0]) 499 | plt.title(f'Precision-Recall curve: AP={experiment.average_precision:.2f}') 500 | plt.show() 501 | del nodes_df 502 | 503 | graphs_df = pd.DataFrame({k: np.concatenate(v) for k, v in graphs_df.items()}).rename_axis('GraphId').reset_index() 504 | experiment.loss_count = graphs_df.LossCount.mean() 505 | experiment.loss_infection = graphs_df.LossInfection.mean() 506 | experiment.mpe = graphs_df.MeanPercError.mean() 507 | print('Count MSE:', experiment.loss_count) 508 | print('Infection BCE:', experiment.loss_infection) 509 | print('Count MPE:', experiment.average_precision) 510 | 511 | # Split the results based on whether the number of nodes was present in the training set or not 512 | df_train_val = graphs_df \ 513 | .groupby(np.where(graphs_df.Nodes < dataset_train.max_nodes, 514 | f'Train [{dataset_train.min_nodes}, {dataset_train.max_nodes - 1})', 515 | f'Val [{dataset_train.max_nodes}, {dataset_val.max_nodes - 1})')) \ 516 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'LossInfection': 'mean', 'LossCount': 'mean'}) \ 517 | .sort_index(ascending=True) \ 518 | .rename_axis(index='Dataset') \ 519 | .rename(str.capitalize, axis='columns', level=1) 520 | 521 | # Split the results in ranges based on the number of nodes and compute the average loss per range 522 | df_losses_by_node_range = graphs_df \ 523 | .groupby(graphs_df.Nodes // 10) \ 524 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'LossInfection': 'mean', 'LossCount': 'mean'}) \ 525 | .rename_axis(index='NodeRange') \ 526 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index') \ 527 | .rename(str.capitalize, axis='columns', level=1) 528 | 529 | # Split the results in ranges based on the number of nodes and keep the N worst predictions w.r.t. node-wise loss 530 | df_worst_infection_loss_by_node_range = graphs_df \ 531 | .groupby(graphs_df.Nodes // 10) \ 532 | .apply(lambda df_gr: df_gr.nlargest(5, 'LossInfection').set_index('GraphId')) \ 533 | .rename_axis(index={'Nodes': 'NodeRange'}) \ 534 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0) 535 | 536 | # Split the results in ranges based on the number of nodes and keep the N worst predictions w.r.t. graph-wise loss 537 | df_worst_count_loss_by_node_range = graphs_df \ 538 | .groupby(graphs_df.Nodes // 10) \ 539 | .apply(lambda df_gr: df_gr.nlargest(5, 'LossCount').set_index('GraphId')) \ 540 | .rename_axis(index={'Nodes': 'NodeRange'}) \ 541 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0) 542 | 543 | print(f""" 544 | Generalization: 545 | {df_train_val}\n 546 | Losses by range: 547 | {df_losses_by_node_range}\n 548 | Worst infection predictions: 549 | {df_worst_infection_loss_by_node_range}\n 550 | Worst count predictions: 551 | {df_worst_count_loss_by_node_range} 552 | """) 553 | 554 | if logger is not None: 555 | logger.add_text( 556 | 'Generalization', 557 | textwrap.indent(df_train_val.to_string(), ' '), 558 | global_step=experiment.samples) 559 | logger.add_text( 560 | 'Losses by range', 561 | textwrap.indent(df_losses_by_node_range.to_string(), ' '), 562 | global_step=experiment.samples) 563 | logger.add_text( 564 | 'Worst infection predictions', 565 | textwrap.indent(df_worst_infection_loss_by_node_range.to_string(), ' '), 566 | global_step=experiment.samples), 567 | logger.add_text( 568 | 'Worst count predictions', 569 | textwrap.indent(df_worst_count_loss_by_node_range.to_string(), ' '), 570 | global_step=experiment.samples) 571 | del graphs_df, df_losses_by_node_range, df_worst_infection_loss_by_node_range, df_worst_count_loss_by_node_range 572 | 573 | params = [f'{name}:\n{param.data.cpu().numpy().round(3)}' for name, param in model.named_parameters()] 574 | print('Parameters:', *params, sep='\n\n') 575 | if logger is not None: 576 | logger.add_text('Parameters', textwrap.indent('\n\n'.join(params), ' '), global_step=experiment.samples) 577 | del params 578 | # endregion 579 | 580 | # region Cleanup 581 | if logger is not None: 582 | logger.close() 583 | exit() 584 | # endregion 585 | -------------------------------------------------------------------------------- /src/relevance/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO split these into separate modules 2 | from .autograd_tricks import add, sum, \ 3 | cat, index_select, repeat_tensor, \ 4 | scatter_add, scatter_mean, scatter_max, \ 5 | linear_eps, relu, get_aggregation 6 | from .graphs import EdgeLinearRelevance, NodeLinearRelevance, GlobalLinearRelevance, \ 7 | EdgeReLURelevance, NodeReLURelevance, GlobalReLURelevance 8 | -------------------------------------------------------------------------------- /src/relevance/autograd_tricks.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | 4 | import torch 5 | import torch_scatter 6 | 7 | 8 | class AddRelevance(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, a, b): 11 | out = a + b 12 | ctx.save_for_backward(a, b, out) 13 | return out 14 | 15 | @staticmethod 16 | def backward(ctx, rel_out): 17 | a, b, out = ctx.saved_tensors 18 | if ((out == 0) & (rel_out > 0)).any(): 19 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost') 20 | rel_a = torch.where(out != 0, rel_out * a / out, out.new_tensor(0)) 21 | rel_b = torch.where(out != 0, rel_out * b / out, out.new_tensor(0)) 22 | return rel_a, rel_b 23 | 24 | 25 | def add(a, b): 26 | return AddRelevance.apply(a, b) 27 | 28 | 29 | class SumPooling(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, src, dim, keepdim): 32 | out = torch.sum(src, dim=dim, keepdim=keepdim) 33 | ctx.dim = dim 34 | ctx.keepdim = keepdim 35 | ctx.save_for_backward(src, out) 36 | return out 37 | 38 | @staticmethod 39 | def backward(ctx, rel_out): 40 | src, out = ctx.saved_tensors 41 | if ((out == 0) & (rel_out > 0)).any(): 42 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost') 43 | rel_out = torch.where(out != 0, rel_out / out, out.new_tensor(0)) 44 | if not ctx.keepdim and ctx.dim is not None: 45 | rel_out.unsqueeze_(ctx.dim) 46 | return rel_out * src, None, None 47 | 48 | 49 | def sum(tensor, dim=None, keepdim=False): 50 | return SumPooling.apply(tensor, dim, keepdim) 51 | 52 | 53 | 54 | class ScatterAddRelevance(torch.autograd.Function): 55 | @staticmethod 56 | def forward(ctx, src, out, idx, dim): 57 | torch_scatter.scatter_add(src, idx, dim=dim, out=out) 58 | ctx.dim = dim 59 | ctx.save_for_backward(src, idx, out) 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, rel_out): 64 | src, idx, out = ctx.saved_tensors 65 | if ((out == 0) & (rel_out > 0)).any(): 66 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost') 67 | rel_out = torch.where(out != 0, rel_out / out, out.new_tensor(0)) 68 | rel_src = torch.index_select(rel_out, ctx.dim, idx) * src 69 | return rel_src, None, None, None 70 | 71 | 72 | def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): 73 | src, out, _, dim = torch_scatter.utils.gen.gen(src, index, dim, out, dim_size, fill_value) 74 | return ScatterAddRelevance.apply(src, out, index, dim) 75 | 76 | 77 | class ScatterMeanRelevance(torch.autograd.Function): 78 | @staticmethod 79 | def forward(ctx, src, idx, dim, dim_size, fill_value): 80 | sums = torch_scatter.scatter_add(src, idx, dim, None, dim_size, fill_value) 81 | count = torch_scatter.scatter_add(torch.ones_like(src), idx, dim, None, dim_size, fill_value=0) 82 | out = sums / count.clamp(min=1) 83 | ctx.dim = dim 84 | ctx.save_for_backward(src, idx, sums) 85 | return out 86 | 87 | @staticmethod 88 | def backward(ctx, rel_out): 89 | src, idx, sums = ctx.saved_tensors 90 | if ((sums == 0) & (rel_out > 0)).any(): 91 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost') 92 | rel_out = torch.where(sums != 0, rel_out / sums, sums.new_tensor(0)) 93 | rel_src = torch.index_select(rel_out, ctx.dim, idx) * src 94 | return rel_src, None, None, None, None 95 | 96 | 97 | def scatter_mean(src, index, dim=-1, dim_size=None, fill_value=0): 98 | return ScatterMeanRelevance.apply(src, index, dim, dim_size, fill_value) 99 | 100 | 101 | class ScatterMaxRelevance(torch.autograd.Function): 102 | @staticmethod 103 | def forward(ctx, src, idx, dim, dim_size, fill_value): 104 | out, idx_maxes = torch_scatter.scatter_max(src, idx, dim=dim, dim_size=dim_size, fill_value=fill_value) 105 | ctx.dim = dim 106 | ctx.dim_size = src.shape[dim] 107 | ctx.save_for_backward(idx, out, idx_maxes) 108 | return out, idx_maxes 109 | 110 | @staticmethod 111 | def backward(ctx, rel_out, rel_idx_maxes): 112 | idx, out, idx_maxes = ctx.saved_tensors 113 | if ((out == 0) & (rel_out > 0)).any(): 114 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost') 115 | rel_out = torch.where(out != 0, rel_out, out.new_tensor(0)) 116 | 117 | # Where idx_maxes==-1 set idx=0 so that the indexes are valid for scatter_add 118 | # The corresponding relevance should already be 0, but set it relevance=0 to be sure 119 | rel_out = torch.where(idx_maxes != -1, rel_out, torch.zeros_like(rel_out)) 120 | idx_maxes = torch.where(idx_maxes != -1, idx_maxes, torch.zeros_like(idx_maxes)) 121 | 122 | rel_src = torch_scatter.scatter_add(rel_out, idx_maxes, dim=ctx.dim, dim_size=ctx.dim_size) 123 | return rel_src, None, None, None, None 124 | 125 | 126 | def scatter_max(src, index, dim=-1, dim_size=None, fill_value=0): 127 | return ScatterMaxRelevance.apply(src, index, dim, dim_size, fill_value) 128 | 129 | 130 | class LinearEpsilonRelevance(torch.autograd.Function): 131 | eps = 1e-16 132 | 133 | @staticmethod 134 | def forward(ctx, input, weight, bias): 135 | Z = weight.t()[None, :, :] * input[:, :, None] 136 | Zs = Z.sum(dim=1, keepdim=True) 137 | if bias is not None: 138 | Zs += bias[None, None, :] 139 | ctx.save_for_backward(Z, Zs) 140 | return Zs.squeeze(dim=1) 141 | 142 | @staticmethod 143 | def backward(ctx, rel_out): 144 | Z, Zs = ctx.saved_tensors 145 | eps = rel_out.new_tensor(LinearEpsilonRelevance.eps) 146 | Zs += torch.where(Zs >= 0, eps, -eps) 147 | return (rel_out[:, None, :] * Z / Zs).sum(dim=2), None, None 148 | 149 | 150 | def linear_eps(input, weight, bias=None): 151 | return LinearEpsilonRelevance.apply(input, weight, bias) 152 | 153 | 154 | class IndexSelectRelevance(torch.autograd.Function): 155 | @staticmethod 156 | def forward(ctx, src, dim, idx): 157 | out = torch.index_select(src, dim, idx) 158 | ctx.dim = dim 159 | ctx.dim_size = src.shape[dim] 160 | ctx.save_for_backward(src, idx, out) 161 | return out 162 | 163 | @staticmethod 164 | def backward(ctx, rel_out): 165 | src, idx, out = ctx.saved_tensors 166 | return torch_scatter.scatter_add(rel_out, idx, dim=ctx.dim, dim_size=ctx.dim_size), None, None 167 | 168 | 169 | def index_select(src, dim, index): 170 | return IndexSelectRelevance.apply(src, dim, index) 171 | 172 | 173 | class CatRelevance(torch.autograd.Function): 174 | @staticmethod 175 | def forward(ctx, dim, *tensors): 176 | ctx.dim = dim 177 | ctx.sizes = [t.shape[dim] for t in tensors] 178 | return torch.cat(tensors, dim) 179 | 180 | @staticmethod 181 | def backward(ctx, rel_out): 182 | return (None, *torch.split_with_sizes(rel_out, dim=ctx.dim, split_sizes=ctx.sizes)) 183 | 184 | 185 | def cat(tensors, dim=0): 186 | return CatRelevance.apply(dim, *tensors) 187 | 188 | 189 | def repeat_tensor(src, repeats, dim=0): 190 | idx = src.new_tensor(np.arange(len(repeats)).repeat(repeats.cpu().numpy()), dtype=torch.long) 191 | return torch.index_select(src, dim, idx) 192 | 193 | 194 | class ReLuRelevance(torch.autograd.Function): 195 | @staticmethod 196 | def forward(ctx, input): 197 | out = input.clamp(min=0) 198 | ctx.save_for_backward(out) 199 | return out 200 | 201 | @staticmethod 202 | def backward(ctx, rel_out): 203 | return rel_out 204 | 205 | 206 | def relu(input): 207 | return ReLuRelevance.apply(input) 208 | 209 | 210 | def get_aggregation(name): 211 | if name in ('add', 'sum'): 212 | return scatter_add 213 | elif name in ('mean', 'avg'): 214 | return scatter_mean 215 | elif name == 'max': 216 | from functools import wraps 217 | 218 | @wraps(scatter_max) 219 | def wrapper(*args, **kwargs): 220 | return scatter_max(*args, **kwargs)[0] 221 | 222 | return wrapper 223 | -------------------------------------------------------------------------------- /src/relevance/graphs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchgraphs as tg 3 | 4 | from . import autograd_tricks as lrp 5 | 6 | 7 | class EdgeLinearRelevance(tg.EdgeLinear): 8 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: 9 | new_edges = torch.tensor(0) 10 | 11 | if self.W_edge is not None: 12 | new_edges = lrp.add(new_edges, lrp.linear_eps(graphs.edge_features, self.W_edge)) 13 | if self.W_sender is not None: 14 | new_edges = lrp.add( 15 | new_edges, 16 | lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_sender), 17 | dim=0, index=graphs.senders) 18 | ) 19 | if self.W_receiver is not None: 20 | new_edges = lrp.add( 21 | new_edges, 22 | lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_receiver), 23 | dim=0, index=graphs.receivers) 24 | ) 25 | if self.W_global is not None: 26 | new_edges = lrp.add( 27 | new_edges, 28 | lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), 29 | dim=0, repeats=graphs.num_edges_by_graph) 30 | ) 31 | if self.bias is not None: 32 | new_edges = lrp.add(new_edges, self.bias) 33 | 34 | return graphs.evolve(edge_features=new_edges) 35 | 36 | 37 | class NodeLinearRelevance(tg.NodeLinear): 38 | def __init__(self, out_features, node_features=None, incoming_features=None, outgoing_features=None, 39 | global_features=None, aggregation=None, bias=True): 40 | super(NodeLinearRelevance, self).__init__(out_features, node_features, incoming_features, 41 | outgoing_features, global_features, lrp.get_aggregation(aggregation), 42 | bias) 43 | 44 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: 45 | new_nodes = torch.tensor(0) 46 | 47 | if self.W_node is not None: 48 | new_nodes = lrp.add( 49 | new_nodes, 50 | lrp.linear_eps(graphs.node_features, self.W_node) 51 | ) 52 | if self.W_incoming is not None: 53 | new_nodes = lrp.add( 54 | new_nodes, 55 | lrp.linear_eps( 56 | self.aggregation(graphs.edge_features, dim=0, index=graphs.receivers, dim_size=graphs.num_nodes), 57 | self.W_incoming) 58 | ) 59 | if self.W_outgoing is not None: 60 | new_nodes = lrp.add( 61 | new_nodes, 62 | lrp.linear_eps( 63 | self.aggregation(graphs.edge_features, dim=0, index=graphs.senders, dim_size=graphs.num_nodes), 64 | self.W_outgoing) 65 | ) 66 | if self.W_global is not None: 67 | new_nodes = lrp.add( 68 | new_nodes, 69 | lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), dim=0, 70 | repeats=graphs.num_nodes_by_graph) 71 | ) 72 | if self.bias is not None: 73 | new_nodes = lrp.add(new_nodes, self.bias) 74 | 75 | return graphs.evolve(node_features=new_nodes) 76 | 77 | 78 | class GlobalLinearRelevance(tg.GlobalLinear): 79 | def __init__(self, out_features, node_features=None, edge_features=None, global_features=None, 80 | aggregation=None, bias=True): 81 | super(GlobalLinearRelevance, self).__init__(out_features, node_features, edge_features, 82 | global_features, lrp.get_aggregation(aggregation), bias) 83 | 84 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch: 85 | new_globals = torch.tensor(0) 86 | 87 | if self.W_node is not None: 88 | index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph) 89 | new_globals = lrp.add( 90 | new_globals, 91 | lrp.linear_eps(self.aggregation(graphs.node_features, dim=0, index=index, dim_size=graphs.num_graphs), 92 | self.W_node) 93 | ) 94 | if self.W_edges is not None: 95 | index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph) 96 | new_globals = lrp.add( 97 | new_globals, 98 | lrp.linear_eps(self.aggregation(graphs.edge_features, dim=0, index=index, dim_size=graphs.num_graphs), 99 | self.W_edges) 100 | ) 101 | if self.W_global is not None: 102 | new_globals = lrp.add( 103 | new_globals, 104 | lrp.linear_eps(graphs.global_features, self.W_global) 105 | ) 106 | if self.bias is not None: 107 | new_globals = lrp.add(new_globals, self.bias) 108 | 109 | return graphs.evolve(global_features=new_globals) 110 | 111 | 112 | class EdgeReLURelevance(tg.EdgeFunction): 113 | def __init__(self): 114 | super(EdgeReLURelevance, self).__init__(lrp.relu) 115 | 116 | 117 | class NodeReLURelevance(tg.NodeFunction): 118 | def __init__(self): 119 | super(NodeReLURelevance, self).__init__(lrp.relu) 120 | 121 | 122 | class GlobalReLURelevance(tg.GlobalFunction): 123 | def __init__(self): 124 | super(GlobalReLURelevance, self).__init__(lrp.relu) 125 | -------------------------------------------------------------------------------- /src/relevance/oldies.py: -------------------------------------------------------------------------------- 1 | """ 2 | class DenseW2(RelevanceFunction): 3 | @staticmethod 4 | def forward_relevance(module, inputs, ctx): 5 | output = inputs @ module.weight.t() 6 | if module.bias is not None: 7 | output += module.bias 8 | return output 9 | 10 | @staticmethod 11 | def backward_relevance(module, relevance_outputs, ctx): 12 | return relevance_outputs @ (module.weight.pow(2) / (module.weight.pow(2).sum(dim=1, keepdim=True) + 10e-6)) 13 | 14 | 15 | class DenseZPlus(RelevanceFunction): 16 | @staticmethod 17 | def forward_relevance(module, inputs, ctx): 18 | ctx['inputs'] = inputs 19 | output = inputs @ module.weight.t() 20 | if module.bias is not None: 21 | output += module.bias 22 | return output 23 | 24 | @staticmethod 25 | def backward_relevance(module, relevance_outputs, ctx): 26 | inputs = ctx['inputs'] 27 | return inputs * ( 28 | (relevance_outputs / (inputs @ module.weight.clamp(min=0).t() + 10e-6)) @ 29 | module.weight.clamp(min=0) 30 | ) 31 | """ 32 | -------------------------------------------------------------------------------- /src/relevance/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_scatter 3 | import torchgraphs as tg 4 | 5 | import textwrap 6 | 7 | from . import autograd_tricks as lrp 8 | 9 | 10 | 11 | def patch(): 12 | torch.add = lrp.add 13 | torch.cat = lrp.cat 14 | torch.index_select = lrp.index_select 15 | 16 | tg.utils.repeat_tensor = lrp.repeat_tensor 17 | 18 | torch_scatter.scatter_add = lrp.scatter_add 19 | torch_scatter.scatter_mean = lrp.scatter_mean 20 | torch_scatter.scatter_max = lrp.scatter_max 21 | 22 | torch.nn.functional.linear = lrp.linear_eps 23 | 24 | 25 | def computational_graph(op): 26 | if op is None: 27 | return 'None' 28 | res = f'{op.__class__.__name__} at {hex(id(op))}:' 29 | if op.__class__.__name__ == 'AccumulateGrad': 30 | res += f'variable at {hex(id(op.variable))}' 31 | for op in op.next_functions: 32 | res += '\n-' + textwrap.indent(computational_graph(op[0]), ' ') 33 | return res 34 | -------------------------------------------------------------------------------- /src/relevance_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | 4 | torch.manual_seed(0) 5 | 6 | x = torch.rand(10000, 5) * 2 - 1 7 | y = (3 * x[:, 0] + 8 * torch.cos(3.14 * x[:, 2]) - 5 * torch.pow(x[:, 4], 2)).view(-1, 1) 8 | 9 | ds = torch.utils.data.TensorDataset(x, y) 10 | dl = torch.utils.data.DataLoader(ds, batch_size=1000, shuffle=True, pin_memory=True, num_workers=0) 11 | 12 | model = torch.nn.Sequential( 13 | torch.nn.Linear(5, 16), torch.nn.ReLU(), 14 | torch.nn.Linear(16, 32), torch.nn.ReLU(), 15 | torch.nn.Linear(32, 16), torch.nn.ReLU(), 16 | torch.nn.Linear(16, 4), torch.nn.ReLU(), 17 | torch.nn.Linear(4, 1) 18 | ).to('cuda') 19 | 20 | model.train() 21 | opt = torch.optim.SGD(model.parameters(), lr=.01) 22 | scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.7) 23 | 24 | for e in range(600): 25 | for xs, ys in dl: 26 | xs = xs.to('cuda') 27 | ys = ys.to('cuda') 28 | 29 | out = model(xs) 30 | model.zero_grad() 31 | 32 | loss = torch.nn.functional.mse_loss(out, ys) 33 | loss.backward() 34 | 35 | opt.step() 36 | 37 | scheduler.step() 38 | if e % 50 == 0: 39 | print(e, loss.item(), opt.param_groups[0]['lr']) 40 | 41 | torch.set_grad_enabled(False) 42 | model.cpu() 43 | model.eval() 44 | 45 | print(y[:10].squeeze()) 46 | print(model(x[:10]).squeeze()) 47 | 48 | from relevance import forward_relevance, backward_relevance 49 | 50 | ctx = {} 51 | 52 | y = forward_relevance(model, x[:10], ctx=ctx) 53 | print(y.squeeze()) 54 | print() 55 | print(*ctx.items(), sep='\n\n') 56 | print() 57 | 58 | rel_out = y 59 | print(rel_out) 60 | print() 61 | rel = backward_relevance(model, rel_out, ctx=ctx) 62 | print(rel) 63 | print(rel / rel_out) 64 | print() 65 | print(rel.abs().mean(dim=0)) 66 | 67 | print() 68 | rel_out = torch.ones_like(y) 69 | print(rel_out) 70 | print() 71 | rel = backward_relevance(model, rel_out, ctx=ctx) 72 | print(rel) 73 | print() 74 | print(rel.mean(dim=0)) 75 | 76 | print(torch.allclose(rel_out.sum(dim=1), rel.sum(dim=1))) 77 | -------------------------------------------------------------------------------- /src/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import pyaml 6 | import torch 7 | 8 | 9 | class Saver(object): 10 | def __init__(self, folder: Union[str, Path]): 11 | self.base_folder = Path(folder).expanduser().resolve() 12 | self.checkpoint_folder = self.base_folder / 'checkpoints' 13 | self.checkpoint_folder.mkdir(parents=True, exist_ok=True) 14 | 15 | def save_model(self, model, suffix=None, is_best=False): 16 | if isinstance(model, torch.nn.DataParallel): 17 | model = model.module 18 | name = 'model.pt' if suffix is None else f'model.{suffix}.pt' 19 | model_path = self.checkpoint_folder / name 20 | torch.save(model.state_dict(), model_path) 21 | 22 | latest_path = self.base_folder / 'model.latest.pt' 23 | if latest_path.exists(): 24 | os.unlink(latest_path) 25 | os.link(model_path, latest_path) 26 | 27 | if is_best: 28 | best_path = self.base_folder / 'model.best.pt' 29 | if best_path.exists(): 30 | os.unlink(latest_path) 31 | os.link(model_path, best_path) 32 | 33 | return model_path.as_posix() 34 | 35 | def save_optimizer(self, optimizer, suffix=None): 36 | name = 'optimizer.pt' if suffix is None else f'optimizer.{suffix}.pt' 37 | optimizer_path = self.checkpoint_folder / name 38 | torch.save(optimizer.state_dict(), optimizer_path) 39 | 40 | latest_path = self.base_folder / 'optimizer.latest.pt' 41 | if latest_path.exists(): 42 | os.unlink(latest_path) 43 | os.link(optimizer_path, latest_path) 44 | 45 | return optimizer_path.as_posix() 46 | 47 | def save_experiment(self, experiment, suffix=None): 48 | name = 'experiment.yaml' if suffix is None else f'experiment.{suffix}.yaml' 49 | experiment_path = self.checkpoint_folder / name 50 | with open(experiment_path, 'w') as f: 51 | pyaml.dump(experiment, f, safe=True, sort_dicts=False) 52 | 53 | latest_path = self.base_folder / 'experiment.latest.yaml' 54 | if latest_path.exists(): 55 | os.unlink(latest_path) 56 | os.link(experiment_path, latest_path) 57 | 58 | return experiment_path.as_posix() 59 | 60 | def save(self, model, experiment, optimizer, suffix=None, is_best=False): 61 | experiment.model.state_dict = self.save_model(model, suffix=suffix, is_best=is_best) 62 | experiment.optimizer.state_dict = self.save_optimizer(optimizer, suffix=suffix) 63 | return { 64 | 'model': experiment.model.state_dict, 65 | 'optimizer': experiment.optimizer.state_dict, 66 | 'experiment': self.save_experiment(experiment, suffix=suffix) 67 | } 68 | -------------------------------------------------------------------------------- /src/solubility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/solubility/__init__.py -------------------------------------------------------------------------------- /src/solubility/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple 3 | 4 | import torch 5 | import pandas as pd 6 | from pandas.api.types import CategoricalDtype 7 | 8 | from rdkit import Chem 9 | 10 | import torchgraphs as tg 11 | 12 | symbols = CategoricalDtype([ 13 | 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 14 | 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 15 | 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', # H? 16 | 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 17 | 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown' 18 | ], ordered=True) 19 | 20 | bonds = CategoricalDtype([ 21 | 'SINGLE', 22 | 'DOUBLE', 23 | 'TRIPLE', 24 | 'AROMATIC' 25 | ], ordered=True) 26 | 27 | 28 | def smiles_to_graph(smiles: str) -> tg.Graph: 29 | molecule = Chem.MolFromSmiles(smiles) 30 | 31 | atoms_df = [] 32 | for i in range(molecule.GetNumAtoms()): 33 | atom = molecule.GetAtomWithIdx(i) 34 | atoms_df.append({ 35 | 'index': i, 36 | 'symbol': atom.GetSymbol(), 37 | 'degree': atom.GetDegree(), 38 | 'hydrogens': atom.GetTotalNumHs(), 39 | 'impl_valence': atom.GetImplicitValence(), 40 | }) 41 | atoms_df = pd.DataFrame.from_records(atoms_df, index='index', 42 | columns=['index', 'symbol', 'degree', 'hydrogens', 'impl_valence']) 43 | atoms_df.symbol = atoms_df.symbol.astype(symbols) 44 | 45 | node_features = torch.tensor(pd.get_dummies(atoms_df, columns=['symbol']).values, dtype=torch.float) 46 | 47 | bonds_df = [] 48 | for bond in molecule.GetBonds(): 49 | bonds_df.append({ 50 | 'sender': bond.GetBeginAtomIdx(), 51 | 'receiver': bond.GetEndAtomIdx(), 52 | 'type': bond.GetBondType().name, 53 | 'conj': bond.GetIsConjugated(), 54 | 'ring': bond.IsInRing() 55 | }) 56 | bonds_df.append({ 57 | 'sender': bond.GetEndAtomIdx(), 58 | 'receiver': bond.GetBeginAtomIdx(), 59 | 'type': bond.GetBondType().name, 60 | 'conj': bond.GetIsConjugated(), 61 | 'ring': bond.IsInRing() 62 | }) 63 | bonds_df = pd.DataFrame.from_records(bonds_df, columns=['sender', 'receiver', 'type', 'conj', 'ring'])\ 64 | .set_index(['sender', 'receiver']) 65 | bonds_df.conj = bonds_df.conj * 2. - 1 66 | bonds_df.ring = bonds_df.ring * 2. - 1 67 | bonds_df.type = bonds_df.type.astype(bonds) 68 | 69 | edge_features = torch.tensor(pd.get_dummies(bonds_df, columns=['type']).values.astype(float), dtype=torch.float) 70 | senders = torch.tensor(bonds_df.index.get_level_values('sender'), dtype=torch.long) 71 | receivers = torch.tensor(bonds_df.index.get_level_values('receiver'), dtype=torch.long) 72 | 73 | return tg.Graph( 74 | num_nodes=molecule.GetNumAtoms(), 75 | num_edges=molecule.GetNumBonds() * 2, 76 | node_features=node_features, 77 | edge_features=edge_features, 78 | senders=senders, 79 | receivers=receivers 80 | ) 81 | 82 | 83 | class SolubilityDataset(torch.utils.data.Dataset): 84 | def __init__(self, path): 85 | self.df = pd.read_csv(path) 86 | # self.df['molecules'] = self.df.smiles.apply(smiles_to_graph) 87 | 88 | def __len__(self): 89 | return len(self.df) 90 | 91 | def __getitem__(self, item) -> Tuple[tg.Graph, float]: 92 | mol = smiles_to_graph(self.df['smiles'].iloc[item]) 93 | target = self.df['measured log solubility in mols per litre'].iloc[item] 94 | return mol, torch.tensor(target) 95 | 96 | 97 | def describe(cfg): 98 | pd.options.display.precision = 2 99 | pd.options.display.max_columns = 999 100 | pd.options.display.expand_frame_repr = False 101 | target = Path(cfg.target).expanduser().resolve() 102 | if target.is_dir(): 103 | paths = target.glob('*.pt') 104 | else: 105 | paths = [target] 106 | for p in paths: 107 | print(f"Loading dataset from: {p}") 108 | dataset = SolubilityDataset(p) 109 | print(f"{p.with_suffix('').name.capitalize()} contains:\n" 110 | f"{dataset.df.drop(columns=['molecules']).describe().transpose()}") 111 | 112 | 113 | def main(): 114 | from argparse import ArgumentParser 115 | from config import Config 116 | 117 | parser = ArgumentParser() 118 | subparsers = parser.add_subparsers() 119 | 120 | sp_print = subparsers.add_parser('print', help='Print parsed configuration') 121 | sp_print.add_argument('config', nargs='*') 122 | sp_print.set_defaults(command=lambda c: print(c.toYAML())) 123 | 124 | sp_describe = subparsers.add_parser('describe', help='Describe existing datasets') 125 | sp_describe.add_argument('config', nargs='*') 126 | sp_describe.set_defaults(command=describe) 127 | 128 | args = parser.parse_args() 129 | cfg = Config.build(*args.config) 130 | args.command(cfg) 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /src/solubility/layout.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | 6 | from tensorboard import summary as summary_lib 7 | from tensorboard.plugins.custom_scalar import layout_pb2 8 | 9 | layout_summary = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=[ 10 | layout_pb2.Category( 11 | title='losses', 12 | chart=[ 13 | # Chart 'losses' (include all losses, exclude upper and lower bounds) 14 | layout_pb2.Chart( 15 | title='losses', 16 | multiline=layout_pb2.MultilineChartContent( 17 | tag=[ 18 | r'loss(?!.*bound.*)' 19 | ] 20 | ) 21 | ), 22 | ]) 23 | ])) 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('folder', help='The log folder to place the layout in') 27 | args = parser.parse_args() 28 | 29 | folder = (Path(args.folder) / 'layout').expanduser().resolve() 30 | with tf.summary.FileWriter(folder) as writer: 31 | writer.add_summary(layout_summary) 32 | 33 | print('Layout saved to', folder) 34 | -------------------------------------------------------------------------------- /src/solubility/networks.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import torchgraphs as tg 7 | 8 | 9 | def build_network(num_hidden): 10 | return SolubilityGN(num_hidden) 11 | 12 | 13 | class SolubilityGN(nn.Module): 14 | def __init__(self, num_layers, hidden_bias, hidden_node, dropout, aggregation): 15 | super().__init__() 16 | 17 | hidden_edge = hidden_node // 4 18 | hidden_global = hidden_node // 8 19 | 20 | self.encoder = nn.Sequential(OrderedDict({ 21 | 'edge': tg.EdgeLinear(hidden_edge, edge_features=6), 22 | 'edge_relu': tg.EdgeReLU(), 23 | 'node': tg.NodeLinear(hidden_node, node_features=47), 24 | 'node_relu': tg.NodeReLU(), 25 | 'global': tg.GlobalLinear(hidden_global, node_features=hidden_node, 26 | edge_features=hidden_edge, aggregation=aggregation), 27 | 'global_relu': tg.GlobalReLU(), 28 | })) 29 | if dropout: 30 | self.hidden = nn.Sequential(OrderedDict({ 31 | f'hidden_{i}': nn.Sequential(OrderedDict({ 32 | 'edge': tg.EdgeLinear(hidden_edge, edge_features=hidden_edge, 33 | sender_features=hidden_node, bias=hidden_bias), 34 | 'edge_relu': tg.EdgeReLU(), 35 | 'edge_dropout': tg.EdgeDroput(), 36 | 'node': tg.NodeLinear(hidden_node, node_features=hidden_node, incoming_features=hidden_edge, 37 | aggregation=aggregation, bias=hidden_bias), 38 | 'node_relu': tg.NodeReLU(), 39 | 'node_dropout': tg.EdgeDroput(), 40 | 'global': tg.GlobalLinear(hidden_global, node_features=hidden_node, edge_features=hidden_edge, 41 | global_features=hidden_global, aggregation=aggregation, bias=hidden_bias), 42 | 'global_relu': tg.GlobalReLU(), 43 | 'global_dropout': tg.EdgeDroput(), 44 | })) 45 | for i in range(num_layers) 46 | })) 47 | else: 48 | self.hidden = nn.Sequential(OrderedDict({ 49 | f'hidden_{i}': nn.Sequential(OrderedDict({ 50 | 'edge': tg.EdgeLinear(hidden_edge, edge_features=hidden_edge, 51 | sender_features=hidden_node, bias=hidden_bias), 52 | 'edge_relu': tg.EdgeReLU(), 53 | 'node': tg.NodeLinear(hidden_node, node_features=hidden_node, incoming_features=hidden_edge, 54 | aggregation=aggregation, bias=hidden_bias), 55 | 'node_relu': tg.NodeReLU(), 56 | 'global': tg.GlobalLinear(hidden_global, node_features=hidden_node, edge_features=hidden_edge, 57 | global_features=hidden_global, aggregation=aggregation, bias=hidden_bias), 58 | 'global_relu': tg.GlobalReLU(), 59 | })) 60 | for i in range(num_layers) 61 | })) 62 | self.readout_globals = tg.GlobalLinear(1, global_features=hidden_global, bias=True) 63 | 64 | def forward(self, graphs): 65 | graphs = self.encoder(graphs) 66 | graphs = self.hidden(graphs) 67 | globals = self.readout_globals(graphs).global_features 68 | 69 | return graphs.evolve( 70 | num_nodes=0, 71 | node_features=None, 72 | num_nodes_by_graph=None, 73 | num_edges=0, 74 | num_edges_by_graph=None, 75 | edge_features=None, 76 | global_features=globals, 77 | senders=None, 78 | receivers=None 79 | ) 80 | 81 | 82 | def describe(cfg): 83 | from pathlib import Path 84 | from utils import import_ 85 | klass = import_(cfg.model.klass) 86 | model = klass(*cfg.model.args, **cfg.model.kwargs) 87 | if 'state_dict' in cfg: 88 | model.load_state_dict(torch.load(Path(cfg.state_dict).expanduser().resolve())) 89 | print(model) 90 | print(f'Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}') 91 | 92 | for name, parameter in model.named_parameters(): 93 | print(f'{name} {tuple(parameter.shape)}:') 94 | if 'state_dict' in cfg: 95 | print(parameter.numpy().round()) 96 | print() 97 | 98 | 99 | def main(): 100 | from argparse import ArgumentParser 101 | from config import Config 102 | 103 | parser = ArgumentParser() 104 | subparsers = parser.add_subparsers() 105 | 106 | sp_print = subparsers.add_parser('print', help='Print parsed configuration') 107 | sp_print.add_argument('config', nargs='*') 108 | sp_print.set_defaults(command=lambda c: print(c.toYAML())) 109 | 110 | sp_describe = subparsers.add_parser('describe', help='Describe a model') 111 | sp_describe.add_argument('config', nargs='*') 112 | sp_describe.set_defaults(command=describe) 113 | 114 | args = parser.parse_args() 115 | cfg = Config.build(*args.config) 116 | args.command(cfg) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /src/solubility/notes.md: -------------------------------------------------------------------------------- 1 | # Solubility 2 | 3 | ## Task 4 | We want to predict the water solubility (in `log mol/L`) of the organic molecules from their molecular structure. 5 | 6 | ## Data 7 | The molecules in [this dataset](../../data/delaney-processed.csv) are loaded and parsed using RDKit. 8 | Train/test split is 70/30. 9 | 10 | ## Losses 11 | 12 | Two losses are used for training: a global-level loss and a regularization term. 13 | The two terms are added together in a weighted sum and constitute the final training objective. 14 | 15 | ### Global-level regression 16 | The network should also output a global-level prediction corresponding to the log solubility of the molecule. 17 | The loss on this prediction is computed as Mean Squared Error 18 | 19 | No weighting is used to account for more/less common values. 20 | 21 | ### L1 regularization 22 | The weights of the network are regularized with L1 regularization. 23 | 24 | ## Workflow 25 | 26 | 1. Create base folder 27 | ```bash 28 | SOLUBILITY=~/experiments/solubility/ 29 | mkdir -p "$SOLUBILITY/"{runs,data} 30 | ``` 31 | 2. Launch one experiment (from the root of the repo): 32 | ```bash 33 | python -m solubility.train --experiment config/solubility/train.yaml 34 | ``` 35 | Or make a grid search over the hyperparameters (from the root of the repo): 36 | ```bash 37 | conda activate tg-experiments 38 | function train { 39 | python -m solubility.train \ 40 | --experiment config/solubility/train.yaml \ 41 | "tags=[layers${3},lr${1},bias${4},size${5},wd${2},dr${7},e${6},${8}]" \ 42 | --model "kwargs.num_layers=${3}" "kwargs.hidden_bias=${4}" "kwargs.hidden_node=${5}" "kwargs.dropout=${6}" "kwargs.aggregation=${8}"\ 43 | --optimizer "kwargs.lr=${1}" \ 44 | --session "losses.l1=${2}" "epochs=${6}" 45 | } 46 | export -f train # use bash otherwise `export -f` won't work 47 | parallel --verbose --max-procs 6 --load 200% --delay 1 --noswap \ 48 | 'train {1} {2} {3} {4} {5} {6} {7} {8}' \ 49 | `# Learning rate` ::: .01 .001 \ 50 | `# L1 loss` ::: .0001 .001 .01 \ 51 | `# Hidden layers` ::: 3 4 5 10 \ 52 | `# Hidden bias` ::: yes no \ 53 | `# Hidden node` ::: 32 64 128 256 512 \ 54 | `# Epochs` ::: 50 75 100 \ 55 | `# Dropout` ::: yes no \ 56 | `# Aggregation` ::: mean sum 57 | ``` 58 | 59 | 6. Query logs and visualize 60 | - Tensorboard: `tensorboard --logdir "$SOLUBILITY/runs"` 61 | - Find best model 62 | ```bash 63 | for f in */experiment.latest.yaml; do 64 | echo -e $(grep loss_sol_val < $f) $(dirname $f) 65 | done | sort -k 3 -g -r | cut -f2,3 | tail -n 5 66 | ``` 67 | -------------------------------------------------------------------------------- /src/solubility/predict.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import yaml 3 | import pyaml 4 | import multiprocessing 5 | from pathlib import Path 6 | from argparse import ArgumentParser 7 | 8 | import numpy as np 9 | from munch import AutoMunch, Munch 10 | 11 | import torch 12 | import torch.utils.data 13 | import torch.nn.functional as F 14 | import torchgraphs as tg 15 | 16 | from utils import parse_dotted, update_rec, import_ 17 | from .dataset import InfectionDataset 18 | 19 | parser = ArgumentParser() 20 | parser.add_argument('--model', nargs='+', required=True) 21 | parser.add_argument('--data', nargs='+', required=True, default=[]) 22 | parser.add_argument('--options', nargs='+', required=False, default=[]) 23 | parser.add_argument('--output', type=str, required=True) 24 | 25 | args = parser.parse_args() 26 | 27 | 28 | # region Collecting phase 29 | 30 | # Defaults 31 | model = Munch(fn=None, args=[], kwargs={}, state_dict=None) 32 | data = [] 33 | options = AutoMunch() 34 | options.cpus = multiprocessing.cpu_count() - 1 35 | options.device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | options.output = args.output 37 | 38 | # Model from --model args 39 | for string in args.model: 40 | if '=' in string: 41 | update = parse_dotted(string) 42 | else: 43 | with open(string, 'r') as f: 44 | update = yaml.safe_load(f) 45 | # If the yaml file contains an entry with key `model` use that one instead 46 | if 'model' in update.keys(): 47 | update = update['model'] 48 | update_rec(model, update) 49 | 50 | # Data from --data args 51 | for path in args.data: 52 | path = Path(path).expanduser().resolve() 53 | if path.is_dir(): 54 | data.extend(path.glob('*.pt')) 55 | elif path.is_file() and path.suffix == '.pt': 56 | data.append(path) 57 | else: 58 | raise ValueError(f'Invalid data: {path}') 59 | 60 | # Options from --options args 61 | for string in args.options: 62 | if '=' in string: 63 | update = parse_dotted(string) 64 | else: 65 | with open(string, 'r') as f: 66 | update = yaml.safe_load(f) 67 | update_rec(options, update) 68 | 69 | # Resolving paths 70 | model.state_dict = Path(model.state_dict).expanduser().resolve() 71 | options.output = Path(options.output).expanduser().resolve() 72 | 73 | # Checks (some missing, others redundant) 74 | if model.fn is None: 75 | raise ValueError('Model constructor function not defined') 76 | if model.state_dict is None: 77 | raise ValueError(f'Model state dict is required to predict') 78 | if len(data) == 0: 79 | raise ValueError(f'No data to predict') 80 | if options.cpus < 0: 81 | raise ValueError(f'Invalid number of cpus: {options.cpus}') 82 | if options.output.exists() and not options.output.is_dir(): 83 | raise ValueError(f'Invalid output path {options.output}') 84 | 85 | 86 | pyaml.pprint({'model': model, 'options': options, 'data': data}, sort_dicts=False, width=200) 87 | # endregion 88 | 89 | # region Building phase 90 | # Model 91 | net: torch.nn.Module = import_(model.fn)(*model.args, **model.kwargs) 92 | net.load_state_dict(torch.load(model.state_dict)) 93 | net.to(options.device) 94 | 95 | # Output folder 96 | options.output.mkdir(parents=True, exist_ok=True) 97 | # endregion 98 | 99 | # region Training 100 | # Dataset and dataloader 101 | dataset_predict: InfectionDataset = torch.load(data[0]) 102 | dataloader_predict = torch.utils.data.DataLoader( 103 | dataset_predict, 104 | shuffle=False, 105 | num_workers=min(options.cpus, 1) if 'cuda' in options.device else options.cpus, 106 | pin_memory='cuda' in options.device, 107 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)), 108 | batch_size=options.batch_size, 109 | collate_fn=tg.GraphBatch.collate, 110 | ) 111 | 112 | # region Predict 113 | net.eval() 114 | torch.set_grad_enabled(False) 115 | i = 0 116 | with tqdm.tqdm(desc='Predict', total=len(dataloader_predict.dataset), unit='g') as bar: 117 | for graphs, *_ in dataloader_predict: 118 | graphs = graphs.to(options.device) 119 | 120 | results = net(graphs) 121 | results.node_features.sigmoid_() 122 | 123 | for result in results: 124 | torch.save(result.cpu(), options.output / f'output_{i:06d}.pt') 125 | i += 1 126 | 127 | bar.update(graphs.num_graphs) 128 | # endregion 129 | -------------------------------------------------------------------------------- /src/solubility/train.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import yaml 3 | import pyaml 4 | import random 5 | import textwrap 6 | import multiprocessing 7 | from pathlib import Path 8 | from datetime import datetime 9 | from argparse import ArgumentParser 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from munch import AutoMunch 14 | 15 | import torch 16 | import torch.utils.data 17 | import torch.nn.functional as F 18 | import torchgraphs as tg 19 | from tensorboardX import SummaryWriter 20 | 21 | from saver import Saver 22 | from utils import git_info, cuda_info, parse_dotted, update_rec, set_seeds, import_, sort_dict, RunningWeightedAverage 23 | from .dataset import SolubilityDataset 24 | 25 | parser = ArgumentParser() 26 | parser.add_argument('--experiment', nargs='+', required=True) 27 | parser.add_argument('--model', nargs='+', required=False, default=[]) 28 | parser.add_argument('--optimizer', nargs='+', required=False, default=[]) 29 | parser.add_argument('--session', nargs='+', required=False, default=[]) 30 | 31 | args = parser.parse_args() 32 | 33 | 34 | # region Collecting phase 35 | class Experiment(AutoMunch): 36 | @property 37 | def session(self): 38 | return self.sessions[-1] 39 | 40 | 41 | experiment = Experiment() 42 | 43 | # Experiment defaults 44 | experiment.name = 'experiment' 45 | experiment.tags = [] 46 | experiment.samples = 0 47 | experiment.model = {'fn': None, 'args': [], 'kwargs': {}} 48 | experiment.optimizer = {'fn': None, 'args': [], 'kwargs': {}} 49 | experiment.sessions = [] 50 | 51 | # Session defaults 52 | session = AutoMunch() 53 | session.losses = {'solubility': 0, 'l1': 0} 54 | session.seed = random.randint(0, 99) 55 | session.cpus = multiprocessing.cpu_count() - 1 56 | session.device = 'cuda' if torch.cuda.is_available() else 'cpu' 57 | session.log = {'when': []} 58 | session.checkpoint = {'when': []} 59 | 60 | # Experiment configuration 61 | for string in args.experiment: 62 | if '=' in string: 63 | update = parse_dotted(string) 64 | else: 65 | with open(string, 'r') as f: 66 | update = yaml.safe_load(f) 67 | # If the current session is defined inside the experiment update the session instead 68 | if 'session' in update: 69 | update_rec(session, update.pop('session')) 70 | update_rec(experiment, update) 71 | 72 | # Model from --model args 73 | for string in args.model: 74 | if '=' in string: 75 | update = parse_dotted(string) 76 | else: 77 | with open(string, 'r') as f: 78 | update = yaml.safe_load(f) 79 | # If the yaml document contains a single entry with key `model` use that one instead 80 | if update.keys() == {'model'}: 81 | update = update['model'] 82 | update_rec(experiment.model, update) 83 | del update 84 | 85 | # Optimizer from --optimizer args 86 | for string in args.optimizer: 87 | if '=' in string: 88 | update = parse_dotted(string) 89 | else: 90 | with open(string, 'r') as f: 91 | update = yaml.safe_load(f) 92 | # If the yaml document contains a single entry with key `optimizer` use that one instead 93 | if update.keys() == {'optimizer'}: 94 | update = update['optimizer'] 95 | update_rec(experiment.optimizer, update) 96 | del update 97 | 98 | # Session from --session args 99 | for string in args.session: 100 | if '=' in string: 101 | update = parse_dotted(string) 102 | else: 103 | with open(string, 'r') as f: 104 | update = yaml.safe_load(f) 105 | # If the yaml document contains a single entry with key `session` use that one instead 106 | if update.keys() == {'session'}: 107 | update = update['session'] 108 | update_rec(session, update) 109 | del update 110 | 111 | # Checks (some missing, others redundant) 112 | if experiment.name is None or len(experiment.name) == 0: 113 | raise ValueError(f'Experiment name is empty: {experiment.name}') 114 | if experiment.tags is None: 115 | raise ValueError('Experiment tags is None') 116 | if experiment.model.fn is None: 117 | raise ValueError('Model constructor function not defined') 118 | if experiment.optimizer.fn is None: 119 | raise ValueError('Optimizer constructor function not defined') 120 | if session.cpus < 0: 121 | raise ValueError(f'Invalid number of cpus: {session.cpus}') 122 | if any(l < 0 for l in session.losses.values()) or all(l == 0 for l in session.losses.values()): 123 | raise ValueError(f'Invalid losses: {session.losses}') 124 | if len(experiment.sessions) > 0 and ('state_dict' not in experiment.model or 'state_dict' not in experiment.optimizer): 125 | raise ValueError(f'Model and optimizer state dicts are required to restore training') 126 | 127 | # Experiment computed fields 128 | experiment.epoch = sum((s.epochs for s in experiment.sessions), 0) 129 | 130 | # Session computed fields 131 | session.status = 'NEW' 132 | session.datetime_started = None 133 | session.datetime_completed = None 134 | git = git_info() 135 | if git is not None: 136 | session.git = git 137 | if 'cuda' in session.device: 138 | session.cuda = cuda_info() 139 | 140 | # Resolving paths 141 | rand_id = ''.join(chr(random.randint(ord('A'), ord('Z'))) for _ in range(6)) 142 | session.data.path = Path(session.data.path.replace('{name}', experiment.name)).expanduser().resolve().as_posix() 143 | session.log.folder = session.log.folder \ 144 | .replace('{name}', experiment.name) \ 145 | .replace('{tags}', '_'.join(experiment.tags)) \ 146 | .replace('{rand}', rand_id) 147 | if len(session.checkpoint.when) > 0: 148 | if len(session.log.when) > 0: 149 | session.log.folder = Path(session.log.folder).expanduser().resolve().as_posix() 150 | session.checkpoint.folder = session.checkpoint.folder \ 151 | .replace('{name}', experiment.name) \ 152 | .replace('{tags}', '_'.join(experiment.tags)) \ 153 | .replace('{rand}', rand_id) 154 | session.checkpoint.folder = Path(session.checkpoint.folder).expanduser().resolve().as_posix() 155 | if 'state_dict' in experiment.model: 156 | experiment.model.state_dict = Path(experiment.model.state_dict).expanduser().resolve().as_posix() 157 | if 'state_dict' in experiment.optimizer: 158 | experiment.optimizer.state_dict = Path(experiment.optimizer.state_dict).expanduser().resolve().as_posix() 159 | 160 | sort_dict(experiment, ['name', 'tags', 'epoch', 'samples', 'model', 'optimizer', 'sessions']) 161 | sort_dict(session, ['epochs', 'batch_size', 'losses', 'seed', 'cpus', 'device', 'samples', 'status', 162 | 'datetime_started', 'datetime_completed', 'data', 'log', 'checkpoint', 'git', 'gpus']) 163 | experiment.sessions.append(session) 164 | pyaml.pprint(experiment, sort_dicts=False, width=200) 165 | del session 166 | # endregion 167 | 168 | # region Building phase 169 | # Seeds (set them after the random run id is generated) 170 | set_seeds(experiment.session.seed) 171 | 172 | # Model 173 | model: torch.nn.Module = import_(experiment.model.fn)(*experiment.model.args, **experiment.model.kwargs) 174 | if 'state_dict' in experiment.model: 175 | model.load_state_dict(torch.load(experiment.model.state_dict)) 176 | model.to(experiment.session.device) 177 | 178 | # Optimizer 179 | optimizer: torch.optim.Optimizer = import_(experiment.optimizer.fn)( 180 | model.parameters(), *experiment.optimizer.args, **experiment.optimizer.kwargs) 181 | if 'state_dict' in experiment.optimizer: 182 | optimizer.load_state_dict(torch.load(experiment.optimizer.state_dict)) 183 | 184 | # Logger 185 | if len(experiment.session.log.when) > 0: 186 | logger = SummaryWriter(experiment.session.log.folder) 187 | logger.add_text( 188 | 'Experiment', textwrap.indent(pyaml.dump(experiment, safe=True, sort_dicts=False), ' '), experiment.samples) 189 | else: 190 | logger = None 191 | 192 | # Saver 193 | if len(experiment.session.checkpoint.when) > 0: 194 | saver = Saver(experiment.session.checkpoint.folder) 195 | if experiment.epoch == 0: 196 | saver.save_experiment(experiment, suffix=f'e{experiment.epoch:04d}') 197 | else: 198 | saver = None 199 | # endregion 200 | 201 | # Datasets and dataloaders 202 | dataset = SolubilityDataset(experiment.session.data.path) 203 | dataloader_kwargs = dict( 204 | num_workers=min(experiment.session.cpus, 1) if 'cuda' in experiment.session.device else experiment.session.cpus, 205 | pin_memory='cuda' in experiment.session.device, 206 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)), 207 | batch_size=experiment.session.batch_size, 208 | collate_fn=tg.GraphBatch.collate, 209 | ) 210 | 211 | dataset_train = torch.utils.data.Subset( 212 | dataset, indices=np.arange(0, int(np.floor(experiment.session.data.train * len(dataset))))) 213 | dataloader_train = torch.utils.data.DataLoader( 214 | dataset_train, 215 | shuffle=True, 216 | **dataloader_kwargs 217 | ) 218 | 219 | dataset_val = torch.utils.data.Subset( 220 | dataset, indices=np.arange(int(np.floor(experiment.session.data.train * len(dataset))), len(dataset))) 221 | dataloader_val = torch.utils.data.DataLoader( 222 | dataset_val, 223 | shuffle=False, 224 | **dataloader_kwargs 225 | ) 226 | del dataset, dataloader_kwargs 227 | 228 | # region Training 229 | # Train and validation loops 230 | experiment.session.status = 'RUNNING' 231 | experiment.session.datetime_started = datetime.utcnow() 232 | 233 | graphs_df = {k: [] for k in ['LossSolubility', 'Nodes', 'Edges', 'Pred', 'Real']} 234 | 235 | epoch_bar_postfix = {} 236 | epoch_bar = tqdm.trange(1, experiment.session.epochs + 1, desc='Epochs', unit='e', leave=True) 237 | for epoch_idx in epoch_bar: 238 | experiment.epoch += 1 239 | 240 | # region Training loop 241 | model.train() 242 | torch.set_grad_enabled(True) 243 | 244 | train_bar_postfix = {} 245 | loss_sol_avg = RunningWeightedAverage() 246 | loss_l1_avg = RunningWeightedAverage() 247 | loss_total_avg = RunningWeightedAverage() 248 | 249 | train_bar = tqdm.tqdm(desc=f'Train {experiment.epoch}', total=len(dataloader_train.dataset), unit='g') 250 | for graphs, targets in dataloader_train: 251 | graphs = graphs.to(experiment.session.device) 252 | targets = targets.to(experiment.session.device) 253 | results = model(graphs) 254 | 255 | loss_total = torch.tensor(0., device=experiment.session.device) 256 | 257 | if experiment.session.losses.solubility > 0: 258 | loss_sol = F.mse_loss( 259 | results.global_features.squeeze(), targets, reduction='mean') 260 | loss_total += experiment.session.losses.solubility * loss_sol 261 | loss_sol_avg.add(loss_sol.mean().item(), len(graphs)) 262 | train_bar_postfix['Solubility'] = f'{loss_sol.item():.5f}' 263 | 264 | if 'every batch' in experiment.session.log.when: 265 | logger.add_scalar('loss/train/solubility', loss_sol.mean().item(), global_step=experiment.samples) 266 | 267 | if experiment.session.losses.l1 > 0: 268 | loss_l1 = sum([p.abs().sum() for p in model.parameters()]) 269 | loss_total += experiment.session.losses.l1 * loss_l1 270 | loss_l1_avg.add(loss_l1.item(), len(graphs)) 271 | train_bar_postfix['L1'] = f'{loss_l1.item():.5f}' 272 | if 'every batch' in experiment.session.log.when: 273 | logger.add_scalar('loss/train/l1', loss_l1.item(), global_step=experiment.samples) 274 | 275 | loss_total_avg.add(loss_total.item(), len(graphs)) 276 | train_bar_postfix['Total'] = f'{loss_total.item():.5f}' 277 | if 'every batch' in experiment.session.log.when: 278 | logger.add_scalar('loss/train/total', loss_total.item(), global_step=experiment.samples) 279 | 280 | optimizer.zero_grad() 281 | loss_total.backward() 282 | optimizer.step(closure=None) 283 | 284 | experiment.samples += len(graphs) 285 | train_bar.update(len(graphs)) 286 | train_bar.set_postfix(train_bar_postfix) 287 | train_bar.close() 288 | 289 | experiment.loss_sol_train = loss_sol_avg.get() 290 | epoch_bar_postfix['Train'] = f'{loss_total_avg.get():.4f}' 291 | epoch_bar.set_postfix(epoch_bar_postfix) 292 | 293 | if 'every epoch' in experiment.session.log.when and 'every batch' not in experiment.session.log.when: 294 | logger.add_scalar('loss/train/total', loss_total_avg.get(), global_step=experiment.samples) 295 | if experiment.session.losses.solubility > 0: 296 | logger.add_scalar('loss/train/solubility', loss_sol_avg.get(), global_step=experiment.samples) 297 | if experiment.session.losses.l1 > 0: 298 | logger.add_scalar('loss/train/l1', loss_l1_avg.get(), global_step=experiment.samples) 299 | 300 | del train_bar, train_bar_postfix, loss_sol_avg, loss_l1_avg, loss_total_avg 301 | # endregion 302 | 303 | # region Validation loop 304 | model.eval() 305 | torch.set_grad_enabled(False) 306 | 307 | val_bar_postfix = {} 308 | loss_sol_avg = RunningWeightedAverage() 309 | loss_total_avg = RunningWeightedAverage() 310 | loss_l1 = sum([p.abs().sum() for p in model.parameters()]) 311 | val_bar_postfix['Solubility'] = '' 312 | val_bar_postfix['L1'] = f'{loss_l1.item():.5f}' 313 | 314 | val_bar = tqdm.tqdm(desc=f'Val {experiment.epoch}', total=len(dataloader_val.dataset), unit='g') 315 | for batch_idx, (graphs, targets) in enumerate(dataloader_val): 316 | graphs = graphs.to(experiment.session.device) 317 | targets = targets.to(experiment.session.device) 318 | results = model(graphs) 319 | 320 | loss_total = torch.tensor(0., device=experiment.session.device) 321 | 322 | if experiment.session.losses.solubility > 0: 323 | loss_sol = F.mse_loss( 324 | results.global_features.squeeze(), targets, reduction='mean') 325 | loss_total += experiment.session.losses.solubility * loss_sol 326 | loss_sol_avg.add(loss_sol.item(), len(graphs)) 327 | val_bar_postfix['Solubility'] = f'{loss_sol.item():.5f}' 328 | 329 | if experiment.session.losses.l1 > 0: 330 | loss_total += experiment.session.losses.l1 * loss_l1 331 | 332 | val_bar_postfix['Total'] = f'{loss_total.item():.5f}' 333 | loss_total_avg.add(loss_total.item(), len(graphs)) 334 | 335 | # region Last epoch 336 | if epoch_idx == experiment.session.epochs: 337 | loss_sol_by_graph = F.mse_loss( 338 | results.global_features.squeeze(), targets, reduction='none') 339 | 340 | graphs_df['LossSolubility'].append(loss_sol_by_graph.cpu()) 341 | graphs_df['Nodes'].append(graphs.num_nodes_by_graph.cpu()) 342 | graphs_df['Edges'].append(graphs.num_edges_by_graph.cpu()) 343 | graphs_df['Pred'].append(results.global_features.squeeze().cpu()) 344 | graphs_df['Real'].append(targets.cpu()) 345 | # endregion 346 | 347 | val_bar.update(len(graphs)) 348 | val_bar.set_postfix(val_bar_postfix) 349 | val_bar.close() 350 | 351 | experiment.loss_sol_val = loss_sol_avg.get() 352 | epoch_bar_postfix['Val'] = f'{loss_total_avg.get():.4f}' 353 | epoch_bar.set_postfix(epoch_bar_postfix) 354 | 355 | if ( 356 | 'every batch' in experiment.session.log.when or 357 | 'every epoch' in experiment.session.log.when or 358 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs 359 | ): 360 | logger.add_scalar('loss/val/total', loss_total_avg.get(), global_step=experiment.samples) 361 | if experiment.session.losses.solubility > 0: 362 | logger.add_scalar('loss/val/solubility', loss_sol_avg.get(), global_step=experiment.samples) 363 | if experiment.session.losses.l1 > 0: 364 | logger.add_scalar('loss/val/l1', loss_l1.item(), global_step=experiment.samples) 365 | 366 | del val_bar, val_bar_postfix, loss_sol_avg, loss_l1, loss_total_avg, batch_idx 367 | # endregion 368 | 369 | # Saving 370 | if epoch_idx == experiment.session.epochs: 371 | experiment.session.status = 'DONE' 372 | experiment.session.datetime_completed = datetime.utcnow() 373 | if ( 374 | 'every batch' in experiment.session.checkpoint.when or 375 | 'every epoch' in experiment.session.checkpoint.when or 376 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs 377 | ): 378 | saver.save(model, experiment, optimizer, suffix=f'e{experiment.epoch:04d}') 379 | epoch_bar.close() 380 | print() 381 | del epoch_bar, epoch_bar_postfix, epoch_idx 382 | # endregion 383 | 384 | # region Final report 385 | pd.options.display.precision = 2 386 | pd.options.display.max_columns = 999 387 | pd.options.display.expand_frame_repr = False 388 | 389 | graphs_df = pd.DataFrame({k: np.concatenate(v) for k, v in graphs_df.items()}).rename_axis('GraphId').reset_index() 390 | experiment.loss_sol = graphs_df.LossSolubility.mean() 391 | print('Solubility MSE:', experiment.loss_sol) 392 | 393 | # Split the results in ranges based on the number of nodes and compute the average loss per range 394 | df_losses_by_node_range = graphs_df \ 395 | .groupby(graphs_df.Nodes // 10) \ 396 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'LossSolubility': 'mean'}) \ 397 | .rename_axis(index='NodeRange') \ 398 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index') \ 399 | .rename(str.capitalize, axis='columns', level=1) 400 | 401 | # Split the results in ranges based on the number of nodes and keep the N worst predictions w.r.t. graph-wise loss 402 | df_worst_solubility_loss_by_node_range = graphs_df \ 403 | .groupby(graphs_df.Nodes // 10) \ 404 | .apply(lambda df_gr: df_gr.nlargest(5, 'LossSolubility').set_index('GraphId')) \ 405 | .rename_axis(index={'Nodes': 'NodeRange'}) \ 406 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0) 407 | 408 | print(f""" 409 | Losses by range: 410 | {df_losses_by_node_range}\n 411 | Worst solubility predictions: 412 | {df_worst_solubility_loss_by_node_range} 413 | """) 414 | 415 | if logger is not None: 416 | logger.add_text( 417 | 'Losses by range', 418 | textwrap.indent(df_losses_by_node_range.to_string(), ' '), 419 | global_step=experiment.samples) 420 | logger.add_text( 421 | 'Worst solubility predictions', 422 | textwrap.indent(df_worst_solubility_loss_by_node_range.to_string(), ' '), 423 | global_step=experiment.samples) 424 | del graphs_df, df_losses_by_node_range, df_worst_solubility_loss_by_node_range 425 | # endregion 426 | 427 | # region Cleanup 428 | if logger is not None: 429 | logger.close() 430 | # endregion 431 | -------------------------------------------------------------------------------- /src/test_edge_linear.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | 5 | import relevance 6 | 7 | import torchgraphs as tg 8 | 9 | torch.set_grad_enabled(False) 10 | relevance.register_function(torch.nn.Linear, relevance.LinearEpsilon) 11 | 12 | graphs = tg.GraphBatch.from_graphs([tg.Graph( 13 | node_features=torch.rand(10, 9), 14 | edge_features=torch.rand(6, 5), 15 | senders=torch.tensor([0, 0, 1, 2, 4, 5]), 16 | receivers=torch.tensor([1, 2, 2, 4, 3, 3]), 17 | global_features=torch.rand(7) 18 | )]) 19 | 20 | print('bias ef sf rf gf out input'.replace(' ', '\t')) 21 | print('-' * 45) 22 | 23 | for bias, ef, sf, rf, gf in itertools.product( 24 | [True, False, 'zero'], 25 | [5, None], 26 | [9, None], 27 | [9, None], 28 | [7, None] 29 | ): 30 | if ef is sf is rf is gf is None: 31 | continue 32 | net = tg.EdgeLinear(3, edge_features=ef, sender_features=sf, receiver_features=rf, global_features=gf, bias=bool(bias)) 33 | if bias == 'zero': 34 | net.bias.zero_() 35 | 36 | ctx = {} 37 | out = relevance.forward_relevance(net, graphs, ctx=ctx) 38 | torch.testing.assert_allclose(out.edge_features, net(graphs).edge_features) 39 | 40 | rel_out = graphs.evolve( 41 | edge_features=torch.ones_like(out.edge_features) * (out.edge_features != 0).float() 42 | ) 43 | 44 | rel_in = relevance.backward_relevance(net, rel_out, ctx=ctx) 45 | 46 | # If bias==0 then relevance is conserved at a graph level 47 | print( 48 | {True: ' x ', False: ' - ', 'zero': ' 0 '}[bias], ef if ef else '-', sf if sf else '-', 49 | rf if rf else '-', gf if gf else '-', 50 | rel_out.edge_features.sum().item(), 51 | ((rel_in.global_features.sum() if rel_in.global_features is not None else 0) + 52 | (rel_in.edge_features.sum() if rel_in.edge_features is not None else 0) + 53 | (rel_in.node_features.sum() if rel_in.node_features is not None else 0)).item(), 54 | sep='\t' 55 | ) 56 | -------------------------------------------------------------------------------- /src/test_global_linear.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | 5 | import relevance 6 | 7 | import torchgraphs as tg 8 | 9 | torch.set_grad_enabled(False) 10 | relevance.register_function(torch.nn.Linear, relevance.LinearEpsilon) 11 | 12 | graphs = tg.GraphBatch.from_graphs([tg.Graph( 13 | node_features=torch.rand(10, 9), 14 | edge_features=torch.rand(6, 5), 15 | senders=torch.tensor([0, 0, 1, 2, 4, 5]), 16 | receivers=torch.tensor([1, 2, 2, 4, 3, 3]), 17 | global_features=torch.rand(7) 18 | )]) 19 | 20 | print('bias nf ef gf R_out R_in'.replace(' ', '\t')) 21 | print('-' * 45) 22 | 23 | for bias, nf, ef, gf in itertools.product( 24 | [True, False, 'zero'], 25 | [9, None], 26 | [5, None], 27 | [7, None] 28 | ): 29 | if nf is ef is gf is None: 30 | continue 31 | net = tg.GlobalLinear(3, node_features=nf, edge_features=ef, global_features=gf, bias=bool(bias), aggregation='sum') 32 | 33 | if bias == 'zero': 34 | net.bias.zero_() 35 | 36 | ctx = {} 37 | out = relevance.forward_relevance(net, graphs, ctx=ctx) 38 | torch.testing.assert_allclose(out.node_features, net(graphs).node_features) 39 | 40 | rel_out = graphs.evolve( 41 | global_features=torch.ones_like(out.global_features) * (out.global_features != 0).float() 42 | ) 43 | 44 | rel_in = relevance.backward_relevance(net, rel_out, ctx=ctx) 45 | 46 | # If bias==0 then relevance is conserved at a graph level 47 | print( 48 | {True: ' x ', False: ' - ', 'zero': ' 0 '}[bias], nf if nf else '-', ef if ef else '-', gf if gf else '-', 49 | rel_out.global_features.sum().item(), 50 | ((rel_in.global_features.sum() if rel_in.global_features is not None else 0) + 51 | (rel_in.edge_features.sum() if rel_in.edge_features is not None else 0) + 52 | (rel_in.node_features.sum() if rel_in.node_features is not None else 0)).item(), 53 | sep='\t' 54 | ) 55 | -------------------------------------------------------------------------------- /src/test_node_linear.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | 5 | import relevance 6 | 7 | import torchgraphs as tg 8 | 9 | torch.set_grad_enabled(False) 10 | relevance.register_function(torch.nn.Linear, relevance.LinearEpsilon) 11 | 12 | graphs = tg.GraphBatch.from_graphs([tg.Graph( 13 | node_features=torch.rand(10, 9), 14 | edge_features=torch.rand(6, 5), 15 | senders=torch.tensor([0, 0, 1, 2, 4, 5]), 16 | receivers=torch.tensor([1, 2, 2, 4, 3, 3]), 17 | global_features=torch.rand(7) 18 | )]) 19 | 20 | print('agg bias nf inf outf gf out input'.replace(' ', '\t')) 21 | print('-' * 45) 22 | 23 | for agg, bias, nf, inf, outf, gf in itertools.product( 24 | ['sum', 'avg', 'max'], 25 | [True, False, 'zero'], 26 | [9, None], 27 | [5, None], 28 | [5, None], 29 | [7, None] 30 | ): 31 | if nf is inf is outf is gf is None: 32 | continue 33 | net = tg.NodeLinear(3, node_features=nf, incoming_features=inf, outgoing_features=outf, global_features=gf, bias=bias, aggregation=agg) 34 | if bias == 'zero': 35 | net.bias.zero_() 36 | 37 | ctx = {} 38 | out = relevance.forward_relevance(net, graphs, ctx=ctx) 39 | torch.testing.assert_allclose(out.node_features, net(graphs).node_features) 40 | 41 | rel_out = graphs.evolve( 42 | node_features=torch.ones_like(out.node_features) * (out.node_features != 0).float() 43 | ) 44 | 45 | rel_in = relevance.backward_relevance(net, rel_out, ctx=ctx) 46 | 47 | # If bias==0 then relevance is conserved at a graph level 48 | print( 49 | agg, {True: ' x ', False: ' - ', 'zero': ' 0 '}[bias], nf if nf else '-', inf if inf else '-', 50 | outf if outf else '-', gf if gf else '-', 51 | rel_out.node_features.sum().item(), 52 | ((rel_in.global_features.sum() if rel_in.global_features is not None else 0) + 53 | (rel_in.edge_features.sum() if rel_in.edge_features is not None else 0) + 54 | (rel_in.node_features.sum() if rel_in.node_features is not None else 0)).item(), 55 | sep='\t' 56 | ) 57 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import subprocess 3 | import collections 4 | from typing import Mapping, Iterable, MutableMapping 5 | 6 | import yaml 7 | from munch import munchify, AutoMunch 8 | 9 | 10 | def git_info(): 11 | try: 12 | import git 13 | try: 14 | result = {} 15 | repo = git.Repo(search_parent_directories=True) 16 | try: 17 | result['url'] = repo.remote(name='origin').url 18 | except ValueError: 19 | result['url'] = 'git:/' + repo.working_dir 20 | result['commit'] = repo.head.commit.hexsha 21 | result['dirty'] = repo.is_dirty() 22 | if repo.is_dirty(): 23 | result['diffs'] = [str(diff) for diff in repo.head.commit.diff(other=None, create_patch=True)] 24 | if len(repo.untracked_files) > 0: 25 | result['untracked_files'] = repo.untracked_files 26 | return result 27 | except (git.InvalidGitRepositoryError, ValueError): 28 | pass 29 | except ImportError: 30 | return None 31 | 32 | 33 | def cuda_info(): 34 | from xml.etree import ElementTree 35 | try: 36 | nvidia_smi_xml = subprocess.check_output(['nvidia-smi', '-q', '-x']).decode() 37 | except (FileNotFoundError, OSError, subprocess.CalledProcessError): 38 | return None 39 | 40 | driver = '' 41 | gpus = [] 42 | for child in ElementTree.fromstring(nvidia_smi_xml): 43 | if child.tag == 'driver_version': 44 | driver = child.text 45 | elif child.tag == 'gpu': 46 | gpus.append({ 47 | 'model': child.find('product_name').text, 48 | 'utilization': child.find('utilization').find('gpu_util').text, 49 | 'memory_used': child.find('fb_memory_usage').find('used').text, 50 | 'memory_total': child.find('fb_memory_usage').find('total').text, 51 | }) 52 | 53 | return {'driver': driver, 'gpus': gpus} 54 | 55 | 56 | def parse_dotted(string): 57 | result_dict = {} 58 | for kv_pair in string.split(' '): 59 | sub_dict = result_dict 60 | name_dotted, value = kv_pair.split('=') 61 | name_head, *name_rest = name_dotted.split('.') 62 | while len(name_rest) > 0: 63 | sub_dict = sub_dict.setdefault(name_head, {}) 64 | name_head, *name_rest = name_rest 65 | sub_dict[name_head] = yaml.safe_load(value) 66 | return result_dict 67 | 68 | 69 | def update_rec(target, source): 70 | for k in source.keys(): 71 | if k in target and isinstance(target[k], Mapping) and isinstance(source[k], Mapping): 72 | update_rec(target[k], source[k]) 73 | else: 74 | # AutoMunch should do its job, but sometimes it doesn't 75 | target[k] = munchify(source[k], AutoMunch) 76 | 77 | 78 | def import_(fullname): 79 | import importlib 80 | package, name = fullname.rsplit('.', maxsplit=1) 81 | package = importlib.import_module(package) 82 | return getattr(package, name) 83 | 84 | 85 | def set_seeds(seed): 86 | import random 87 | import torch 88 | import numpy as np 89 | random.seed(seed) 90 | np.random.seed(seed) 91 | torch.random.manual_seed(seed) 92 | 93 | 94 | def sort_dict(mapping: MutableMapping, order: Iterable): 95 | for key in itertools.chain(filter(mapping.__contains__, order), set(mapping) - set(order)): 96 | mapping[key] = mapping.pop(key) 97 | return mapping 98 | 99 | 100 | class RunningWeightedAverage(object): 101 | def __init__(self): 102 | self.total_weight = 0 103 | self.total_weighted_value = 0 104 | 105 | def add(self, value, weight): 106 | if weight <= 0: 107 | raise ValueError() 108 | self.total_weighted_value += value * weight 109 | self.total_weight += weight 110 | 111 | def get(self): 112 | if self.total_weight == 0: 113 | return 0 114 | return self.total_weighted_value / self.total_weight 115 | 116 | def __repr__(self): 117 | return f'{self.get() (self.total_weight)}' 118 | -------------------------------------------------------------------------------- /src/yaml_ext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | 4 | _init = False 5 | 6 | 7 | class TorchDevice: 8 | _tag = '!torch.device' 9 | _class = torch.device 10 | 11 | @staticmethod 12 | def represent(dumper, device): 13 | return dumper.represent_scalar(TorchDevice._tag, str(device)) 14 | 15 | @staticmethod 16 | def construct(loader, node): 17 | return torch.device(loader.construct_scalar(node)) 18 | 19 | @classmethod 20 | def register(cls): 21 | yaml.add_representer(cls._class, cls.represent, yaml.SafeDumper) 22 | yaml.add_constructor(cls._tag, cls.construct, yaml.SafeLoader) 23 | 24 | 25 | def init_ext(): 26 | global _init 27 | if not _init: 28 | TorchDevice.register() 29 | _init = True 30 | --------------------------------------------------------------------------------