├── .gitignore
├── README.md
├── conda.yaml
├── config
├── count_nodes
│ ├── dataset.yaml
│ ├── full.yaml
│ ├── minimal.yaml
│ └── train.yaml
├── infection
│ ├── datasets.yaml
│ └── train.yaml
└── solubility
│ └── train.yaml
├── data
└── delaney-processed.csv
├── models
├── infection
│ ├── lr.001_nodes.1_count1_wd.0001_MPUKHT
│ │ ├── events.out.tfevents.1555228025
│ │ ├── experiment.latest.yaml
│ │ ├── model.latest.pt
│ │ └── optimizer.latest.pt
│ ├── lr.001_nodes.1_count1_wd0_QYNATD
│ │ ├── events.out.tfevents.1555269841
│ │ ├── experiment.latest.yaml
│ │ ├── model.latest.pt
│ │ └── optimizer.latest.pt
│ ├── max_bias
│ │ ├── events.out.tfevents.1555604341
│ │ ├── experiment.latest.yaml
│ │ ├── model.latest.pt
│ │ └── optimizer.latest.pt
│ ├── max_nobias
│ │ ├── events.out.tfevents.1555604167
│ │ ├── experiment.latest.yaml
│ │ ├── model.latest.pt
│ │ └── optimizer.latest.pt
│ ├── sum_bias
│ │ ├── events.out.tfevents.1555603943
│ │ ├── experiment.latest.yaml
│ │ ├── model.latest.pt
│ │ └── optimizer.latest.pt
│ └── sum_nobias
│ │ ├── events.out.tfevents.1555603768
│ │ ├── experiment.latest.yaml
│ │ ├── model.latest.pt
│ │ └── optimizer.latest.pt
└── solubility
│ └── layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG
│ ├── events.out.tfevents.1555363822
│ ├── experiment.latest.yaml
│ ├── model.latest.pt
│ └── optimizer.latest.pt
├── notebooks
├── Hijacking-Autograd-for-LRP-Graph.ipynb
├── Hijacking-Autograd-for-LRP.ipynb
├── Infection-Explanation-BigGraph.ipynb
├── Infection-LRP-SimpleExamples.ipynb
├── LRP-index-scatter.ipynb
├── LRP-linear-ReLU.ipynb
├── NxGraphs.ipynb
├── Solubility-GraphFeatures.ipynb
├── Solubility-LRP.ipynb
└── biggraph.pt
├── resources
├── sucrose-atoms.png
├── sucrose-bonds.png
└── sucrose.png
├── scripts.sh
├── setup.py
└── src
├── __init__.py
├── config.py
├── count_nodes
├── __init__.py
├── dataset.py
├── layout.py
├── networks.py
├── notes.md
└── train.py
├── guidedbackprop
├── __init__.py
├── autograd_tricks.py
└── graphs.py
├── infection
├── __init__.py
├── dataset.py
├── layout.py
├── networks.py
├── notes.md
├── predict.py
└── train.py
├── relevance
├── __init__.py
├── autograd_tricks.py
├── graphs.py
├── oldies.py
└── patch.py
├── relevance_regression.py
├── saver.py
├── solubility
├── __init__.py
├── dataset.py
├── layout.py
├── networks.py
├── notes.md
├── predict.py
└── train.py
├── test_edge_linear.py
├── test_global_linear.py
├── test_node_linear.py
├── utils.py
└── yaml_ext.py
/.gitignore:
--------------------------------------------------------------------------------
1 | notebooks/
2 | data
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Jupyter Notebook
57 | .ipynb_checkpoints
58 |
59 | # IPython
60 | profile_default/
61 | ipython_config.py
62 |
63 | # Environments
64 | .env
65 | .venv
66 | env/
67 | venv/
68 | ENV/
69 | env.bak/
70 | venv.bak/
71 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Explainability Techniques for Graph Convolutional Networks
2 |
3 | Code and notebooks for the paper ["Explainability Techniques for Graph Convolutional Networks"](https://arxiv.org/abs/1905.13686)
4 | accepted at the ICML 2019 Workshop ["Learning and Reasoning with Graph-Structured Data"](https://graphreason.github.io/).
5 |
6 | ## Overview
7 | A Graph Network trained to predict the solubility of organic molecules is applied to _sucrose_,
8 | the prediction is explained using [Layer-wise Relevance Propagation](http://heatmapping.org) that assigns
9 | positive and negative relevance to the nodes and edges of the molecular graph:
10 |
11 | 
12 |
13 | The predicted solubility can be broken down to the individual features of the atoms and their bonds:
14 |
15 | 
16 | 
17 |
18 | ## Code structure
19 | - `src`, `config`, `data` contain code, configuration files and data for the experiments
20 | - `infection`, `solubility` contain the code for the two experiments in the paper
21 | - `torchgraphs` contain the core graph network library
22 | - `guidedbackrprop`, `relevance` contain the code to run Guided Backpropagation and Layer-wise Relevance Propagation on top of PyTorch's `autograd`
23 | - `notebooks`, `models` contain a visualization of the datasets, the trained models and the results of our experiments
24 | - `test` contains unit tests for the `torchgraphs` module (core GN library)
25 | - `conda.yaml` contains the conda environment for the project
26 |
27 | ## Setup
28 | The project is build on top of Python 3.7, PyTorch 1.1+,
29 | [torchgraphs](https://github.com/baldassarreFe/torchgraphs) 0.0.1 and many other open source projects.
30 |
31 | A [Conda](https://conda.io) environment for the project can be installed as:
32 | ```bash
33 | conda env create -n gn-exp -f conda.yaml
34 | conda activate gn-exp
35 | python setup.py develop
36 | pytest
37 | ```
38 |
39 | ## Training
40 | Detailed instructions for data processing, training and hyperparameter search can be found in the respective subfolders:
41 | - Infection: [infection/notes.md](./src/infection/notes.md)
42 | - Solubility: [solubility/notes.md](./src/solubility/notes.md)
43 |
44 | ## Experimental results
45 | The results of our experiments are visualized through the notebooks in [`notebooks`](./notebooks):
46 | ```bash
47 | conda activate gn-exp
48 | cd notebooks
49 | jupyter lab
50 | ```
51 |
--------------------------------------------------------------------------------
/conda.yaml:
--------------------------------------------------------------------------------
1 | name: tg-experiments
2 | channels:
3 | - rdkit
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - attrs=19.1.0=py37_1
8 | - backcall=0.1.0=py37_0
9 | - blas=1.0=mkl
10 | - bleach=3.1.0=py37_0
11 | - bzip2=1.0.6=h14c3975_5
12 | - ca-certificates=2019.1.23=0
13 | - cairo=1.14.12=h8948797_3
14 | - certifi=2019.3.9=py37_0
15 | - cffi=1.12.2=py37h2e261b9_1
16 | - cudatoolkit=10.0.130=0
17 | - cudnn=7.3.1=cuda10.0_0
18 | - cycler=0.10.0=py37_0
19 | - dbus=1.13.6=h746ee38_0
20 | - decorator=4.4.0=py37_1
21 | - defusedxml=0.5.0=py37_1
22 | - entrypoints=0.3=py37_0
23 | - expat=2.2.6=he6710b0_0
24 | - fontconfig=2.13.0=h9420a91_0
25 | - freetype=2.9.1=h8a8886c_1
26 | - glib=2.56.2=hd408876_0
27 | - gmp=6.1.2=h6c8ec71_1
28 | - gst-plugins-base=1.14.0=hbbd80ab_1
29 | - gstreamer=1.14.0=hb453b48_1
30 | - icu=58.2=h9c2bf20_1
31 | - intel-openmp=2019.3=199
32 | - ipykernel=5.1.0=py37h39e3cac_0
33 | - ipython=7.4.0=py37h39e3cac_0
34 | - ipython_genutils=0.2.0=py37_0
35 | - jedi=0.13.3=py37_0
36 | - jinja2=2.10.1=py37_0
37 | - jpeg=9b=h024ee3a_2
38 | - jsonschema=3.0.1=py37_0
39 | - jupyter_client=5.2.4=py37_0
40 | - jupyter_console=6.0.0=py37_0
41 | - jupyter_core=4.4.0=py37_0
42 | - jupyterlab=0.35.4=py37hf63ae98_0
43 | - jupyterlab_server=0.2.0=py37_0
44 | - kiwisolver=1.0.1=py37hf484d3e_0
45 | - libboost=1.67.0=h46d08c1_4
46 | - libedit=3.1.20181209=hc058e9b_0
47 | - libffi=3.2.1=hd88cf55_4
48 | - libgcc-ng=8.2.0=hdf63c60_1
49 | - libgfortran-ng=7.3.0=hdf63c60_0
50 | - libpng=1.6.36=hbc83047_0
51 | - libsodium=1.0.16=h1bed415_0
52 | - libstdcxx-ng=8.2.0=hdf63c60_1
53 | - libtiff=4.0.10=h2733197_2
54 | - libuuid=1.0.3=h1bed415_2
55 | - libxcb=1.13=h1bed415_1
56 | - libxml2=2.9.9=he19cac6_0
57 | - markupsafe=1.1.1=py37h7b6447c_0
58 | - matplotlib=3.0.3=py37h5429711_0
59 | - mistune=0.8.4=py37h7b6447c_0
60 | - mkl=2019.3=199
61 | - mkl_fft=1.0.12=py37ha843d7b_0
62 | - mkl_random=1.0.2=py37hd81dba3_0
63 | - munch=2.3.2=py37_0
64 | - nbconvert=5.4.1=py37_3
65 | - nbformat=4.4.0=py37_0
66 | - ncurses=6.1=he6710b0_1
67 | - ninja=1.9.0=py37hfd86e86_0
68 | - notebook=5.7.8=py37_0
69 | - numpy-base=1.16.4=py37hde5b4d6_0
70 | - olefile=0.46=py37_0
71 | - openssl=1.1.1c=h7b6447c_1
72 | - pandas=0.24.2=py37he6710b0_0
73 | - pandoc=2.2.3.2=0
74 | - pandocfilters=1.4.2=py37_1
75 | - parso=0.3.4=py37_0
76 | - pcre=8.43=he6710b0_0
77 | - pexpect=4.6.0=py37_0
78 | - pickleshare=0.7.5=py37_0
79 | - pillow=5.4.1=py37h34e0f95_0
80 | - pip=19.0.3=py37_0
81 | - pixman=0.38.0=h7b6447c_0
82 | - prometheus_client=0.6.0=py37_0
83 | - prompt_toolkit=2.0.9=py37_0
84 | - ptyprocess=0.6.0=py37_0
85 | - py-boost=1.67.0=py37h04863e7_4
86 | - pyaml=18.11.0=py37_0
87 | - pycparser=2.19=py37_0
88 | - pygments=2.3.1=py37_0
89 | - pyparsing=2.4.0=py_0
90 | - pyqt=5.9.2=py37h05f1152_2
91 | - pyrsistent=0.14.11=py37h7b6447c_0
92 | - python=3.7.3=h0371630_0
93 | - python-dateutil=2.8.0=py37_0
94 | - pytorch=1.1.0=py3.7_cuda10.0.130_cudnn7.5.1_0
95 | - pytz=2018.9=py37_0
96 | - pyyaml=5.1=py37h7b6447c_0
97 | - pyzmq=18.0.0=py37he6710b0_0
98 | - qt=5.9.7=h5867ecd_1
99 | - rdkit=2019.03.1.0=py37hc20afe1_1
100 | - readline=7.0=h7b6447c_5
101 | - scikit-learn=0.20.3=py37hd81dba3_0
102 | - scipy=1.2.1=py37h7c811a0_0
103 | - send2trash=1.5.0=py37_0
104 | - setuptools=41.0.0=py37_0
105 | - sip=4.19.8=py37hf484d3e_0
106 | - six=1.12.0=py37_0
107 | - sqlite=3.27.2=h7b6447c_0
108 | - terminado=0.8.1=py37_1
109 | - testpath=0.4.2=py37_0
110 | - tk=8.6.8=hbc83047_0
111 | - tornado=6.0.2=py37h7b6447c_0
112 | - tqdm=4.31.1=py37_1
113 | - traitlets=4.3.2=py37_0
114 | - wcwidth=0.1.7=py37_0
115 | - webencodings=0.5.1=py37_1
116 | - wheel=0.33.1=py37_0
117 | - xz=5.2.4=h14c3975_4
118 | - yaml=0.1.7=had09818_2
119 | - zeromq=4.3.1=he6710b0_3
120 | - zlib=1.2.11=h7b6447c_3
121 | - zstd=1.3.7=h0b5b093_0
122 | - pip:
123 | - networkx==2.3
124 | - numpy==1.16.4
125 | - protobuf==3.7.1
126 | - tensorboardx==1.6
127 | - torch==1.1.0
128 | - torch-scatter==1.2.0
129 | - git+https://github.com/baldassarreFe/torchgraphs@v0.0.1#egg=torchgraphs
130 |
--------------------------------------------------------------------------------
/config/count_nodes/dataset.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | edge_features_shape: 2
3 | node_features_shape: 4
4 | global_features_shape: 2
5 | informative_features: 2
6 | _train_:
7 | min_nodes: 0
8 | max_nodes: 20
9 | num_samples: 100_000
10 | _val_:
11 | min_nodes: 20
12 | max_nodes: 60
13 | num_samples: 50_000
14 | _test_:
15 | min_nodes: 0
16 | max_nodes: 60
17 | num_samples: 100_000
18 |
19 | opts:
20 | seed: _auto_
21 | folder: ~/experiments/count-nodes
--------------------------------------------------------------------------------
/config/count_nodes/full.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | _class_: count_nodes.networks.FullGN
3 | in_edge_features_shape: 2
4 | in_node_features_shape: 4
5 | in_global_features_shape: 2
6 | out_edge_features_shape: 2
7 | out_node_features_shape: 1
8 | out_global_features_shape: 1
9 |
--------------------------------------------------------------------------------
/config/count_nodes/minimal.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | _class_: count_nodes.networks.MinimalGN
3 | in_node_features_shape: 4
4 | out_node_features_shape: 1
5 | out_global_features_shape: 1
6 |
--------------------------------------------------------------------------------
/config/count_nodes/train.yaml:
--------------------------------------------------------------------------------
1 | training:
2 | batch_size: 1_000
3 | epochs: 20
4 | save_every: 0
5 | restore: False
6 | l1: 0
7 |
8 | opts:
9 | log: True
10 | folder: ~/experiments/count-nodes
11 | session: _auto_
12 | cpus: _auto_
13 | seed: _auto_
14 | device: _auto_
15 |
16 | optimizer:
17 | _class_: torch.optim.Adam
18 |
--------------------------------------------------------------------------------
/config/infection/datasets.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | train:
3 | min_nodes: 10
4 | max_nodes: 30
5 | max_percent_sick: .1
6 | max_percent_immune: .3
7 | max_percent_virtual: .3
8 | num_samples: 100_000
9 | val:
10 | min_nodes: 10
11 | max_nodes: 60
12 | max_percent_sick: .4
13 | max_percent_immune: .6
14 | max_percent_virtual: .5
15 | num_samples: 20_000
16 |
17 | folder: ~/experiments/infection/data
--------------------------------------------------------------------------------
/config/infection/train.yaml:
--------------------------------------------------------------------------------
1 | name: infection
2 |
3 | model:
4 | fn: infection.networks.InfectionGN
5 | kwargs:
6 | aggregation: max
7 | bias: yes
8 |
9 | optimizer:
10 | fn: torch.optim.Adam
11 | kwargs:
12 | lr: .001
13 |
14 | session:
15 | epochs: 20
16 | batch_size: 1_000
17 | losses:
18 | nodes: 1
19 | count: 1
20 | l1: 0
21 | data:
22 | folder: ~/experiments/{name}/data
23 | log:
24 | folder: ~/experiments/{name}/runs/{tags}_{rand}
25 | when:
26 | - every batch
27 | checkpoint:
28 | folder: ~/experiments/{name}/runs/{tags}_{rand}
29 | when:
30 | #- every epoch
31 | - last epoch
32 |
--------------------------------------------------------------------------------
/config/solubility/train.yaml:
--------------------------------------------------------------------------------
1 | name: solubility
2 |
3 | model:
4 | fn: solubility.networks.SolubilityGN
5 | kwargs:
6 | num_layers: 2
7 | hidden_bias: yes
8 | hidden_node: 16
9 | aggregation: mean
10 |
11 | optimizer:
12 | fn: torch.optim.Adam
13 | kwargs:
14 | lr: .001
15 |
16 | session:
17 | epochs: 20
18 | batch_size: 50
19 | losses:
20 | solubility: 1
21 | l1: 0
22 | data:
23 | path: ~/experiments/{name}/data/delaney-processed.csv
24 | train: .7
25 | val: .3
26 | log:
27 | folder: ~/experiments/{name}/runs/{tags}_{rand}
28 | when:
29 | - every batch
30 | checkpoint:
31 | folder: ~/experiments/{name}/runs/{tags}_{rand}
32 | when:
33 | - last epoch
34 |
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/events.out.tfevents.1555228025:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/events.out.tfevents.1555228025
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/experiment.latest.yaml:
--------------------------------------------------------------------------------
1 | name: infection
2 | tags:
3 | - lr.001
4 | - nodes.1
5 | - count1
6 | - wd.0001
7 | epoch: 20
8 | samples: 2000000
9 | model:
10 | fn: infection.networks.InfectionGN
11 | args: []
12 | kwargs: {}
13 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT/checkpoints/model.e0020.pt
14 | optimizer:
15 | fn: torch.optim.Adam
16 | args: []
17 | kwargs:
18 | lr: 0.001
19 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT/checkpoints/optimizer.e0020.pt
20 | sessions:
21 | - epochs: 20
22 | batch_size: 1000
23 | losses:
24 | nodes: 0.1
25 | count: 1
26 | l1: 0.0001
27 | seed: 60
28 | cpus: 31
29 | device: cuda
30 | status: DONE
31 | datetime_started: 2019-04-14 07:47:39.901626
32 | datetime_completed: 2019-04-14 07:51:05.397617
33 | data:
34 | folder: /experiments/infection/data
35 | log:
36 | when:
37 | - every batch
38 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT
39 | checkpoint:
40 | when:
41 | - last epoch
42 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd.0001_MPUKHT
43 | cuda:
44 | driver: '396.37'
45 | gpus:
46 | - model: GeForce GTX 1080 Ti
47 | utilization: 18 %
48 | memory_used: 1639 MiB
49 | memory_total: 11178 MiB
50 | - model: GeForce GTX 1080 Ti
51 | utilization: 0 %
52 | memory_used: 0 MiB
53 | memory_total: 11178 MiB
54 | - model: GeForce GTX 1080 Ti
55 | utilization: 0 %
56 | memory_used: 0 MiB
57 | memory_total: 11178 MiB
58 | - model: GeForce GTX 1080 Ti
59 | utilization: 0 %
60 | memory_used: 0 MiB
61 | memory_total: 11178 MiB
62 | - model: GeForce GTX 1080 Ti
63 | utilization: 0 %
64 | memory_used: 0 MiB
65 | memory_total: 11178 MiB
66 | - model: GeForce GTX 1080 Ti
67 | utilization: 0 %
68 | memory_used: 0 MiB
69 | memory_total: 11178 MiB
70 |
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/model.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/model.latest.pt
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/optimizer.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd.0001_MPUKHT/optimizer.latest.pt
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/events.out.tfevents.1555269841:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/events.out.tfevents.1555269841
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/experiment.latest.yaml:
--------------------------------------------------------------------------------
1 | name: infection
2 | tags:
3 | - lr.001
4 | - nodes.1
5 | - count1
6 | - wd0
7 | epoch: 20
8 | samples: 2000000
9 | model:
10 | fn: infection.networks.InfectionGN
11 | args: []
12 | kwargs: {}
13 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD/checkpoints/model.e0020.pt
14 | optimizer:
15 | fn: torch.optim.Adam
16 | args: []
17 | kwargs:
18 | lr: 0.001
19 | state_dict: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD/checkpoints/optimizer.e0020.pt
20 | sessions:
21 | - epochs: 20
22 | batch_size: 1000
23 | losses:
24 | nodes: 0.1
25 | count: 1
26 | l1: 0
27 | seed: 76
28 | cpus: 31
29 | device: cuda
30 | status: DONE
31 | datetime_started: 2019-04-14 19:24:34.801354
32 | datetime_completed: 2019-04-14 19:28:47.920841
33 | data:
34 | folder: /experiments/infection/data
35 | log:
36 | when:
37 | - every batch
38 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD
39 | checkpoint:
40 | when:
41 | - last epoch
42 | folder: /experiments/infection/runs/lr.001_nodes.1_count1_wd0_QYNATD
43 | cuda:
44 | driver: '396.37'
45 | gpus:
46 | - model: GeForce GTX 1080 Ti
47 | utilization: 1 %
48 | memory_used: 3086 MiB
49 | memory_total: 11178 MiB
50 | - model: GeForce GTX 1080 Ti
51 | utilization: 0 %
52 | memory_used: 0 MiB
53 | memory_total: 11178 MiB
54 | - model: GeForce GTX 1080 Ti
55 | utilization: 0 %
56 | memory_used: 0 MiB
57 | memory_total: 11178 MiB
58 | - model: GeForce GTX 1080 Ti
59 | utilization: 0 %
60 | memory_used: 0 MiB
61 | memory_total: 11178 MiB
62 | - model: GeForce GTX 1080 Ti
63 | utilization: 0 %
64 | memory_used: 0 MiB
65 | memory_total: 11178 MiB
66 | - model: GeForce GTX 1080 Ti
67 | utilization: 0 %
68 | memory_used: 10 MiB
69 | memory_total: 11178 MiB
70 |
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/model.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/model.latest.pt
--------------------------------------------------------------------------------
/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/optimizer.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/lr.001_nodes.1_count1_wd0_QYNATD/optimizer.latest.pt
--------------------------------------------------------------------------------
/models/infection/max_bias/events.out.tfevents.1555604341:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_bias/events.out.tfevents.1555604341
--------------------------------------------------------------------------------
/models/infection/max_bias/experiment.latest.yaml:
--------------------------------------------------------------------------------
1 | name: infection
2 | tags: []
3 | epoch: 20
4 | samples: 2000000
5 | model:
6 | fn: infection.networks.InfectionGN
7 | args: []
8 | kwargs:
9 | aggregation: max
10 | bias: true
11 | state_dict: /experiments/infection/runs/_YAPSEY/checkpoints/model.e0020.pt
12 | optimizer:
13 | fn: torch.optim.Adam
14 | args: []
15 | kwargs:
16 | lr: 0.001
17 | state_dict: /experiments/infection/runs/_YAPSEY/checkpoints/optimizer.e0020.pt
18 | sessions:
19 | - epochs: 20
20 | batch_size: 1000
21 | losses:
22 | nodes: 1
23 | count: 0
24 | l1: 0.001
25 | seed: 6
26 | cpus: 11
27 | device: cuda
28 | status: DONE
29 | datetime_started: 2019-04-18 16:19:27.499119
30 | datetime_completed: 2019-04-18 16:21:42.604716
31 | data:
32 | folder: /experiments/infection/data
33 | log:
34 | when:
35 | - every batch
36 | folder: /experiments/infection/runs/_YAPSEY
37 | checkpoint:
38 | when:
39 | - last epoch
40 | folder: /experiments/infection/runs/_YAPSEY
41 | cuda:
42 | driver: '418.43'
43 | gpus:
44 | - model: GeForce GTX 1050 Ti with Max-Q Design
45 | utilization: 0 %
46 | memory_used: 10 MiB
47 | memory_total: 4042 MiB
48 |
--------------------------------------------------------------------------------
/models/infection/max_bias/model.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_bias/model.latest.pt
--------------------------------------------------------------------------------
/models/infection/max_bias/optimizer.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_bias/optimizer.latest.pt
--------------------------------------------------------------------------------
/models/infection/max_nobias/events.out.tfevents.1555604167:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_nobias/events.out.tfevents.1555604167
--------------------------------------------------------------------------------
/models/infection/max_nobias/experiment.latest.yaml:
--------------------------------------------------------------------------------
1 | name: infection
2 | tags: []
3 | epoch: 20
4 | samples: 2000000
5 | model:
6 | fn: infection.networks.InfectionGN
7 | args: []
8 | kwargs:
9 | aggregation: max
10 | bias: false
11 | state_dict: /experiments/infection/runs/_VFJYGG/checkpoints/model.e0020.pt
12 | optimizer:
13 | fn: torch.optim.Adam
14 | args: []
15 | kwargs:
16 | lr: 0.001
17 | state_dict: /experiments/infection/runs/_VFJYGG/checkpoints/optimizer.e0020.pt
18 | sessions:
19 | - epochs: 20
20 | batch_size: 1000
21 | losses:
22 | nodes: 1
23 | count: 0
24 | l1: 0.001
25 | seed: 44
26 | cpus: 11
27 | device: cuda
28 | status: DONE
29 | datetime_started: 2019-04-18 16:16:34.420894
30 | datetime_completed: 2019-04-18 16:18:45.621812
31 | data:
32 | folder: /experiments/infection/data
33 | log:
34 | when:
35 | - every batch
36 | folder: /experiments/infection/runs/_VFJYGG
37 | checkpoint:
38 | when:
39 | - last epoch
40 | folder: /experiments/infection/runs/_VFJYGG
41 | cuda:
42 | driver: '418.43'
43 | gpus:
44 | - model: GeForce GTX 1050 Ti with Max-Q Design
45 | utilization: 0 %
46 | memory_used: 10 MiB
47 | memory_total: 4042 MiB
48 |
--------------------------------------------------------------------------------
/models/infection/max_nobias/model.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_nobias/model.latest.pt
--------------------------------------------------------------------------------
/models/infection/max_nobias/optimizer.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/max_nobias/optimizer.latest.pt
--------------------------------------------------------------------------------
/models/infection/sum_bias/events.out.tfevents.1555603943:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_bias/events.out.tfevents.1555603943
--------------------------------------------------------------------------------
/models/infection/sum_bias/experiment.latest.yaml:
--------------------------------------------------------------------------------
1 | name: infection
2 | tags: []
3 | epoch: 20
4 | samples: 2000000
5 | model:
6 | fn: infection.networks.InfectionGN
7 | args: []
8 | kwargs:
9 | aggregation: sum
10 | bias: true
11 | state_dict: /experiments/infection/runs/_EHRTVG/checkpoints/model.e0020.pt
12 | optimizer:
13 | fn: torch.optim.Adam
14 | args: []
15 | kwargs:
16 | lr: 0.001
17 | state_dict: /experiments/infection/runs/_EHRTVG/checkpoints/optimizer.e0020.pt
18 | sessions:
19 | - epochs: 20
20 | batch_size: 1000
21 | losses:
22 | nodes: 1
23 | count: 0
24 | l1: 0.001
25 | seed: 89
26 | cpus: 11
27 | device: cuda
28 | status: DONE
29 | datetime_started: 2019-04-18 16:12:49.619048
30 | datetime_completed: 2019-04-18 16:15:03.664174
31 | data:
32 | folder: /experiments/infection/data
33 | log:
34 | when:
35 | - every batch
36 | folder: /experiments/infection/runs/_EHRTVG
37 | checkpoint:
38 | when:
39 | - last epoch
40 | folder: /experiments/infection/runs/_EHRTVG
41 | cuda:
42 | driver: '418.43'
43 | gpus:
44 | - model: GeForce GTX 1050 Ti with Max-Q Design
45 | utilization: 0 %
46 | memory_used: 10 MiB
47 | memory_total: 4042 MiB
48 |
--------------------------------------------------------------------------------
/models/infection/sum_bias/model.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_bias/model.latest.pt
--------------------------------------------------------------------------------
/models/infection/sum_bias/optimizer.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_bias/optimizer.latest.pt
--------------------------------------------------------------------------------
/models/infection/sum_nobias/events.out.tfevents.1555603768:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_nobias/events.out.tfevents.1555603768
--------------------------------------------------------------------------------
/models/infection/sum_nobias/experiment.latest.yaml:
--------------------------------------------------------------------------------
1 | name: infection
2 | tags: []
3 | epoch: 20
4 | samples: 2000000
5 | model:
6 | fn: infection.networks.InfectionGN
7 | args: []
8 | kwargs:
9 | aggregation: sum
10 | bias: false
11 | state_dict: /experiments/infection/runs/_QDOVMP/checkpoints/model.e0020.pt
12 | optimizer:
13 | fn: torch.optim.Adam
14 | args: []
15 | kwargs:
16 | lr: 0.001
17 | state_dict: /experiments/infection/runs/_QDOVMP/checkpoints/optimizer.e0020.pt
18 | sessions:
19 | - epochs: 20
20 | batch_size: 1000
21 | losses:
22 | nodes: 1
23 | count: 0
24 | l1: 0.001
25 | seed: 30
26 | cpus: 11
27 | device: cuda
28 | status: DONE
29 | datetime_started: 2019-04-18 16:09:54.821299
30 | datetime_completed: 2019-04-18 16:12:04.285842
31 | data:
32 | folder: /experiments/infection/data
33 | log:
34 | when:
35 | - every batch
36 | folder: /experiments/infection/runs/_QDOVMP
37 | checkpoint:
38 | when:
39 | - last epoch
40 | folder: /experiments/infection/runs/_QDOVMP
41 | cuda:
42 | driver: '418.43'
43 | gpus:
44 | - model: GeForce GTX 1050 Ti with Max-Q Design
45 | utilization: 0 %
46 | memory_used: 10 MiB
47 | memory_total: 4042 MiB
48 |
--------------------------------------------------------------------------------
/models/infection/sum_nobias/model.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_nobias/model.latest.pt
--------------------------------------------------------------------------------
/models/infection/sum_nobias/optimizer.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/infection/sum_nobias/optimizer.latest.pt
--------------------------------------------------------------------------------
/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/events.out.tfevents.1555363822:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/events.out.tfevents.1555363822
--------------------------------------------------------------------------------
/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/experiment.latest.yaml:
--------------------------------------------------------------------------------
1 | name: solubility
2 | tags:
3 | - layers3
4 | - lr.01
5 | - biasyes
6 | - size64
7 | - wd.001
8 | - dryes
9 | - e50
10 | - sum
11 | epoch: 50
12 | samples: 39450
13 | model:
14 | fn: solubility.networks.SolubilityGN
15 | args: []
16 | kwargs:
17 | num_layers: 3
18 | hidden_bias: true
19 | hidden_node: 64
20 | aggregation: sum
21 | dropout: 50
22 | state_dict: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/checkpoints/model.e0050.pt
23 | optimizer:
24 | fn: torch.optim.Adam
25 | args: []
26 | kwargs:
27 | lr: 0.01
28 | state_dict: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/checkpoints/optimizer.e0050.pt
29 | sessions:
30 | - epochs: 50
31 | batch_size: 50
32 | losses:
33 | solubility: 1
34 | l1: 0.001
35 | seed: 94
36 | cpus: 31
37 | device: cuda
38 | status: DONE
39 | datetime_started: 2019-04-15 21:30:22.449869
40 | datetime_completed: 2019-04-15 21:42:14.467594
41 | data:
42 | path: /experiments/solubility/data/delaney-processed.csv
43 | train: 0.7
44 | val: 0.3
45 | log:
46 | when:
47 | - every batch
48 | folder: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG
49 | checkpoint:
50 | when:
51 | - last epoch
52 | folder: /experiments/solubility/runs/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG
53 | cuda:
54 | driver: '396.37'
55 | gpus:
56 | - model: GeForce GTX 1080 Ti
57 | utilization: 5 %
58 | memory_used: 2432 MiB
59 | memory_total: 11178 MiB
60 | - model: GeForce GTX 1080 Ti
61 | utilization: 0 %
62 | memory_used: 0 MiB
63 | memory_total: 11178 MiB
64 | - model: GeForce GTX 1080 Ti
65 | utilization: 0 %
66 | memory_used: 0 MiB
67 | memory_total: 11178 MiB
68 | - model: GeForce GTX 1080 Ti
69 | utilization: 0 %
70 | memory_used: 0 MiB
71 | memory_total: 11178 MiB
72 | - model: GeForce GTX 1080 Ti
73 | utilization: 0 %
74 | memory_used: 0 MiB
75 | memory_total: 11178 MiB
76 | - model: GeForce GTX 1080 Ti
77 | utilization: 0 %
78 | memory_used: 0 MiB
79 | memory_total: 11178 MiB
80 | loss_sol_train: 0.5124920198672624
81 | loss_sol_val: 0.5948061519316165
82 |
--------------------------------------------------------------------------------
/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/model.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/model.latest.pt
--------------------------------------------------------------------------------
/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/optimizer.latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/models/solubility/layers3_lr.01_biasyes_size64_wd.001_dryes_e50_sum_KCJGWG/optimizer.latest.pt
--------------------------------------------------------------------------------
/notebooks/Solubility-GraphFeatures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Solubility"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 181,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import torch\n",
17 | "import numpy as np\n",
18 | "import pandas as pd\n",
19 | "from pandas.api.types import CategoricalDtype\n",
20 | "\n",
21 | "from rdkit import Chem\n",
22 | "from rdkit.Chem import AllChem\n",
23 | "\n",
24 | "import torchgraphs as tg"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 182,
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "data": {
34 | "text/html": [
35 | "
\n",
36 | "\n",
49 | "
\n",
50 | " \n",
51 | " \n",
52 | " | \n",
53 | " Compound ID | \n",
54 | " ESOL predicted log solubility in mols per litre | \n",
55 | " Minimum Degree | \n",
56 | " Molecular Weight | \n",
57 | " Number of H-Bond Donors | \n",
58 | " Number of Rings | \n",
59 | " Number of Rotatable Bonds | \n",
60 | " Polar Surface Area | \n",
61 | " measured log solubility in mols per litre | \n",
62 | " smiles | \n",
63 | "
\n",
64 | " \n",
65 | " \n",
66 | " \n",
67 | " 0 | \n",
68 | " Amigdalin | \n",
69 | " -0.974 | \n",
70 | " 1 | \n",
71 | " 457.432 | \n",
72 | " 7 | \n",
73 | " 3 | \n",
74 | " 7 | \n",
75 | " 202.32 | \n",
76 | " -0.77 | \n",
77 | " OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)... | \n",
78 | "
\n",
79 | " \n",
80 | " 1 | \n",
81 | " Fenfuram | \n",
82 | " -2.885 | \n",
83 | " 1 | \n",
84 | " 201.225 | \n",
85 | " 1 | \n",
86 | " 2 | \n",
87 | " 2 | \n",
88 | " 42.24 | \n",
89 | " -3.30 | \n",
90 | " Cc1occc1C(=O)Nc2ccccc2 | \n",
91 | "
\n",
92 | " \n",
93 | " 2 | \n",
94 | " citral | \n",
95 | " -2.579 | \n",
96 | " 1 | \n",
97 | " 152.237 | \n",
98 | " 0 | \n",
99 | " 0 | \n",
100 | " 4 | \n",
101 | " 17.07 | \n",
102 | " -2.06 | \n",
103 | " CC(C)=CCCC(C)=CC(=O) | \n",
104 | "
\n",
105 | " \n",
106 | " 3 | \n",
107 | " Picene | \n",
108 | " -6.618 | \n",
109 | " 2 | \n",
110 | " 278.354 | \n",
111 | " 0 | \n",
112 | " 5 | \n",
113 | " 0 | \n",
114 | " 0.00 | \n",
115 | " -7.87 | \n",
116 | " c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43 | \n",
117 | "
\n",
118 | " \n",
119 | " 4 | \n",
120 | " Thiophene | \n",
121 | " -2.232 | \n",
122 | " 2 | \n",
123 | " 84.143 | \n",
124 | " 0 | \n",
125 | " 1 | \n",
126 | " 0 | \n",
127 | " 0.00 | \n",
128 | " -1.33 | \n",
129 | " c1ccsc1 | \n",
130 | "
\n",
131 | " \n",
132 | "
\n",
133 | "
"
134 | ],
135 | "text/plain": [
136 | " Compound ID ESOL predicted log solubility in mols per litre \\\n",
137 | "0 Amigdalin -0.974 \n",
138 | "1 Fenfuram -2.885 \n",
139 | "2 citral -2.579 \n",
140 | "3 Picene -6.618 \n",
141 | "4 Thiophene -2.232 \n",
142 | "\n",
143 | " Minimum Degree Molecular Weight Number of H-Bond Donors Number of Rings \\\n",
144 | "0 1 457.432 7 3 \n",
145 | "1 1 201.225 1 2 \n",
146 | "2 1 152.237 0 0 \n",
147 | "3 2 278.354 0 5 \n",
148 | "4 2 84.143 0 1 \n",
149 | "\n",
150 | " Number of Rotatable Bonds Polar Surface Area \\\n",
151 | "0 7 202.32 \n",
152 | "1 2 42.24 \n",
153 | "2 4 17.07 \n",
154 | "3 0 0.00 \n",
155 | "4 0 0.00 \n",
156 | "\n",
157 | " measured log solubility in mols per litre \\\n",
158 | "0 -0.77 \n",
159 | "1 -3.30 \n",
160 | "2 -2.06 \n",
161 | "3 -7.87 \n",
162 | "4 -1.33 \n",
163 | "\n",
164 | " smiles \n",
165 | "0 OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)... \n",
166 | "1 Cc1occc1C(=O)Nc2ccccc2 \n",
167 | "2 CC(C)=CCCC(C)=CC(=O) \n",
168 | "3 c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43 \n",
169 | "4 c1ccsc1 "
170 | ]
171 | },
172 | "execution_count": 182,
173 | "metadata": {},
174 | "output_type": "execute_result"
175 | }
176 | ],
177 | "source": [
178 | "df = pd.read_csv('../data/delaney-processed.csv')\n",
179 | "df.head()"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": 183,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "molecule = Chem.MolFromSmiles(df.smiles[0])"
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {},
194 | "source": [
195 | "## Atom features"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 184,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "symbols = CategoricalDtype([\n",
205 | " 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',\n",
206 | " 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',\n",
207 | " 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', # H?\n",
208 | " 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',\n",
209 | " 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown'\n",
210 | "], ordered=True)"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 185,
216 | "metadata": {},
217 | "outputs": [
218 | {
219 | "data": {
220 | "text/html": [
221 | "\n",
222 | "\n",
235 | "
\n",
236 | " \n",
237 | " \n",
238 | " | \n",
239 | " degree | \n",
240 | " hydrogens | \n",
241 | " impl_valence | \n",
242 | " symbol | \n",
243 | "
\n",
244 | " \n",
245 | " index | \n",
246 | " | \n",
247 | " | \n",
248 | " | \n",
249 | " | \n",
250 | "
\n",
251 | " \n",
252 | " \n",
253 | " \n",
254 | " 0 | \n",
255 | " 1 | \n",
256 | " 1 | \n",
257 | " 1 | \n",
258 | " O | \n",
259 | "
\n",
260 | " \n",
261 | " 1 | \n",
262 | " 2 | \n",
263 | " 2 | \n",
264 | " 2 | \n",
265 | " C | \n",
266 | "
\n",
267 | " \n",
268 | " 2 | \n",
269 | " 3 | \n",
270 | " 1 | \n",
271 | " 1 | \n",
272 | " C | \n",
273 | "
\n",
274 | " \n",
275 | " 3 | \n",
276 | " 2 | \n",
277 | " 0 | \n",
278 | " 0 | \n",
279 | " O | \n",
280 | "
\n",
281 | " \n",
282 | " 4 | \n",
283 | " 3 | \n",
284 | " 1 | \n",
285 | " 1 | \n",
286 | " C | \n",
287 | "
\n",
288 | " \n",
289 | "
\n",
290 | "
"
291 | ],
292 | "text/plain": [
293 | " degree hydrogens impl_valence symbol\n",
294 | "index \n",
295 | "0 1 1 1 O\n",
296 | "1 2 2 2 C\n",
297 | "2 3 1 1 C\n",
298 | "3 2 0 0 O\n",
299 | "4 3 1 1 C"
300 | ]
301 | },
302 | "execution_count": 185,
303 | "metadata": {},
304 | "output_type": "execute_result"
305 | }
306 | ],
307 | "source": [
308 | "atoms_df = []\n",
309 | "for i in range(molecule.GetNumAtoms()):\n",
310 | " atom = molecule.GetAtomWithIdx(i)\n",
311 | " atoms_df.append({\n",
312 | " 'index': i,\n",
313 | " 'symbol': atom.GetSymbol(),\n",
314 | " 'degree': atom.GetDegree(),\n",
315 | " 'hydrogens': atom.GetTotalNumHs(),\n",
316 | " 'impl_valence': atom.GetImplicitValence(),\n",
317 | " })\n",
318 | "atoms_df = pd.DataFrame.from_records(atoms_df, index='index')\n",
319 | "#atoms_df.degree.cat.set_categories([0, 1, 2, 3, 4, 5])\n",
320 | "#atoms_df.hydrogens.cat.set_categories([0, 1, 2, 3, 4])\n",
321 | "#atoms_df.impl_valence.cat.set_categories([0, 1, 2, 3, 4, 5])\n",
322 | "atoms_df.symbol = atoms_df.symbol.astype(symbols)\n",
323 | "atoms_df.head()"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 186,
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "node_features = torch.tensor(pd.get_dummies(atoms_df, columns=['symbol']).values, dtype=torch.float)"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {},
338 | "source": [
339 | "## Bond features"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 187,
345 | "metadata": {},
346 | "outputs": [],
347 | "source": [
348 | "bonds = CategoricalDtype([\n",
349 | " 'SINGLE',\n",
350 | " 'DOUBLE',\n",
351 | " 'TRIPLE',\n",
352 | " 'AROMATIC'\n",
353 | "], ordered=True)"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 188,
359 | "metadata": {},
360 | "outputs": [
361 | {
362 | "data": {
363 | "text/html": [
364 | "\n",
365 | "\n",
378 | "
\n",
379 | " \n",
380 | " \n",
381 | " | \n",
382 | " | \n",
383 | " conj | \n",
384 | " ring | \n",
385 | " type | \n",
386 | "
\n",
387 | " \n",
388 | " sender | \n",
389 | " receiver | \n",
390 | " | \n",
391 | " | \n",
392 | " | \n",
393 | "
\n",
394 | " \n",
395 | " \n",
396 | " \n",
397 | " 0 | \n",
398 | " 1 | \n",
399 | " -1.0 | \n",
400 | " -1.0 | \n",
401 | " SINGLE | \n",
402 | "
\n",
403 | " \n",
404 | " 1 | \n",
405 | " 0 | \n",
406 | " -1.0 | \n",
407 | " -1.0 | \n",
408 | " SINGLE | \n",
409 | "
\n",
410 | " \n",
411 | " 2 | \n",
412 | " -1.0 | \n",
413 | " -1.0 | \n",
414 | " SINGLE | \n",
415 | "
\n",
416 | " \n",
417 | " 2 | \n",
418 | " 1 | \n",
419 | " -1.0 | \n",
420 | " -1.0 | \n",
421 | " SINGLE | \n",
422 | "
\n",
423 | " \n",
424 | " 3 | \n",
425 | " -1.0 | \n",
426 | " 1.0 | \n",
427 | " SINGLE | \n",
428 | "
\n",
429 | " \n",
430 | "
\n",
431 | "
"
432 | ],
433 | "text/plain": [
434 | " conj ring type\n",
435 | "sender receiver \n",
436 | "0 1 -1.0 -1.0 SINGLE\n",
437 | "1 0 -1.0 -1.0 SINGLE\n",
438 | " 2 -1.0 -1.0 SINGLE\n",
439 | "2 1 -1.0 -1.0 SINGLE\n",
440 | " 3 -1.0 1.0 SINGLE"
441 | ]
442 | },
443 | "execution_count": 188,
444 | "metadata": {},
445 | "output_type": "execute_result"
446 | }
447 | ],
448 | "source": [
449 | "bonds_df = []\n",
450 | "for bond in molecule.GetBonds():\n",
451 | " bonds_df.append({\n",
452 | " 'sender': bond.GetBeginAtomIdx(),\n",
453 | " 'receiver': bond.GetEndAtomIdx(),\n",
454 | " 'type': bond.GetBondType().name,\n",
455 | " 'conj': bond.GetIsConjugated(),\n",
456 | " 'ring': bond.IsInRing()\n",
457 | " })\n",
458 | " bonds_df.append({\n",
459 | " 'receiver': bond.GetBeginAtomIdx(),\n",
460 | " 'sender': bond.GetEndAtomIdx(),\n",
461 | " 'type': bond.GetBondType().name,\n",
462 | " 'conj': bond.GetIsConjugated(),\n",
463 | " 'ring': bond.IsInRing()\n",
464 | " })\n",
465 | "bonds_df = pd.DataFrame.from_records(bonds_df, index=['sender', 'receiver'])\n",
466 | "bonds_df.conj = bonds_df.conj * 2. - 1\n",
467 | "bonds_df.ring = bonds_df.ring * 2. - 1\n",
468 | "bonds_df.type = bonds_df.type.astype(bonds)\n",
469 | "bonds_df.head()"
470 | ]
471 | },
472 | {
473 | "cell_type": "code",
474 | "execution_count": 189,
475 | "metadata": {},
476 | "outputs": [],
477 | "source": [
478 | "edge_features = torch.tensor(pd.get_dummies(bonds_df, columns=['type']).values, dtype=torch.float)\n",
479 | "senders = torch.tensor(bonds_df.index.get_level_values('sender'))\n",
480 | "receivers = torch.tensor(bonds_df.index.get_level_values('receiver'))"
481 | ]
482 | },
483 | {
484 | "cell_type": "code",
485 | "execution_count": 191,
486 | "metadata": {},
487 | "outputs": [
488 | {
489 | "data": {
490 | "text/plain": [
491 | "Graph(n=32, e=68, n_shape=torch.Size([47]), e_shape=torch.Size([6]), g_shape=None)"
492 | ]
493 | },
494 | "execution_count": 191,
495 | "metadata": {},
496 | "output_type": "execute_result"
497 | }
498 | ],
499 | "source": [
500 | "def smiles_to_graph(smiles: str) -> tg.Graph:\n",
501 | " molecule = Chem.MolFromSmiles(df.smiles[0])\n",
502 | " \n",
503 | " atoms_df = []\n",
504 | " for i in range(molecule.GetNumAtoms()):\n",
505 | " atom = molecule.GetAtomWithIdx(i)\n",
506 | " atoms_df.append({\n",
507 | " 'index': i,\n",
508 | " 'symbol': atom.GetSymbol(),\n",
509 | " 'degree': atom.GetDegree(),\n",
510 | " 'hydrogens': atom.GetTotalNumHs(),\n",
511 | " 'impl_valence': atom.GetImplicitValence(),\n",
512 | " })\n",
513 | " atoms_df = pd.DataFrame.from_records(atoms_df, index='index')\n",
514 | " atoms_df.symbol = atoms_df.symbol.astype(symbols)\n",
515 | " \n",
516 | " node_features = torch.tensor(pd.get_dummies(atoms_df, columns=['symbol']).values, dtype=torch.float)\n",
517 | " \n",
518 | " bonds_df = []\n",
519 | " for bond in molecule.GetBonds():\n",
520 | " bonds_df.append({\n",
521 | " 'sender': bond.GetBeginAtomIdx(),\n",
522 | " 'receiver': bond.GetEndAtomIdx(),\n",
523 | " 'type': bond.GetBondType().name,\n",
524 | " 'conj': bond.GetIsConjugated(),\n",
525 | " 'ring': bond.IsInRing()\n",
526 | " })\n",
527 | " bonds_df.append({\n",
528 | " 'receiver': bond.GetBeginAtomIdx(),\n",
529 | " 'sender': bond.GetEndAtomIdx(),\n",
530 | " 'type': bond.GetBondType().name,\n",
531 | " 'conj': bond.GetIsConjugated(),\n",
532 | " 'ring': bond.IsInRing()\n",
533 | " })\n",
534 | " bonds_df = pd.DataFrame.from_records(bonds_df, index=['sender', 'receiver'])\n",
535 | " bonds_df.conj = bonds_df.conj * 2. - 1\n",
536 | " bonds_df.ring = bonds_df.ring * 2. - 1\n",
537 | " bonds_df.type = bonds_df.type.astype(bonds)\n",
538 | " \n",
539 | " edge_features = torch.tensor(pd.get_dummies(bonds_df, columns=['type']).values, dtype=torch.float)\n",
540 | " senders = torch.tensor(bonds_df.index.get_level_values('sender'))\n",
541 | " receivers = torch.tensor(bonds_df.index.get_level_values('receiver'))\n",
542 | " \n",
543 | " return tg.Graph(\n",
544 | " num_nodes=molecule.GetNumAtoms(),\n",
545 | " num_edges=molecule.GetNumBonds() * 2,\n",
546 | " node_features=node_features,\n",
547 | " edge_features=edge_features,\n",
548 | " senders=senders,\n",
549 | " receivers=receivers\n",
550 | " )\n",
551 | "\n",
552 | "smiles_to_graph(df.smiles[0])"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": 219,
558 | "metadata": {},
559 | "outputs": [],
560 | "source": [
561 | "class SolubilityDataset(torch.utils.data.Dataset):\n",
562 | " def __init__(self, path):\n",
563 | " self.df = pd.read_csv(path)\n",
564 | " self.df['molecules'] = self.df.smiles.apply(smiles_to_graph)\n",
565 | "\n",
566 | " def __len__(self):\n",
567 | " return len(self.df)\n",
568 | "\n",
569 | " def __getitem__(self, item):\n",
570 | " mol = self.df['molecules'].iloc[item]\n",
571 | " target = self.df['measured log solubility in mols per litre'].iloc[item]\n",
572 | " return mol, target\n",
573 | " \n",
574 | "sd = SolubilityDataset('../data/delaney-processed.csv')"
575 | ]
576 | }
577 | ],
578 | "metadata": {
579 | "kernelspec": {
580 | "display_name": "Python 3",
581 | "language": "python",
582 | "name": "python3"
583 | },
584 | "language_info": {
585 | "codemirror_mode": {
586 | "name": "ipython",
587 | "version": 3
588 | },
589 | "file_extension": ".py",
590 | "mimetype": "text/x-python",
591 | "name": "python",
592 | "nbconvert_exporter": "python",
593 | "pygments_lexer": "ipython3",
594 | "version": "3.7.1"
595 | }
596 | },
597 | "nbformat": 4,
598 | "nbformat_minor": 2
599 | }
600 |
--------------------------------------------------------------------------------
/notebooks/biggraph.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/notebooks/biggraph.pt
--------------------------------------------------------------------------------
/resources/sucrose-atoms.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/resources/sucrose-atoms.png
--------------------------------------------------------------------------------
/resources/sucrose-bonds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/resources/sucrose-bonds.png
--------------------------------------------------------------------------------
/resources/sucrose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/resources/sucrose.png
--------------------------------------------------------------------------------
/scripts.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Conda environment
4 | conda env export | sed '/prefix: .*/ d' > conda.yaml
5 | conda env create -n gn-exp -f conda.yaml
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='graph-network-explainability',
5 | version='0.0.1',
6 | packages=find_packages(where='src'),
7 | package_dir={"": "src"},
8 | )
9 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/__init__.py
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | from collections import Mapping, defaultdict
2 | from pathlib import Path
3 | from typing import Union
4 |
5 | import yaml
6 | import munch
7 |
8 |
9 | class Config(munch.Munch):
10 |
11 | def __setattr__(self, key, value):
12 | if isinstance(value, Mapping):
13 | value = Config.fromDict(value)
14 | super(Config, self).__setattr__(key, value)
15 |
16 | @staticmethod
17 | def build(*new_configs, **cfg_args):
18 | c = Config()
19 | for new_config in new_configs:
20 | if isinstance(new_config, Mapping):
21 | Config._update_rec(c, new_config)
22 | elif isinstance(new_config, str) and '=' in new_config:
23 | Config._update_rec(c, Config.from_dotted(new_config))
24 | elif Path(new_config).suffix in {'.yml', '.yaml'}:
25 | Config._update_rec(c, Config.from_yaml(new_config))
26 | elif Path(new_config).suffix == '.json':
27 | Config._update_rec(c, Config.from_json(new_config))
28 | Config._update_rec(c, cfg_args)
29 | return c
30 |
31 | @staticmethod
32 | def from_yaml(file: Union[str, Path]):
33 | with open(file, 'r') as f:
34 | return Config.fromDict(yaml.safe_load(f))
35 |
36 | @staticmethod
37 | def from_json(file: Union[str, Path]):
38 | import json
39 | with open(file, 'r') as f:
40 | return Config.fromDict(json.load(f))
41 |
42 | @staticmethod
43 | def from_dotted(dotted_str: str):
44 | """Parse a string of named arguments that use dots to indicate hierarchy, e.g. `name=test opts.cpus=4`
45 | """
46 |
47 | def recursively_defaultdict():
48 | return defaultdict(recursively_defaultdict)
49 |
50 | config = recursively_defaultdict()
51 |
52 | for name_dotted, value in (pair.split('=') for pair in dotted_str.split(' ')):
53 | c = config
54 | name_head, *name_rest = name_dotted.lstrip('-').split('.')
55 | while len(name_rest) > 0:
56 | c = c[name_head]
57 | name_head, *name_rest = name_rest
58 | c[name_head] = yaml.safe_load(value)
59 | return Config.fromDict(config)
60 |
61 | @staticmethod
62 | def _update_rec(old_config, new_config):
63 | for k in new_config.keys():
64 | if k in old_config and isinstance(old_config[k], Mapping) and isinstance(new_config[k], Mapping):
65 | Config._update_rec(old_config[k], new_config[k])
66 | else:
67 | setattr(old_config, k, new_config[k])
68 |
--------------------------------------------------------------------------------
/src/count_nodes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/count_nodes/__init__.py
--------------------------------------------------------------------------------
/src/count_nodes/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import networkx as nx
3 |
4 | import torch
5 | from torch.utils import data
6 |
7 | import torchgraphs as tg
8 |
9 |
10 | class NodeCountDataset(data.Dataset):
11 | def __init__(self, min_nodes, max_nodes, num_samples, informative_features,
12 | edge_features_shape, node_features_shape, global_features_shape):
13 | self.num_samples = num_samples
14 | self.min_nodes = min_nodes
15 | self.max_nodes = max_nodes
16 | self.node_features_shape = node_features_shape
17 | self.edge_features_shape = edge_features_shape
18 | self.global_features_shape = global_features_shape
19 | self.informative_features = informative_features
20 | self.samples = [None] * self.num_samples
21 |
22 | def __len__(self):
23 | return self.num_samples
24 |
25 | def __getitem__(self, item):
26 | self.samples[item] = self.samples[item] if self.samples[item] is not None else self._random_sample()
27 | return self.samples[item]
28 |
29 | def _random_sample(self):
30 | num_nodes = np.random.randint(self.min_nodes, self.max_nodes)
31 | num_edges = np.random.randint(0, num_nodes * (num_nodes - 1) + 1)
32 | g_nx = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
33 | g = tg.Graph.from_networkx(g_nx)
34 | g = g.evolve(
35 | node_features=torch.empty(num_nodes, self.node_features_shape).uniform_(-1, 1),
36 | edge_features=torch.empty(num_edges, self.edge_features_shape).uniform_(-1, 1),
37 | global_features=torch.empty(self.global_features_shape).uniform_(-1, 1)
38 | )
39 |
40 | if self.informative_features > 0:
41 | feats = np.random.rand(num_nodes, self.informative_features) > .5
42 | target = np.logical_and(*feats.transpose()).sum()
43 | g.node_features[:, :self.informative_features] = torch.from_numpy(feats.astype(np.float32)) * 2 - 1
44 | else:
45 | target = g.num_nodes
46 |
47 | return g, target
48 |
49 |
50 | def create_and_save():
51 | import random
52 | import argparse
53 | from tqdm import tqdm
54 | from pathlib import Path
55 | from config import Config
56 |
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--yaml', nargs='+', default=[])
59 | parser.add_argument('--dry-run', action='store_true')
60 | args, rest = parser.parse_known_args()
61 |
62 | config = Config()
63 | for y in args.yaml:
64 | config.update_from_yaml(y)
65 | if len(rest) > 0:
66 | if rest[0] != '--':
67 | rest = ' '.join(rest)
68 | print(f"Error: additional config must be separated by '--', got:\n{rest}")
69 | exit(1)
70 | config.update_from_cli(' '.join(rest[1:]))
71 |
72 | print(config.toYAML())
73 | if args.dry_run:
74 | print('Dry run, exiting.')
75 | exit(0)
76 | del args, rest
77 |
78 | random.seed(config.opts.seed)
79 | np.random.seed(config.opts.seed)
80 | torch.random.manual_seed(config.opts.seed)
81 |
82 | folder = Path(config.opts.folder).expanduser().resolve() / 'data'
83 | folder.mkdir(parents=True, exist_ok=True)
84 |
85 | datasets_common = {k: v for k, v in config.datasets.items() if k not in {'_train_', '_val_', '_test_'}}
86 | for name in ['train', 'val', 'test']:
87 | path = folder / f'{name}.pt'
88 | dataset = NodeCountDataset(
89 | **datasets_common,
90 | **{k: v for k, v in config.datasets.get(f'_{name}_', {}).items()}
91 | )
92 | samples = len([s for s in tqdm(dataset, desc=name.capitalize(), unit='samples', leave=True)])
93 | torch.save(dataset, path)
94 | tqdm.write(f'{name.capitalize()}: saved {samples} samples in\t{path}')
95 |
96 |
97 | if __name__ == '__main__':
98 | create_and_save()
99 |
--------------------------------------------------------------------------------
/src/count_nodes/layout.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import tensorflow as tf
5 |
6 | from tensorboard import summary as summary_lib
7 | from tensorboard.plugins.custom_scalar import layout_pb2
8 |
9 | layout_summary = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=[
10 | layout_pb2.Category(
11 | title='Losses',
12 | chart=[
13 | layout_pb2.Chart(
14 | title='Train', multiline=layout_pb2.MultilineChartContent(tag=['loss/train/mse', 'loss/train/l1'])),
15 | layout_pb2.Chart(
16 | title='Val', multiline=layout_pb2.MultilineChartContent(tag=['loss/train/mse', 'loss/val/mse'])),
17 | ])
18 | ]))
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('folder', help='The log folder to place the layout in')
22 | args = parser.parse_args()
23 |
24 | folder = (Path(args.folder) / 'layout').expanduser().resolve()
25 | with tf.summary.FileWriter(folder) as writer:
26 | writer.add_summary(layout_summary)
27 |
28 | print('Layout saved to', folder)
29 |
--------------------------------------------------------------------------------
/src/count_nodes/networks.py:
--------------------------------------------------------------------------------
1 | import torch_scatter
2 |
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | import torchgraphs as tg
7 | from torchgraphs.utils import segment_lengths_to_ids
8 |
9 |
10 | class FullGN(nn.Module):
11 | def __init__(
12 | self,
13 | in_edge_features_shape, in_node_features_shape, in_global_features_shape,
14 | out_edge_features_shape, out_node_features_shape, out_global_features_shape
15 | ):
16 | super().__init__()
17 | self.f_e = nn.Linear(in_edge_features_shape, out_edge_features_shape)
18 | self.f_s = nn.Linear(in_node_features_shape, out_edge_features_shape)
19 | self.f_r = nn.Linear(in_node_features_shape, out_edge_features_shape)
20 | self.f_u = nn.Linear(in_global_features_shape, out_edge_features_shape)
21 |
22 | self.g_n = nn.Linear(in_node_features_shape, out_node_features_shape)
23 | self.g_in = nn.Linear(out_edge_features_shape, out_node_features_shape)
24 | self.g_out = nn.Linear(out_edge_features_shape, out_node_features_shape)
25 | self.g_u = nn.Linear(in_global_features_shape, out_node_features_shape)
26 |
27 | self.h_n = nn.Linear(out_node_features_shape, out_global_features_shape)
28 | self.h_e = nn.Linear(out_edge_features_shape, out_global_features_shape)
29 | self.h_u = nn.Linear(in_global_features_shape, out_global_features_shape)
30 |
31 | def forward(self, graphs: tg.GraphBatch):
32 | edges = F.relu(
33 | self.f_e(graphs.edge_features) +
34 | self.f_s(graphs.node_features).index_select(dim=0, index=graphs.senders) +
35 | self.f_r(graphs.node_features).index_select(dim=0, index=graphs.receivers) +
36 | tg.utils.repeat_tensor(self.f_u(graphs.global_features), graphs.num_edges_by_graph)
37 | )
38 | nodes = F.relu(
39 | self.g_n(graphs.node_features) +
40 | self.g_in(torch_scatter.scatter_add(edges, graphs.receivers, dim=0, dim_size=graphs.num_nodes)) +
41 | self.g_out(torch_scatter.scatter_add(edges, graphs.senders, dim=0, dim_size=graphs.num_nodes)) +
42 | tg.utils.repeat_tensor(self.g_u(graphs.global_features), graphs.num_nodes_by_graph)
43 | )
44 | globals = (
45 | self.h_e(torch_scatter.scatter_add(
46 | edges, segment_lengths_to_ids(graphs.num_edges_by_graph), dim=0, dim_size=graphs.num_graphs)) +
47 | self.h_n(torch_scatter.scatter_add(
48 | nodes, segment_lengths_to_ids(graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs)) +
49 | self.h_u(graphs.global_features)
50 | )
51 | return graphs.evolve(
52 | edge_features=edges,
53 | node_features=nodes,
54 | global_features=globals,
55 | )
56 |
57 |
58 | class MinimalGN(nn.Module):
59 | def __init__(self, in_node_features_shape, out_node_features_shape, out_global_features_shape):
60 | super().__init__()
61 | self.g_n = nn.Linear(in_node_features_shape, out_node_features_shape)
62 | self.h_n = nn.Linear(out_node_features_shape, out_global_features_shape)
63 |
64 | def forward(self, graphs: tg.GraphBatch):
65 | nodes = F.relu(self.g_n(graphs.node_features))
66 | globals = self.h_n(torch_scatter.scatter_add(
67 | nodes, segment_lengths_to_ids(graphs.num_nodes_by_graph), dim=0, dim_size=graphs.num_graphs))
68 | return graphs.evolve(
69 | num_edges=0,
70 | edge_features=None,
71 | node_features=None,
72 | global_features=globals,
73 | senders=None,
74 | receivers=None
75 | )
76 |
--------------------------------------------------------------------------------
/src/count_nodes/notes.md:
--------------------------------------------------------------------------------
1 | # Count Nodes
2 |
3 | ## Task
4 | Count the number of nodes in the graph. It's a case of simple aggregation from nodes to global.
5 |
6 | Cases:
7 | 1. Node features are all random, just count the number of nodes
8 | 2. Node features are a combination of informative features and random features. Only some of them are considered
9 | active and therefore counted. Specifically, if values of the informative features of a node are **all** 1,
10 | that node should be counted, otherwise is skipped.
11 |
12 | ## Generalization strategy
13 | During training the network only sees _small_ graphs, i.e. graphs whose node count is below a threshold.
14 | During testing the network is presented larger graphs, with a number of nodes outside the range used for training.
15 |
16 | ## Weighting function
17 | Weighting is only necessary for case 2:
18 | - The number of nodes is selected at random as `n~Uniform(0, max_nodes)`
19 | - Out of `K` features, `I` are considered independently informative and set to 0 or 1 with prob 0.5
20 | - The `I` informative features are evaluated through a logical AND, i.e. all must be 1 for the node to be counted
21 | - For this reason, the target values for the dataset will follow a binomial distribution
22 |
23 | ```t~Binomial(max_nodes, 0.5^I)```
24 |
25 | with mean `max_nodes * 0.5^I` and variance `max_nodes * 0.5^I (1 - 0.5^I)`
26 | - Graphs that have a very small or very large number or active nodes are rare and should be weighted more.
27 | - A suitable weighting function is the negative log-likelihood of the target value:
28 |
29 | ```- ln Binomial(target | max_nodes, 0.5^I)```
30 |
31 | ## Full network
32 | The full version of the graph network operates as such:
33 | ```
34 | e_{s \to r}^{t+1} = ReLU [ f_e(e_{s \to r}^t) + f_s(n_s^t) + f_r(n_r^t) + f_u(u^t)]
35 |
36 | n_i^{t+1} = ReLU [ g_n(n_i^t) + g_{in}(agg_s(e_{s \to i}^{t+1})) + g_{out}(agg_r(e_{i \to r}^{t+1})) + g_u(u^t)]
37 |
38 | u_i^{t+1} = h_n(agg_i(n_i^{t+1})) + h_e(agg_{ij}(e_{i \to j}^{t+1})) + h_u(u^t)
39 | ```
40 |
41 | using summation as an aggregation function
42 |
43 | ## Minimal network
44 | For this task, the minimal version of the network should only use:
45 | ```
46 | n_i^{t+1} = ReLU [ g_n(n_i^t) ]
47 |
48 | u_i^{t+1} = h_n(agg_i(n_i^{t+1}))
49 | ```
50 |
51 | ## Workflow
52 |
53 | 1. Create base folder
54 | ```bash
55 | COUNT_NODES=~/experiments/count-nodes/
56 | mkdir -p "$COUNT_NODES/{runs,data}"
57 | ```
58 | 2. Create dataset
59 | ```bash
60 | python -m count_nodes.dataset --yaml
61 | ```
62 | 3. Create tensorboard layout
63 | ```bash
64 | python -m count_nodes.layout --folder "$COUNT_NODES/runs"
65 | ```
66 | 4. Launch experiments
67 | ```bash
68 | python -m count_nodes.train --yaml -- []
69 |
70 | conda activate tg-experiments
71 | for i in 1 2 3 4 5; do
72 | for type in full minimal; do
73 | for lr in .01 .001; do
74 | for wd in 0 .01 .001; do
75 | python -m count_nodes.train --yaml ../config/count_nodes/{train,${type}}.yml -- \
76 | opts.session=${type}_lr${lr}_wd${wd}_${i} \
77 | optimizer.lr=${lr} \
78 | training.l1=${wd} training.epochs=40
79 | done
80 | done
81 | done
82 | done
83 | ```
84 | 5. Visualize
85 | ```bash
86 | tensorboard --logdir "$COUNT_NODES/runs"
87 | ```
--------------------------------------------------------------------------------
/src/count_nodes/train.py:
--------------------------------------------------------------------------------
1 | import random
2 | import argparse
3 | import textwrap
4 |
5 | import pandas as pd
6 | from pathlib import Path
7 |
8 | import tqdm
9 | import numpy as np
10 | import torchgraphs as tg
11 | from munch import Munch
12 | from scipy import stats
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from torch import nn, optim
17 | from torch.utils import data
18 | from tensorboardX import SummaryWriter
19 |
20 | from count_nodes.dataset import NodeCountDataset
21 | from saver import Saver
22 | from utils import load_class
23 | from config import Config
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--yaml', nargs='+', default=[])
27 | parser.add_argument('--dry-run', action='store_true')
28 | args, rest = parser.parse_known_args()
29 |
30 | config = Config()
31 | for y in args.yaml:
32 | config.update_from_yaml(y)
33 | if len(rest) > 0:
34 | if rest[0] != '--':
35 | rest = ' '.join(rest)
36 | print(f"Error: additional config must be separated by '--', got:\n{rest}")
37 | exit(1)
38 | config.update_from_cli(' '.join(rest[1:]))
39 |
40 | print('Config summary:', config.toYAML(), sep='\n')
41 | if args.dry_run:
42 | print('Dry run, exiting.')
43 | exit(0)
44 | del args, rest
45 |
46 | random.seed(config.opts.seed)
47 | np.random.seed(config.opts.seed)
48 | torch.random.manual_seed(config.opts.seed)
49 |
50 | folder_base = (Path(config.opts.folder)).expanduser().resolve()
51 | folder_data = folder_base / 'data'
52 | folder_run = folder_base / 'runs' / config.opts.session
53 | saver = Saver(folder_run)
54 | logger = SummaryWriter(folder_run.as_posix())
55 |
56 | ModelClass = load_class(config.model._class_)
57 | net: nn.Module = ModelClass(**{k: v for k, v in config.model.items() if k != '_class_'})
58 | net.to(config.opts.device)
59 |
60 | OptimizerClass = load_class(config.optimizer._class_)
61 | optimizer: optim.Optimizer = OptimizerClass(params=net.parameters(),
62 | **{k: v for k, v in config.optimizer.items() if k != '_class_'})
63 |
64 | if config.training.restore:
65 | train_state = saver.load(model=net, optimizer=optimizer, device=config.training.device)
66 | else:
67 | train_state = Munch(epochs=0, samples=0)
68 |
69 | if config.opts.log:
70 | with open(folder_run / 'config.yml', mode='w') as f:
71 | f.write(config.toYAML())
72 | logger.add_text(
73 | 'Config',
74 | textwrap.indent(config.toYAML(), ' '),
75 | global_step=train_state.samples)
76 |
77 |
78 | def make_dataloader(dataset, shuffle) -> data.DataLoader:
79 | return data.DataLoader(
80 | dataset,
81 | batch_size=config.training.batch_size,
82 | collate_fn=tg.GraphBatch.collate,
83 | num_workers=config.opts.cpus,
84 | shuffle=shuffle,
85 | pin_memory='cuda' in str(config.opts.device),
86 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1))
87 | )
88 |
89 |
90 | dataset_train: NodeCountDataset = torch.load(folder_data / 'train.pt')
91 | dataset_val: NodeCountDataset = torch.load(folder_data / 'val.pt')
92 | dataset_test: NodeCountDataset = torch.load(folder_data / 'test.pt')
93 | dataloader_train = make_dataloader(dataset_train, shuffle=True)
94 | dataloader_val = make_dataloader(dataset_val, shuffle=False)
95 | dataloader_test = make_dataloader(dataset_test, shuffle=False)
96 |
97 | if dataset_train.informative_features > 0:
98 | binom = stats.binom(n=dataset_train.max_nodes, p=.5 ** dataset_train.informative_features)
99 |
100 |
101 | def weight_fn(targets):
102 | return targets.new_tensor(- binom.logpmf(targets.cpu().numpy()))
103 | else:
104 | def weight_fn(targets):
105 | return torch.ones_like(targets)
106 |
107 | epoch_bar_postfix = {}
108 | epoch_start = train_state.epochs + 1
109 | epoch_end = train_state.epochs + 1 + config.training.epochs
110 | epoch_bar = tqdm.trange(epoch_start, epoch_end, desc='Training', unit='e', leave=True)
111 | for epoch in epoch_bar:
112 | # Training loop
113 | net.train()
114 | loss_mse_train = 0
115 | train_bar_postfix = {}
116 | with tqdm.tqdm(desc=f'Train {epoch}', total=dataloader_train.dataset.num_samples, unit='g') as train_bar:
117 | for graphs, targets in dataloader_train:
118 | graphs = graphs.to(config.opts.device)
119 | targets = targets.float().to(config.opts.device)
120 | weights = weight_fn(targets)
121 |
122 | loss_total = 0
123 | results = net(graphs).global_features.squeeze()
124 | losses_mse = F.mse_loss(results, targets, reduction='none')
125 | losses_mse_weighted = losses_mse * weights
126 | loss_total += losses_mse_weighted.mean()
127 |
128 | if config.training.l1 > 0:
129 | loss_l1 = sum([p.abs().sum() for p in net.parameters()]) * config.training.l1
130 | loss_total += loss_l1
131 | train_bar_postfix['L1'] = f'{loss_l1.item():.5f}'
132 | if config.opts.log:
133 | logger.add_scalar('loss/train/l1', loss_l1.item(), global_step=train_state.samples)
134 |
135 | optimizer.zero_grad()
136 | loss_total.backward()
137 | optimizer.step(closure=None)
138 |
139 | train_state.samples += graphs.num_graphs
140 | loss_mse_train += losses_mse.sum().item()
141 |
142 | train_bar.update(graphs.num_graphs)
143 | train_bar_postfix['MSE'] = f'{losses_mse.mean().item():.5f}'
144 | train_bar.set_postfix(train_bar_postfix)
145 | if config.opts.log:
146 | logger.add_scalar('loss/train/all', loss_total.item(), global_step=train_state.samples)
147 | logger.add_scalar('loss/train/mse', losses_mse.mean().item(), global_step=train_state.samples)
148 | epoch_bar_postfix['train/mse'] = f'{loss_mse_train / dataloader_train.dataset.num_samples:.5f}'
149 | epoch_bar.set_postfix(epoch_bar_postfix)
150 |
151 | # Saving
152 | train_state.epochs += 1
153 | if config.training.save_every > 0 and epoch % config.training.save_every == 0:
154 | saver.save(name=epoch, model=net, optimizer=optimizer, **train_state)
155 |
156 | # Validation loop
157 | net.eval()
158 | loss_mse_val = 0
159 | with torch.no_grad():
160 | with tqdm.tqdm(desc=f'Val {epoch}', total=dataloader_val.dataset.num_samples, unit='g') as val_bar:
161 | for graphs, targets in dataloader_val:
162 | graphs = graphs.to(config.opts.device)
163 | targets = targets.float().to(config.opts.device)
164 |
165 | results = net(graphs).global_features.squeeze()
166 | losses_mse = F.mse_loss(results, targets, reduction='none')
167 |
168 | loss_mse_val += losses_mse.sum().item()
169 | val_bar.update(graphs.num_graphs)
170 | val_bar.set_postfix_str(f'MSE: {losses_mse.mean().item():.5f}')
171 | if config.opts.log:
172 | logger.add_scalar(
173 | 'loss/val/mse', loss_mse_val / dataloader_val.dataset.num_samples, global_step=train_state.samples)
174 | epoch_bar_postfix['val/mse'] = f'{loss_mse_val / dataloader_val.dataset.num_samples:.5f}'
175 | epoch_bar.set_postfix(epoch_bar_postfix)
176 | epoch_bar.close()
177 |
178 | net.eval()
179 | df = {k: [] for k in ['Loss', 'Nodes', 'Edges', 'Predict', 'Target']}
180 | with torch.no_grad():
181 | with tqdm.tqdm(desc='Test', total=dataloader_test.dataset.num_samples, leave=True, unit='g') as test_bar:
182 | for graphs, targets in dataloader_test:
183 | graphs = graphs.to(config.opts.device)
184 | targets = targets.float().to(config.opts.device)
185 |
186 | results = net(graphs).global_features.squeeze()
187 | losses_mse = F.mse_loss(results, targets, reduction='none')
188 |
189 | df['Loss'].append(losses_mse.cpu().numpy())
190 | df['Nodes'].append(graphs.num_nodes_by_graph.cpu().numpy())
191 | df['Edges'].append(graphs.num_edges_by_graph.cpu().numpy())
192 | df['Target'].append(targets.int().cpu().numpy())
193 | df['Predict'].append(results.cpu().numpy())
194 |
195 | test_bar.update(graphs.num_graphs)
196 | test_bar.set_postfix_str(f'MSE: {losses_mse.mean().item():.5f}')
197 | df = pd.DataFrame({k: np.concatenate(v) for k, v in df.items()}).rename_axis('GraphId').reset_index()
198 |
199 | # Split the results based on whether the number of nodes was present in the training set or not
200 | df_train_test = df \
201 | .groupby(np.where(df.Nodes < dataset_train.max_nodes,
202 | f'Train [{dataset_train.min_nodes}, {dataset_train.max_nodes - 1})',
203 | f'Test [{dataset_train.max_nodes}, {dataset_test.max_nodes - 1})')) \
204 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'Loss': 'mean'}) \
205 | .sort_index(ascending=False) \
206 | .rename_axis(index='Dataset') \
207 | .rename(str.capitalize, axis='columns', level=1)
208 |
209 | # Split the results in ranges based on the number of nodes and compute the average loss per range
210 | df_losses_by_node_range = df \
211 | .groupby(df.Nodes // 10) \
212 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'Loss': 'mean'}) \
213 | .rename_axis(index='NodeRange') \
214 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index') \
215 | .rename(str.capitalize, axis='columns', level=1)
216 |
217 | # Split the results in ranges based on the number of nodes and compute the average loss per range
218 | df_worst_graphs_by_node_range = df \
219 | .groupby(df.Nodes // 10) \
220 | .apply(lambda df_gr: df_gr.nlargest(5, 'Loss').set_index('GraphId')) \
221 | .rename_axis(index={'Nodes': 'NodeRange'}) \
222 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0)
223 |
224 | print(
225 | df_train_test.to_string(float_format=lambda x: f'{x:.2f}'),
226 | df_losses_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'),
227 | df_worst_graphs_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'),
228 | sep='\n\n')
229 | if config.opts.log:
230 | logger.add_text(
231 | 'Generalization',
232 | textwrap.indent(df_train_test.to_string(float_format=lambda x: f'{x:.2f}'), ' '),
233 | global_step=train_state.samples)
234 | logger.add_text(
235 | 'Loss by range',
236 | textwrap.indent(df_losses_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'), ' '),
237 | global_step=train_state.samples)
238 | logger.add_text(
239 | 'Samples',
240 | textwrap.indent(df_worst_graphs_by_node_range.to_string(float_format=lambda x: f'{x:.2f}'), ' '),
241 | global_step=train_state.samples)
242 |
243 | params = [f'{name}:\n{param.data.cpu().numpy().round(3)}' for name, param in net.named_parameters()]
244 | print('Parameters:', *params, sep='\n\n')
245 | if config.opts.log:
246 | logger.add_text('Parameters', textwrap.indent('\n\n'.join(params), ' '), global_step=train_state.samples)
247 |
248 | logger.close()
249 |
--------------------------------------------------------------------------------
/src/guidedbackprop/__init__.py:
--------------------------------------------------------------------------------
1 | # TODO split these into separate modules
2 | from .autograd_tricks import add, sum, \
3 | cat, index_select, repeat_tensor, \
4 | scatter_add, scatter_mean, scatter_max, \
5 | linear, relu, get_aggregation
6 | from .graphs import EdgeLinearGuidedBP, NodeLinearGuidedBP, GlobalLinearGuidedBP, \
7 | EdgeReLUGuidedBP, NodeReLUGuidedBP, GlobalReLUGuidedBP
8 |
--------------------------------------------------------------------------------
/src/guidedbackprop/autograd_tricks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch_scatter
5 |
6 |
7 | def hook(grad):
8 | return grad.clamp(min=0)
9 |
10 |
11 | def add(a, b):
12 | out = torch.add(a, b)
13 | out.register_hook(hook)
14 | return out
15 |
16 |
17 | def sum(tensor, dim=None, keepdim=False):
18 | out = torch.sum(tensor, dim=dim, keepdim=keepdim)
19 | out.register_hook(hook)
20 | return out
21 |
22 |
23 | def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
24 | out = torch_scatter.scatter_add(src, index, dim, out, dim_size, fill_value)
25 | out.register_hook(hook)
26 | return out
27 |
28 |
29 | def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
30 | out = torch_scatter.scatter_mean(src, index, dim, out, dim_size, fill_value)
31 | out.register_hook(hook)
32 | return out
33 |
34 |
35 | def scatter_max(src, index, dim=-1, dim_size=None, fill_value=0):
36 | out = torch_scatter.scatter_max(src, index, dim, None, dim_size, fill_value)
37 | out[0].register_hook(hook)
38 | return out
39 |
40 |
41 | def linear(input, weight, bias=None):
42 | out = torch.nn.functional.linear(input, weight, bias)
43 | out.register_hook(hook)
44 | return out
45 |
46 |
47 | def index_select(src, dim, index):
48 | out = torch.index_select(src, dim, index)
49 | out.register_hook(hook)
50 | return out
51 |
52 |
53 | def cat(tensors, dim=0):
54 | out = torch.cat(tensors, dim)
55 | out.register_hook(hook)
56 | return out
57 |
58 |
59 | def repeat_tensor(src, repeats, dim=0):
60 | idx = src.new_tensor(np.arange(len(repeats)).repeat(repeats.cpu().numpy()), dtype=torch.long)
61 | return index_select(src, dim, idx)
62 |
63 |
64 | def relu(input):
65 | out = torch.nn.functional.relu(input)
66 | out.register_hook(hook)
67 | return out
68 |
69 |
70 | def get_aggregation(name):
71 | if name in ('add', 'sum'):
72 | return scatter_add
73 | elif name in ('mean', 'avg'):
74 | return scatter_mean
75 | elif name == 'max':
76 | from functools import wraps
77 |
78 | @wraps(scatter_max)
79 | def wrapper(*args, **kwargs):
80 | return scatter_max(*args, **kwargs)[0]
81 |
82 | return wrapper
83 |
--------------------------------------------------------------------------------
/src/guidedbackprop/graphs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchgraphs as tg
3 |
4 | from . import autograd_tricks as guidedbp
5 |
6 |
7 | class EdgeLinearGuidedBP(tg.EdgeLinear):
8 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
9 | new_edges = torch.tensor(0)
10 |
11 | if self.W_edge is not None:
12 | new_edges = guidedbp.add(new_edges, guidedbp.linear(graphs.edge_features, self.W_edge))
13 | if self.W_sender is not None:
14 | new_edges = guidedbp.add(
15 | new_edges,
16 | guidedbp.index_select(guidedbp.linear(graphs.node_features, self.W_sender),
17 | dim=0, index=graphs.senders)
18 | )
19 | if self.W_receiver is not None:
20 | new_edges = guidedbp.add(
21 | new_edges,
22 | guidedbp.index_select(guidedbp.linear(graphs.node_features, self.W_receiver),
23 | dim=0, index=graphs.receivers)
24 | )
25 | if self.W_global is not None:
26 | new_edges = guidedbp.add(
27 | new_edges,
28 | guidedbp.repeat_tensor(guidedbp.linear(graphs.global_features, self.W_global),
29 | dim=0, repeats=graphs.num_edges_by_graph)
30 | )
31 | if self.bias is not None:
32 | new_edges = guidedbp.add(new_edges, self.bias)
33 |
34 | return graphs.evolve(edge_features=new_edges)
35 |
36 |
37 | class NodeLinearGuidedBP(tg.NodeLinear):
38 | def __init__(self, out_features, node_features=None, incoming_features=None, outgoing_features=None,
39 | global_features=None, aggregation=None, bias=True):
40 | super(NodeLinearGuidedBP, self).__init__(out_features, node_features, incoming_features,
41 | outgoing_features, global_features,
42 | guidedbp.get_aggregation(aggregation),
43 | bias)
44 |
45 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
46 | new_nodes = torch.tensor(0)
47 |
48 | if self.W_node is not None:
49 | new_nodes = guidedbp.add(
50 | new_nodes,
51 | guidedbp.linear(graphs.node_features, self.W_node)
52 | )
53 | if self.W_incoming is not None:
54 | new_nodes = guidedbp.add(
55 | new_nodes,
56 | guidedbp.linear(
57 | self.aggregation(graphs.edge_features, dim=0, index=graphs.receivers, dim_size=graphs.num_nodes),
58 | self.W_incoming)
59 | )
60 | if self.W_outgoing is not None:
61 | new_nodes = guidedbp.add(
62 | new_nodes,
63 | guidedbp.linear(
64 | self.aggregation(graphs.edge_features, dim=0, index=graphs.senders, dim_size=graphs.num_nodes),
65 | self.W_outgoing)
66 | )
67 | if self.W_global is not None:
68 | new_nodes = guidedbp.add(
69 | new_nodes,
70 | guidedbp.repeat_tensor(guidedbp.linear(graphs.global_features, self.W_global), dim=0,
71 | repeats=graphs.num_nodes_by_graph)
72 | )
73 | if self.bias is not None:
74 | new_nodes = guidedbp.add(new_nodes, self.bias)
75 |
76 | return graphs.evolve(node_features=new_nodes)
77 |
78 |
79 | class GlobalLinearGuidedBP(tg.GlobalLinear):
80 | def __init__(self, out_features, node_features=None, edge_features=None, global_features=None,
81 | aggregation=None, bias=True):
82 | super(GlobalLinearGuidedBP, self).__init__(out_features, node_features, edge_features,
83 | global_features, guidedbp.get_aggregation(aggregation), bias)
84 |
85 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
86 | new_globals = torch.tensor(0)
87 |
88 | if self.W_node is not None:
89 | index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph)
90 | new_globals = guidedbp.add(
91 | new_globals,
92 | guidedbp.linear(self.aggregation(graphs.node_features, dim=0, index=index, dim_size=graphs.num_graphs),
93 | self.W_node)
94 | )
95 | if self.W_edges is not None:
96 | index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph)
97 | new_globals = guidedbp.add(
98 | new_globals,
99 | guidedbp.linear(self.aggregation(graphs.edge_features, dim=0, index=index, dim_size=graphs.num_graphs),
100 | self.W_edges)
101 | )
102 | if self.W_global is not None:
103 | new_globals = guidedbp.add(
104 | new_globals,
105 | guidedbp.linear(graphs.global_features, self.W_global)
106 | )
107 | if self.bias is not None:
108 | new_globals = guidedbp.add(new_globals, self.bias)
109 |
110 | return graphs.evolve(global_features=new_globals)
111 |
112 |
113 | class EdgeReLUGuidedBP(tg.EdgeFunction):
114 | def __init__(self):
115 | super(EdgeReLUGuidedBP, self).__init__(guidedbp.relu)
116 |
117 |
118 | class NodeReLUGuidedBP(tg.NodeFunction):
119 | def __init__(self):
120 | super(NodeReLUGuidedBP, self).__init__(guidedbp.relu)
121 |
122 |
123 | class GlobalReLUGuidedBP(tg.GlobalFunction):
124 | def __init__(self):
125 | super(GlobalReLUGuidedBP, self).__init__(guidedbp.relu)
126 |
--------------------------------------------------------------------------------
/src/infection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/infection/__init__.py
--------------------------------------------------------------------------------
/src/infection/dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Tuple
3 |
4 | import torch
5 | import torch.utils.data
6 | import numpy as np
7 | import networkx as nx
8 | import torchgraphs as tg
9 |
10 |
11 | class InfectionDataset(torch.utils.data.Dataset):
12 | def __init__(self, max_percent_immune, max_percent_sick, max_percent_virtual, min_nodes, max_nodes):
13 | if max_percent_sick + max_percent_immune > 1:
14 | raise ValueError(f"Cannot have a population with `max_percent_sick`={max_percent_sick}"
15 | f"and `max_percent_immune`={max_percent_immune}")
16 | self.min_nodes = min_nodes
17 | self.max_nodes = max_nodes
18 | self.max_percent_immune = max_percent_immune
19 | self.max_percent_sick = max_percent_sick
20 | self.max_percent_virtual = max_percent_virtual
21 | self.node_features_shape = 4
22 | self.edge_features_shape = 2
23 | self.samples: List[Tuple[tg.Graph, tg.Graph]] = []
24 |
25 | def __len__(self):
26 | return len(self.samples)
27 |
28 | def __getitem__(self, item):
29 | return self.samples[item]
30 |
31 | def random_sample(self):
32 | num_nodes = np.random.randint(self.min_nodes, self.max_nodes)
33 | g_nx = nx.barabasi_albert_graph(num_nodes, 2).to_directed()
34 |
35 | # Remove some edges
36 | num_edges = int(.7 * g_nx.number_of_edges())
37 | edges_to_remove = np.random.choice(g_nx.number_of_edges(),
38 | size=g_nx.number_of_edges() - num_edges, replace=False)
39 | edges_to_remove = [list(g_nx.edges)[i] for i in edges_to_remove]
40 | g_nx.remove_edges_from(edges_to_remove)
41 |
42 | # Create node features: sick (at least one), immune and at risk
43 | num_sick = np.random.randint(1, max(1, int(num_nodes * self.max_percent_sick)) + 1)
44 | num_immune = np.random.randint(0, int(num_nodes * self.max_percent_immune) + 1)
45 | sick, immune, atrisk = np.split(g_nx.nodes, [num_sick, num_sick + num_immune])
46 |
47 | node_features = torch.empty(num_nodes, self.node_features_shape)
48 | node_features[:, :2] = -1
49 | node_features[sick, 0] = 1
50 | node_features[immune, 1] = 1
51 | node_features[:, 2:].uniform_(-1, 1)
52 |
53 | # Create edge features: in person and virtual
54 | num_virtual = np.random.randint(0, int(num_edges * self.max_percent_virtual) + 1)
55 | virtual = np.random.randint(0, num_edges, size=num_virtual)
56 | edge_features = torch.empty(num_edges, self.edge_features_shape)
57 | edge_features[:, 0] = -1
58 | edge_features[virtual, 0] = 1
59 | edge_features[:, 1:].uniform_(-1, 1)
60 |
61 | g = tg.Graph.from_networkx(g_nx).evolve(node_features=node_features, edge_features=edge_features)
62 |
63 | # Create target by spreading the infection to the non-virtual neighbors who are not immune
64 | virtual = {list(g_nx.edges)[i] for i in virtual}
65 | infected = list({
66 | infection_target for infection_src in sick for infection_target in g_nx.neighbors(infection_src)
67 | if infection_target not in immune and (infection_src, infection_target) not in virtual
68 | })
69 |
70 | target = torch.zeros((num_nodes, 1), dtype=torch.int8)
71 | target[sick] = 1
72 | target[infected] = 1
73 | target = tg.Graph(node_features=target, global_features=target.sum(dim=0))
74 |
75 | return g, target
76 |
77 |
78 | def generate(cfg):
79 | from tqdm import trange
80 | from utils import set_seeds
81 |
82 | cfg.setdefault('seed', 0)
83 | set_seeds(cfg.seed)
84 | print(f'Random seed: {cfg.seed}')
85 |
86 | folder = Path(cfg.folder).expanduser().resolve()
87 | folder.mkdir(parents=True, exist_ok=True)
88 | print(f'Saving datasets in: {folder}')
89 |
90 | with open(folder / 'datasets.yaml', 'w') as f:
91 | f.write(cfg.toYAML())
92 | for p, params in cfg.datasets.items():
93 | dataset = InfectionDataset(**{k: v for k, v in params.items() if k != 'num_samples'})
94 | dataset.samples = [dataset.random_sample() for _ in
95 | trange(params.num_samples, desc=p.capitalize(), unit='samples', leave=True)]
96 | path = folder.joinpath(p).with_suffix('.pt')
97 | torch.save(dataset, path)
98 | print(f'{p.capitalize()}: saved {len(dataset)} samples in: {path}')
99 |
100 |
101 | def describe(cfg):
102 | import pandas as pd
103 | target = Path(cfg.target).expanduser().resolve()
104 | if target.is_dir():
105 | paths = target.glob('*.pt')
106 | else:
107 | paths = [target]
108 | for p in paths:
109 | print(f"Loading dataset from: {p}")
110 | dataset = torch.load(p)
111 | if not isinstance(dataset, InfectionDataset):
112 | raise ValueError(f'Not an InfectionDataset: {p}')
113 | print(f"{p.with_suffix('').name.capitalize()} contains:\n"
114 | f"min_nodes: {dataset.min_nodes}\n"
115 | f"max_nodes: {dataset.max_nodes}\n"
116 | f"max_percent_immune: {dataset.max_percent_immune}\n"
117 | f"max_percent_sick: {dataset.max_percent_sick}\n"
118 | f"node_features_shape: {dataset.node_features_shape}\n"
119 | f"edge_features_shape: {dataset.edge_features_shape}\n"
120 | f"samples: {len(dataset)}")
121 | df = pd.DataFrame.from_records(
122 | {
123 | 'num_nodes': g.num_nodes,
124 | 'num_edges': g.num_edges,
125 | 'degree': g.degree.float().mean().item(),
126 | 'infected': g.node_features[:, 0].sum().item() / g.num_nodes,
127 | 'immune': g.node_features[:, 1].sum().item() / g.num_nodes,
128 | 'infected_post': t.node_features[:, 0].sum().item() / t.num_nodes,
129 | } for g, t in dataset.samples)
130 | print(f'\n{df.describe()}')
131 |
132 |
133 | def main():
134 | from argparse import ArgumentParser
135 | from config import Config
136 |
137 | parser = ArgumentParser()
138 | subparsers = parser.add_subparsers()
139 |
140 | sp_print = subparsers.add_parser('print', help='Print parsed configuration')
141 | sp_print.add_argument('config', nargs='*')
142 | sp_print.set_defaults(command=lambda c: print(c.toYAML()))
143 |
144 | sp_generate = subparsers.add_parser('generate', help='Generate new datasets')
145 | sp_generate.add_argument('config', nargs='*')
146 | sp_generate.set_defaults(command=generate)
147 |
148 | sp_describe = subparsers.add_parser('describe', help='Describe existing datasets')
149 | sp_describe.add_argument('config', nargs='*')
150 | sp_describe.set_defaults(command=describe)
151 |
152 | args = parser.parse_args()
153 | cfg = Config.build(*args.config)
154 | args.command(cfg)
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/src/infection/layout.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import tensorflow as tf
5 |
6 | from tensorboard import summary as summary_lib
7 | from tensorboard.plugins.custom_scalar import layout_pb2
8 |
9 | layout_summary = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=[
10 | layout_pb2.Category(
11 | title='losses',
12 | chart=[
13 | # Chart 'losses' (include all losses, exclude upper and lower bounds)
14 | layout_pb2.Chart(
15 | title='losses',
16 | multiline=layout_pb2.MultilineChartContent(
17 | tag=[
18 | r'loss(?!.*bound.*)'
19 | ]
20 | )
21 | ),
22 | ])
23 | ]))
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('folder', help='The log folder to place the layout in')
27 | args = parser.parse_args()
28 |
29 | folder = (Path(args.folder) / 'layout').expanduser().resolve()
30 | with tf.summary.FileWriter(folder) as writer:
31 | writer.add_summary(layout_summary)
32 |
33 | print('Layout saved to', folder)
34 |
--------------------------------------------------------------------------------
/src/infection/networks.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 |
4 | import torch
5 | from torch import nn
6 |
7 | import torchgraphs as tg
8 |
9 |
10 | class InfectionGN(nn.Module):
11 | def __init__(self, aggregation, bias):
12 | super().__init__()
13 | self.encoder = nn.Sequential(OrderedDict({
14 | 'edge': tg.EdgeLinear(4, edge_features=2, bias=bias),
15 | 'edge_relu': tg.EdgeReLU(),
16 | 'node': tg.NodeLinear(8, node_features=4, bias=bias),
17 | 'node_relu': tg.NodeReLU(),
18 | }))
19 | self.hidden = nn.Sequential(OrderedDict({
20 | 'edge': tg.EdgeLinear(8, edge_features=4, sender_features=8, bias=bias),
21 | 'edge_relu': tg.EdgeReLU(),
22 | 'node': tg.NodeLinear(8, node_features=8, incoming_features=8, aggregation=aggregation, bias=bias),
23 | 'node_relu': tg.NodeReLU()
24 | }))
25 | self.readout_nodes = tg.NodeLinear(1, node_features=8, bias=True)
26 | self.readout_globals = tg.GlobalLinear(1, node_features=8, aggregation='sum', bias=bias)
27 |
28 | def forward(self, graphs):
29 | graphs = self.encoder(graphs)
30 | graphs = self.hidden(graphs)
31 | nodes = self.readout_nodes(graphs).node_features
32 | globals = self.readout_globals(graphs).global_features
33 |
34 | return graphs.evolve(
35 | node_features=nodes,
36 | num_edges=0,
37 | num_edges_by_graph=None,
38 | edge_features=None,
39 | global_features=globals,
40 | senders=None,
41 | receivers=None
42 | )
43 |
44 |
45 | def _reset_parameters(module):
46 | for name, param in module.named_parameters():
47 | if 'bias' in name:
48 | bound = 1 / math.sqrt(param.numel())
49 | nn.init.uniform_(param, -bound, bound)
50 | else:
51 | nn.init.kaiming_uniform_(param, a=math.sqrt(5))
52 |
53 |
54 | def describe(cfg):
55 | from pathlib import Path
56 | from utils import import_
57 | klass = import_(cfg.model.klass)
58 | model = klass(*cfg.model.args, **cfg.model.kwargs)
59 | if 'state_dict' in cfg:
60 | model.load_state_dict(torch.load(Path(cfg.state_dict).expanduser().resolve()))
61 | print(model)
62 | print(f'Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
63 |
64 | for name, parameter in model.named_parameters():
65 | print(f'{name} {tuple(parameter.shape)}:')
66 | if 'state_dict' in cfg:
67 | print(parameter.numpy().round())
68 | print()
69 |
70 |
71 | def main():
72 | from argparse import ArgumentParser
73 | from config import Config
74 |
75 | parser = ArgumentParser()
76 | subparsers = parser.add_subparsers()
77 |
78 | sp_print = subparsers.add_parser('print', help='Print parsed configuration')
79 | sp_print.add_argument('config', nargs='*')
80 | sp_print.set_defaults(command=lambda c: print(c.toYAML()))
81 |
82 | sp_describe = subparsers.add_parser('describe', help='Describe a model')
83 | sp_describe.add_argument('config', nargs='*')
84 | sp_describe.set_defaults(command=describe)
85 |
86 | args = parser.parse_args()
87 | cfg = Config.build(*args.config)
88 | args.command(cfg)
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
--------------------------------------------------------------------------------
/src/infection/notes.md:
--------------------------------------------------------------------------------
1 | # Infection
2 |
3 | ## Task
4 | 1. One node of the graph is _infected_ and identified through a 1 in one element of its feature vector.
5 | The remaining features are unused. The infection is spread to the neighbors (directed) of the infected nodes.
6 |
7 | 2. One node of the graph is _infected_ and identified through a 1 in the first element of its feature vector. Some nodes
8 | are immune and identified through a 1 in the second element of their feature vectors.
9 | Again, the infection is spread to the neighbors, but some of them are immune.
10 |
11 | The network should output a prediction of 1 for nodes that are infected and 0 for the others, effectively
12 | identifying the neighbors of the infected node, the connection type and the immunity status.
13 | The network also outputs a graph-level prediction that should correspond to the total number of infected nodes.
14 |
15 | ## Data
16 |
17 | ### Training
18 | 100,000 graphs are generated using the [Barabási–Albert model](https://en.wikipedia.org/wiki/Barab%C3%A1si%E2%80%93Albert_model),
19 | every graph contains between 10 and 30 nodes. Up to 10% of the nodes are sick and up to 30% are immune. Edges are virtual with a percentage of up to 30%.
20 |
21 | ### Testing
22 | 20,000 graphs for testing are generated in a similar way, but containing between 10 and 60 nodes.
23 | Up to 40% of the nodes are sick and up to 60% are immune. Edges are virtual with a percentage of up to 50%.
24 |
25 | ## Losses
26 |
27 | Three losses are used for training: a node-level loss, a global-level loss and a regularization term.
28 | The three terms are added together in a weighted sum and constitute the final training objective.
29 |
30 | ### Node-level classification
31 | At node-level, the network has to output a single binary prediction of whether the node will be sick or not.
32 | The loss on this prediction is computed as Binary Cross Entropy.
33 |
34 | ### Global-level regression
35 | The network should also output a global-level prediction corresponding to the total number of infected nodes.
36 | The loss on this prediction is computed as Mean Squared Error
37 |
38 | Since the total number of infected nodes in the training set is not homogeneously distributed, we weight the losses
39 | computed on individual graphs using the negative log frequency of the true value. For example, if 15 is the
40 | ground-truth number of infected nodes after the spread for a given input graph, we weight the MSE of the network's
41 | prediction as `-ln(# of graphs with 15 infected nodes in the training set / # number of graphs in the training set)`
42 |
43 | ### L1 regularization
44 | The weights of the network are regularized with L1 regularization.
45 |
46 | ## Workflow
47 |
48 | 1. Create base folder
49 | ```bash
50 | INFECTION=~/experiments/infection/
51 | mkdir -p "$INFECTION/"{runs,data}
52 | ```
53 | 2. Create dataset (plus a small one for debug)
54 | ```bash
55 | python -m infection.dataset generate ../config/infection/datasets.yaml "folder=${INFECTION}/data"
56 | python -m infection.dataset generate ../config/infection/datasets.yaml \
57 | "folder=${INFECTION}/smalldata" \
58 | datasets.{train,val}.num_samples=5000
59 | ```
60 | 4. Launch one experiment:
61 | ```bash
62 | python -m infection.train \
63 | --experiment ../config/infection/train.yaml \
64 | --model ../config/infection/minimal.yaml
65 | ```
66 | Or make a grid search over the hyperparameters:
67 | ```bash
68 | conda activate tg-experiments
69 | function train {
70 | python -m infection.train \
71 | --experiment ../config/infection/train.yaml \
72 | "tags=[${1},lr${2},nodes${4},count${5},wd${3}]" \
73 | --model "../config/infection/${1}.yaml" \
74 | --optimizer "kwargs.lr=${2}" \
75 | --session "losses.l1=${3}" "losses.nodes=${4}" "losses.count=${5}"
76 | }
77 | export -f train # use bash otherwise `export -f` won't work
78 | parallel --eta --max-procs 6 --load 80% --noswap 'train {1} {2} {3} {4} {5}' \
79 | `# Architecture` ::: infectionGN \
80 | `# Learning rate` ::: .01 .001 \
81 | `# L1 loss` ::: 0 .0001 \
82 | `# Infection loss` ::: 1 .1 \
83 | `# Count loss` ::: 1 .1
84 | ```
85 | 6. Visualize logs:
86 | ```bash
87 | tensorboard --logdir "$INFECTION/runs"
88 | ```
89 |
--------------------------------------------------------------------------------
/src/infection/predict.py:
--------------------------------------------------------------------------------
1 | # TODO this might be incompatible with recent code changes
2 |
3 | import tqdm
4 | import yaml
5 | import pyaml
6 | import multiprocessing
7 | from pathlib import Path
8 | from argparse import ArgumentParser
9 |
10 | import numpy as np
11 | from munch import AutoMunch, Munch
12 |
13 | import torch
14 | import torch.utils.data
15 | import torch.nn.functional as F
16 | import torchgraphs as tg
17 |
18 | from utils import parse_dotted, update_rec, import_
19 | from .dataset import InfectionDataset
20 |
21 | parser = ArgumentParser()
22 | parser.add_argument('--model', nargs='+', required=True)
23 | parser.add_argument('--data', nargs='+', required=True, default=[])
24 | parser.add_argument('--options', nargs='+', required=False, default=[])
25 | parser.add_argument('--output', type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 |
30 | # region Collecting phase
31 |
32 | # Defaults
33 | model = Munch(fn=None, args=[], kwargs={}, state_dict=None)
34 | data = []
35 | options = AutoMunch()
36 | options.cpus = multiprocessing.cpu_count() - 1
37 | options.device = 'cuda' if torch.cuda.is_available() else 'cpu'
38 | options.output = args.output
39 |
40 | # Model from --model args
41 | for string in args.model:
42 | if '=' in string:
43 | update = parse_dotted(string)
44 | else:
45 | with open(string, 'r') as f:
46 | update = yaml.safe_load(f)
47 | # If the yaml file contains an entry with key `model` use that one instead
48 | if 'model' in update.keys():
49 | update = update['model']
50 | update_rec(model, update)
51 |
52 | # Data from --data args
53 | for path in args.data:
54 | path = Path(path).expanduser().resolve()
55 | if path.is_dir():
56 | data.extend(path.glob('*.pt'))
57 | elif path.is_file() and path.suffix == '.pt':
58 | data.append(path)
59 | else:
60 | raise ValueError(f'Invalid data: {path}')
61 |
62 | # Options from --options args
63 | for string in args.options:
64 | if '=' in string:
65 | update = parse_dotted(string)
66 | else:
67 | with open(string, 'r') as f:
68 | update = yaml.safe_load(f)
69 | update_rec(options, update)
70 |
71 | # Resolving paths
72 | model.state_dict = Path(model.state_dict).expanduser().resolve()
73 | options.output = Path(options.output).expanduser().resolve()
74 |
75 | # Checks (some missing, others redundant)
76 | if model.fn is None:
77 | raise ValueError('Model constructor function not defined')
78 | if model.state_dict is None:
79 | raise ValueError(f'Model state dict is required to predict')
80 | if len(data) == 0:
81 | raise ValueError(f'No data to predict')
82 | if options.cpus < 0:
83 | raise ValueError(f'Invalid number of cpus: {options.cpus}')
84 | if options.output.exists() and not options.output.is_dir():
85 | raise ValueError(f'Invalid output path {options.output}')
86 |
87 |
88 | pyaml.pprint({'model': model, 'options': options, 'data': data}, sort_dicts=False, width=200)
89 | # endregion
90 |
91 | # region Building phase
92 | # Model
93 | net: torch.nn.Module = import_(model.fn)(*model.args, **model.kwargs)
94 | net.load_state_dict(torch.load(model.state_dict))
95 | net.to(options.device)
96 |
97 | # Output folder
98 | options.output.mkdir(parents=True, exist_ok=True)
99 | # endregion
100 |
101 | # region Training
102 | # Dataset and dataloader
103 | dataset_predict: InfectionDataset = torch.load(data[0])
104 | dataloader_predict = torch.utils.data.DataLoader(
105 | dataset_predict,
106 | shuffle=False,
107 | num_workers=min(options.cpus, 1) if 'cuda' in options.device else options.cpus,
108 | pin_memory='cuda' in options.device,
109 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)),
110 | batch_size=options.batch_size,
111 | collate_fn=tg.GraphBatch.collate,
112 | )
113 |
114 | # region Predict
115 | net.eval()
116 | torch.set_grad_enabled(False)
117 | i = 0
118 | with tqdm.tqdm(desc='Predict', total=len(dataloader_predict.dataset), unit='g') as bar:
119 | for graphs, *_ in dataloader_predict:
120 | graphs = graphs.to(options.device)
121 |
122 | results = net(graphs)
123 | results.node_features.sigmoid_()
124 |
125 | for result in results:
126 | torch.save(result.cpu(), options.output / f'output_{i:06d}.pt')
127 | i += 1
128 |
129 | bar.update(graphs.num_graphs)
130 | # endregion
131 |
--------------------------------------------------------------------------------
/src/infection/train.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import yaml
3 | import pyaml
4 | import random
5 | import textwrap
6 | import multiprocessing
7 | from pathlib import Path
8 | from datetime import datetime
9 | from argparse import ArgumentParser
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import sklearn.metrics
14 | from munch import AutoMunch
15 |
16 | import torch
17 | import torch.utils.data
18 | import torch.nn.functional as F
19 | import torch_scatter
20 | import torchgraphs as tg
21 | from tensorboardX import SummaryWriter
22 |
23 | from saver import Saver
24 | from utils import git_info, cuda_info, parse_dotted, update_rec, set_seeds, import_, sort_dict, RunningWeightedAverage
25 | from .dataset import InfectionDataset
26 |
27 | parser = ArgumentParser()
28 | parser.add_argument('--experiment', nargs='+', required=True)
29 | parser.add_argument('--model', nargs='+', required=False, default=[])
30 | parser.add_argument('--optimizer', nargs='+', required=False, default=[])
31 | parser.add_argument('--session', nargs='+', required=False, default=[])
32 |
33 | args = parser.parse_args()
34 |
35 |
36 | # region Collecting phase
37 | class Experiment(AutoMunch):
38 | @property
39 | def session(self):
40 | return self.sessions[-1]
41 |
42 |
43 | experiment = Experiment()
44 |
45 | # Experiment defaults
46 | experiment.name = 'experiment'
47 | experiment.tags = []
48 | experiment.samples = 0
49 | experiment.model = {'fn': None, 'args': [], 'kwargs': {}}
50 | experiment.optimizer = {'fn': None, 'args': [], 'kwargs': {}}
51 | experiment.sessions = []
52 |
53 | # Session defaults
54 | session = AutoMunch()
55 | session.losses = {'nodes': 0, 'count': 0, 'l1': 0}
56 | session.seed = random.randint(0, 99)
57 | session.cpus = multiprocessing.cpu_count() - 1
58 | session.device = 'cuda' if torch.cuda.is_available() else 'cpu'
59 | session.log = {'when': []}
60 | session.checkpoint = {'when': []}
61 |
62 | # Experiment configuration
63 | for string in args.experiment:
64 | if '=' in string:
65 | update = parse_dotted(string)
66 | else:
67 | with open(string, 'r') as f:
68 | update = yaml.safe_load(f)
69 | # If the current session is defined inside the experiment update the session instead
70 | if 'session' in update:
71 | update_rec(session, update.pop('session'))
72 | update_rec(experiment, update)
73 |
74 | # Model from --model args
75 | for string in args.model:
76 | if '=' in string:
77 | update = parse_dotted(string)
78 | else:
79 | with open(string, 'r') as f:
80 | update = yaml.safe_load(f)
81 | # If the yaml document contains a single entry with key `model` use that one instead
82 | if update.keys() == {'model'}:
83 | update = update['model']
84 | update_rec(experiment.model, update)
85 | del update
86 |
87 | # Optimizer from --optimizer args
88 | for string in args.optimizer:
89 | if '=' in string:
90 | update = parse_dotted(string)
91 | else:
92 | with open(string, 'r') as f:
93 | update = yaml.safe_load(f)
94 | # If the yaml document contains a single entry with key `optimizer` use that one instead
95 | if update.keys() == {'optimizer'}:
96 | update = update['optimizer']
97 | update_rec(experiment.optimizer, update)
98 | del update
99 |
100 | # Session from --session args
101 | for string in args.session:
102 | if '=' in string:
103 | update = parse_dotted(string)
104 | else:
105 | with open(string, 'r') as f:
106 | update = yaml.safe_load(f)
107 | # If the yaml document contains a single entry with key `session` use that one instead
108 | if update.keys() == {'session'}:
109 | update = update['session']
110 | update_rec(session, update)
111 | del update
112 |
113 | # Checks (some missing, others redundant)
114 | if experiment.name is None or len(experiment.name) == 0:
115 | raise ValueError(f'Experiment name is empty: {experiment.name}')
116 | if experiment.tags is None:
117 | raise ValueError('Experiment tags is None')
118 | if experiment.model.fn is None:
119 | raise ValueError('Model constructor function not defined')
120 | if experiment.optimizer.fn is None:
121 | raise ValueError('Optimizer constructor function not defined')
122 | if session.cpus < 0:
123 | raise ValueError(f'Invalid number of cpus: {session.cpus}')
124 | if any(l < 0 for l in session.losses.values()) or all(l == 0 for l in session.losses.values()):
125 | raise ValueError(f'Invalid losses: {session.losses}')
126 | if len(experiment.sessions) > 0 and ('state_dict' not in experiment.model or 'state_dict' not in experiment.optimizer):
127 | raise ValueError(f'Model and optimizer state dicts are required to restore training')
128 |
129 | # Experiment computed fields
130 | experiment.epoch = sum((s.epochs for s in experiment.sessions), 0)
131 |
132 | # Session computed fields
133 | session.status = 'NEW'
134 | session.datetime_started = None
135 | session.datetime_completed = None
136 | git = git_info()
137 | if git is not None:
138 | session.git = git
139 | if 'cuda' in session.device:
140 | session.cuda = cuda_info()
141 |
142 | # Resolving paths
143 | rand_id = ''.join(chr(random.randint(ord('A'), ord('Z'))) for _ in range(6))
144 | session.data.folder = Path(session.data.folder.replace('{name}', experiment.name)).expanduser().resolve().as_posix()
145 | session.log.folder = session.log.folder \
146 | .replace('{name}', experiment.name) \
147 | .replace('{tags}', '_'.join(experiment.tags)) \
148 | .replace('{rand}', rand_id)
149 | if len(session.checkpoint.when) > 0:
150 | if len(session.log.when) > 0:
151 | session.log.folder = Path(session.log.folder).expanduser().resolve().as_posix()
152 | session.checkpoint.folder = session.checkpoint.folder \
153 | .replace('{name}', experiment.name) \
154 | .replace('{tags}', '_'.join(experiment.tags)) \
155 | .replace('{rand}', rand_id)
156 | session.checkpoint.folder = Path(session.checkpoint.folder).expanduser().resolve().as_posix()
157 | if 'state_dict' in experiment.model:
158 | experiment.model.state_dict = Path(experiment.model.state_dict).expanduser().resolve().as_posix()
159 | if 'state_dict' in experiment.optimizer:
160 | experiment.optimizer.state_dict = Path(experiment.optimizer.state_dict).expanduser().resolve().as_posix()
161 |
162 | sort_dict(experiment, ['name', 'tags', 'epoch', 'samples', 'model', 'optimizer', 'sessions'])
163 | sort_dict(session, ['epochs', 'batch_size', 'losses', 'seed', 'cpus', 'device', 'samples', 'status',
164 | 'datetime_started', 'datetime_completed', 'data', 'log', 'checkpoint', 'git', 'gpus'])
165 | experiment.sessions.append(session)
166 | pyaml.pprint(experiment, sort_dicts=False, width=200)
167 | del session
168 | # endregion
169 |
170 | # region Building phase
171 | # Seeds (set them after the random run id is generated)
172 | set_seeds(experiment.session.seed)
173 |
174 | # Model
175 | model: torch.nn.Module = import_(experiment.model.fn)(*experiment.model.args, **experiment.model.kwargs)
176 | if 'state_dict' in experiment.model:
177 | model.load_state_dict(torch.load(experiment.model.state_dict))
178 | model.to(experiment.session.device)
179 |
180 | # Optimizer
181 | optimizer: torch.optim.Optimizer = import_(experiment.optimizer.fn)(
182 | model.parameters(), *experiment.optimizer.args, **experiment.optimizer.kwargs)
183 | if 'state_dict' in experiment.optimizer:
184 | optimizer.load_state_dict(torch.load(experiment.optimizer.state_dict))
185 |
186 | # Logger
187 | if len(experiment.session.log.when) > 0:
188 | logger = SummaryWriter(experiment.session.log.folder)
189 | logger.add_text(
190 | 'Experiment', textwrap.indent(pyaml.dump(experiment, safe=True, sort_dicts=False), ' '), experiment.samples)
191 | else:
192 | logger = None
193 |
194 | # Saver
195 | if len(experiment.session.checkpoint.when) > 0:
196 | saver = Saver(experiment.session.checkpoint.folder)
197 | if experiment.epoch == 0:
198 | saver.save_experiment(experiment, suffix=f'e{experiment.epoch:04d}')
199 | else:
200 | saver = None
201 | # endregion
202 |
203 | # Datasets and dataloaders
204 | dataloader_kwargs = dict(
205 | num_workers=min(experiment.session.cpus, 1) if 'cuda' in experiment.session.device else experiment.session.cpus,
206 | pin_memory='cuda' in experiment.session.device,
207 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)),
208 | batch_size=experiment.session.batch_size,
209 | collate_fn=tg.GraphBatch.collate,
210 | )
211 |
212 | dataset_train: InfectionDataset = torch.load(Path(experiment.session.data.folder) / 'train.pt')
213 | dataloader_train = torch.utils.data.DataLoader(
214 | dataset_train,
215 | shuffle=True,
216 | **dataloader_kwargs
217 | )
218 | count_weights = pd.Series(t.global_features.item() for g, t in dataset_train) \
219 | .value_counts(normalize=True) \
220 | .apply(np.log) \
221 | .apply(np.negative) \
222 | .astype(np.float32) \
223 | .sort_index()
224 |
225 | dataset_val: InfectionDataset = torch.load(Path(experiment.session.data.folder) / 'val.pt')
226 | dataloader_val = torch.utils.data.DataLoader(
227 | dataset_val,
228 | shuffle=False,
229 | **dataloader_kwargs
230 | )
231 | del dataloader_kwargs
232 |
233 | # region Training
234 | # Train and validation loops
235 | experiment.session.status = 'RUNNING'
236 | experiment.session.datetime_started = datetime.utcnow()
237 |
238 | graphs_df = {k: [] for k in ['LossInfection', 'LossCount', 'Nodes', 'Edges', 'InfectedStart', 'InfectedEnd',
239 | 'InfectedSum', 'InfectedCount', 'MeanPercError', 'AvgPrecision', 'AreaROC']}
240 | nodes_df = {k: [] for k in ['Targets', 'Results']}
241 |
242 | epoch_bar_postfix = {}
243 | epoch_bar = tqdm.trange(1, experiment.session.epochs + 1, desc='Epochs', unit='e', leave=True)
244 | for epoch_idx in epoch_bar:
245 | experiment.epoch += 1
246 |
247 | # region Training loop
248 | model.train()
249 | torch.set_grad_enabled(True)
250 |
251 | train_bar_postfix = {}
252 | metric_mpe_avg = RunningWeightedAverage()
253 | loss_bce_avg = RunningWeightedAverage()
254 | loss_count_avg = RunningWeightedAverage()
255 | loss_l1_avg = RunningWeightedAverage()
256 | loss_total_avg = RunningWeightedAverage()
257 |
258 | train_bar = tqdm.tqdm(desc=f'Train {experiment.epoch}', total=len(dataloader_train.dataset), unit='g')
259 | for graphs, targets in dataloader_train:
260 | graphs = graphs.to(experiment.session.device)
261 | targets = targets.to(experiment.session.device)
262 | results = model(graphs)
263 |
264 | loss_total = torch.tensor(0., device=experiment.session.device)
265 |
266 | if experiment.session.losses.nodes > 0:
267 | loss_bce = F.binary_cross_entropy_with_logits(
268 | results.node_features.squeeze(), targets.node_features.squeeze().float(), reduction='mean')
269 | loss_total += experiment.session.losses.nodes * loss_bce
270 | loss_bce_avg.add(loss_bce.item(), len(graphs))
271 | train_bar_postfix['Infected'] = f'{loss_bce.item():.5f}'
272 | if 'every batch' in experiment.session.log.when:
273 | logger.add_scalar('loss/train/infected', loss_bce.item(), global_step=experiment.samples)
274 |
275 | if experiment.session.losses.count > 0:
276 | loss_count = F.mse_loss(
277 | results.global_features.squeeze(), targets.global_features.squeeze().float(), reduction='none')
278 | weights = loss_count.new_tensor(count_weights.loc[targets.global_features.squeeze().cpu().numpy()].values)
279 | loss_total += experiment.session.losses.count * torch.mean(loss_count * weights)
280 | loss_count_avg.add(loss_count.mean().item(), len(graphs))
281 | train_bar_postfix['Count'] = f'{loss_count.mean().item():.5f}'
282 |
283 | metric_mpe = ((results.global_features - targets.global_features.float()).abs() /
284 | targets.global_features.float()).mean()
285 | metric_mpe_avg.add(metric_mpe,len(graphs))
286 |
287 | if 'every batch' in experiment.session.log.when:
288 | logger.add_scalar('loss/train/count', loss_count.mean().item(), global_step=experiment.samples)
289 | logger.add_scalar('metric/train/mpe', metric_mpe.item(), global_step=experiment.samples)
290 |
291 | if experiment.session.losses.l1 > 0:
292 | loss_l1 = sum([p.abs().sum() for p in model.parameters()])
293 | loss_total += experiment.session.losses.l1 * loss_l1
294 | loss_l1_avg.add(loss_l1.item(), len(graphs))
295 | train_bar_postfix['L1'] = f'{loss_l1.item():.5f}'
296 | if 'every batch' in experiment.session.log.when:
297 | logger.add_scalar('loss/train/l1', loss_l1.item(), global_step=experiment.samples)
298 |
299 | loss_total_avg.add(loss_total.item(), len(graphs))
300 | train_bar_postfix['Total'] = f'{loss_total.item():.5f}'
301 | if 'every batch' in experiment.session.log.when:
302 | logger.add_scalar('loss/train/total', loss_total.item(), global_step=experiment.samples)
303 |
304 | optimizer.zero_grad()
305 | loss_total.backward()
306 | optimizer.step(closure=None)
307 |
308 | experiment.samples += len(graphs)
309 | train_bar.update(len(graphs))
310 | train_bar.set_postfix(train_bar_postfix)
311 | train_bar.close()
312 |
313 | epoch_bar_postfix['Train'] = f'{loss_total_avg.get():.4f}'
314 | epoch_bar.set_postfix(epoch_bar_postfix)
315 |
316 | if 'every epoch' in experiment.session.log.when and 'every batch' not in experiment.session.log.when:
317 | logger.add_scalar('loss/train/total', loss_total_avg.get(), global_step=experiment.samples)
318 | if experiment.session.losses.nodes > 0:
319 | logger.add_scalar('loss/train/infected', loss_bce_avg.get(), global_step=experiment.samples)
320 | if experiment.session.losses.count > 0:
321 | logger.add_scalar('loss/train/count', loss_count_avg.get(), global_step=experiment.samples)
322 | logger.add_scalar('metric/train/mpe', metric_mpe_avg.get(), global_step=experiment.samples)
323 | if experiment.session.losses.l1 > 0:
324 | logger.add_scalar('loss/train/l1', loss_l1_avg.get(), global_step=experiment.samples)
325 |
326 | del train_bar, train_bar_postfix, loss_bce_avg, loss_count_avg, loss_l1_avg, loss_total_avg, metric_mpe_avg
327 | # endregion
328 |
329 | # region Validation loop
330 | model.eval()
331 | torch.set_grad_enabled(False)
332 |
333 | val_bar_postfix = {}
334 | metric_mpe_avg = RunningWeightedAverage()
335 | loss_bce_avg = RunningWeightedAverage()
336 | loss_count_avg = RunningWeightedAverage()
337 | loss_total_avg = RunningWeightedAverage()
338 | loss_l1 = sum([p.abs().sum() for p in model.parameters()])
339 | val_bar_postfix['L1'] = f'{loss_l1.item():.5f}'
340 |
341 | val_bar = tqdm.tqdm(desc=f'Val {experiment.epoch}', total=len(dataloader_val.dataset), unit='g')
342 | for batch_idx, (graphs, targets) in enumerate(dataloader_val):
343 | graphs = graphs.to(experiment.session.device)
344 | targets = targets.to(experiment.session.device)
345 | results = model(graphs)
346 |
347 | loss_total = torch.tensor(0., device=experiment.session.device)
348 |
349 | if experiment.session.losses.nodes > 0:
350 | loss_bce = F.binary_cross_entropy_with_logits(
351 | results.node_features.squeeze(), targets.node_features.squeeze().float(), reduction='mean')
352 | loss_total += experiment.session.losses.nodes * loss_bce
353 | loss_bce_avg.add(loss_bce.item(), len(graphs))
354 | val_bar_postfix['Infected'] = f'{loss_bce.item():.5f}'
355 |
356 | if experiment.session.losses.count > 0:
357 | loss_count = F.mse_loss(
358 | results.global_features.squeeze(), targets.global_features.squeeze().float(), reduction='mean')
359 | loss_total += experiment.session.losses.count * loss_count
360 | loss_count_avg.add(loss_count.item(), len(graphs))
361 | val_bar_postfix['Count'] = f'{loss_count.item():.5f}'
362 |
363 | metric_mpe_avg.add(
364 | ((results.global_features - targets.global_features.float()).abs() /
365 | targets.global_features.float()).mean(),
366 | len(graphs)
367 | )
368 |
369 | if experiment.session.losses.l1 > 0:
370 | loss_total += experiment.session.losses.l1 * loss_l1
371 |
372 | val_bar_postfix['Total'] = f'{loss_total.item():.5f}'
373 | loss_total_avg.add(loss_total.item(), len(graphs))
374 |
375 | # region Last epoch
376 | if epoch_idx == experiment.session.epochs:
377 | loss_bce_by_graph = torch_scatter.scatter_mean(
378 | F.binary_cross_entropy_with_logits(
379 | results.node_features.squeeze(), targets.node_features.squeeze().float(), reduction='none'),
380 | index=tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph),
381 | dim=0,
382 | dim_size=graphs.num_graphs
383 | )
384 | loss_count_by_graph = F.mse_loss(
385 | results.global_features.squeeze(), targets.global_features.squeeze().float(), reduction='none')
386 | mpe_by_graph = (
387 | (results.global_features - targets.global_features.float()).abs() /
388 | targets.global_features.float()
389 | ).squeeze()
390 | infected_by_graph_start_true = torch_scatter.scatter_add(
391 | graphs.node_features[:, 0].int().clamp(min=0),
392 | index=tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph),
393 | dim=0,
394 | dim_size=graphs.num_graphs
395 | )
396 | infected_by_graph_sum_pred = torch_scatter.scatter_add(
397 | results.node_features.squeeze().sigmoid(),
398 | index=tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph),
399 | dim=0,
400 | dim_size=graphs.num_graphs
401 | )
402 | avg_prec_by_graph = []
403 | area_roc_by_graph = []
404 | for t, r in zip(targets.node_features_by_graph, results.node_features_by_graph):
405 | avg_prec_by_graph.append(sklearn.metrics.average_precision_score(
406 | y_true=t.squeeze().cpu().int(), # numpy does not work with torch.int8
407 | y_score=r.squeeze().sigmoid().cpu())
408 | )
409 | try:
410 | area_roc_by_graph.append(sklearn.metrics.roc_auc_score(
411 | y_true=t.squeeze().cpu().int(), # numpy does not work with torch.int8
412 | y_score=r.squeeze().sigmoid().cpu())
413 | )
414 | except ValueError:
415 | # ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.
416 | area_roc_by_graph.append(np.nan)
417 |
418 | nodes_df['Targets'].append(targets.node_features.squeeze().cpu().int()) # numpy doesn't convert torch.int8
419 | nodes_df['Results'].append(results.node_features.squeeze().sigmoid().cpu())
420 |
421 | graphs_df['LossInfection'].append(loss_bce_by_graph.cpu())
422 | graphs_df['LossCount'].append(loss_count_by_graph.cpu())
423 | graphs_df['Nodes'].append(graphs.num_nodes_by_graph.cpu())
424 | graphs_df['Edges'].append(graphs.num_edges_by_graph.cpu())
425 | graphs_df['InfectedStart'].append(infected_by_graph_start_true.cpu())
426 | graphs_df['InfectedEnd'].append(targets.global_features.squeeze().cpu())
427 | graphs_df['InfectedSum'].append(infected_by_graph_sum_pred.cpu())
428 | graphs_df['InfectedCount'].append(results.global_features.squeeze().cpu())
429 | graphs_df['MeanPercError'].append(mpe_by_graph.cpu())
430 | graphs_df['AvgPrecision'].append(np.array(avg_prec_by_graph))
431 | graphs_df['AreaROC'].append(np.array(area_roc_by_graph))
432 | # endregion
433 |
434 | val_bar.update(len(graphs))
435 | val_bar.set_postfix(val_bar_postfix)
436 | val_bar.close()
437 |
438 | epoch_bar_postfix['Val'] = f'{loss_total_avg.get():.4f}'
439 | epoch_bar.set_postfix(epoch_bar_postfix)
440 |
441 | if (
442 | 'every batch' in experiment.session.log.when or
443 | 'every epoch' in experiment.session.log.when or
444 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs
445 | ):
446 | logger.add_scalar('loss/val/total', loss_total_avg.get(), global_step=experiment.samples)
447 | if experiment.session.losses.nodes > 0:
448 | logger.add_scalar('loss/val/infected', loss_bce_avg.get(), global_step=experiment.samples)
449 | if experiment.session.losses.count > 0:
450 | logger.add_scalar('loss/val/count', loss_count_avg.get(), global_step=experiment.samples)
451 | logger.add_scalar('metric/val/mpe', metric_mpe_avg.get(), global_step=experiment.samples)
452 | if experiment.session.losses.l1 > 0:
453 | logger.add_scalar('loss/val/l1', loss_l1.item(), global_step=experiment.samples)
454 |
455 | del val_bar, val_bar_postfix, loss_bce_avg, loss_count_avg, loss_l1, loss_total_avg, metric_mpe_avg, batch_idx
456 | # endregion
457 |
458 | # Saving
459 | if epoch_idx == experiment.session.epochs:
460 | experiment.session.status = 'DONE'
461 | experiment.session.datetime_completed = datetime.utcnow()
462 | if (
463 | 'every batch' in experiment.session.checkpoint.when or
464 | 'every epoch' in experiment.session.checkpoint.when or
465 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs
466 | ):
467 | saver.save(model, experiment, optimizer, suffix=f'e{experiment.epoch:04d}')
468 | epoch_bar.close()
469 | print()
470 | del epoch_bar, epoch_bar_postfix, epoch_idx
471 | # endregion
472 |
473 | # region Final report
474 | pd.options.display.precision = 2
475 | pd.options.display.max_columns = 999
476 | pd.options.display.expand_frame_repr = False
477 |
478 | nodes_df = pd.DataFrame({k: np.concatenate(v) for k, v in nodes_df.items()})
479 | experiment.average_precision = sklearn.metrics.average_precision_score(
480 | y_true=nodes_df.Targets, y_score=nodes_df.Results)
481 | print('Average precision:', experiment.average_precision)
482 | if logger is not None:
483 | logger.add_scalar('metrics/val/avg_precision', experiment.average_precision, global_step=experiment.samples)
484 | logger.add_pr_curve('infection', labels=nodes_df.Targets.values,
485 | predictions=nodes_df.Results.values, global_step=experiment.samples)
486 |
487 | # noinspection PyUnreachableCode
488 | if False:
489 | import matplotlib.pyplot as plt
490 |
491 | precision, recall, _ = sklearn.metrics.precision_recall_curve(
492 | y_true=nodes_df.Targets, probas_pred=nodes_df.Results)
493 | plt.step(recall, precision, color='b', alpha=0.2, where='post')
494 | plt.fill_between(recall, precision, alpha=0.2, color='b', step='post')
495 | plt.xlabel('Recall')
496 | plt.ylabel('Precision')
497 | plt.ylim([0.0, 1.05])
498 | plt.xlim([0.0, 1.0])
499 | plt.title(f'Precision-Recall curve: AP={experiment.average_precision:.2f}')
500 | plt.show()
501 | del nodes_df
502 |
503 | graphs_df = pd.DataFrame({k: np.concatenate(v) for k, v in graphs_df.items()}).rename_axis('GraphId').reset_index()
504 | experiment.loss_count = graphs_df.LossCount.mean()
505 | experiment.loss_infection = graphs_df.LossInfection.mean()
506 | experiment.mpe = graphs_df.MeanPercError.mean()
507 | print('Count MSE:', experiment.loss_count)
508 | print('Infection BCE:', experiment.loss_infection)
509 | print('Count MPE:', experiment.average_precision)
510 |
511 | # Split the results based on whether the number of nodes was present in the training set or not
512 | df_train_val = graphs_df \
513 | .groupby(np.where(graphs_df.Nodes < dataset_train.max_nodes,
514 | f'Train [{dataset_train.min_nodes}, {dataset_train.max_nodes - 1})',
515 | f'Val [{dataset_train.max_nodes}, {dataset_val.max_nodes - 1})')) \
516 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'LossInfection': 'mean', 'LossCount': 'mean'}) \
517 | .sort_index(ascending=True) \
518 | .rename_axis(index='Dataset') \
519 | .rename(str.capitalize, axis='columns', level=1)
520 |
521 | # Split the results in ranges based on the number of nodes and compute the average loss per range
522 | df_losses_by_node_range = graphs_df \
523 | .groupby(graphs_df.Nodes // 10) \
524 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'LossInfection': 'mean', 'LossCount': 'mean'}) \
525 | .rename_axis(index='NodeRange') \
526 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index') \
527 | .rename(str.capitalize, axis='columns', level=1)
528 |
529 | # Split the results in ranges based on the number of nodes and keep the N worst predictions w.r.t. node-wise loss
530 | df_worst_infection_loss_by_node_range = graphs_df \
531 | .groupby(graphs_df.Nodes // 10) \
532 | .apply(lambda df_gr: df_gr.nlargest(5, 'LossInfection').set_index('GraphId')) \
533 | .rename_axis(index={'Nodes': 'NodeRange'}) \
534 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0)
535 |
536 | # Split the results in ranges based on the number of nodes and keep the N worst predictions w.r.t. graph-wise loss
537 | df_worst_count_loss_by_node_range = graphs_df \
538 | .groupby(graphs_df.Nodes // 10) \
539 | .apply(lambda df_gr: df_gr.nlargest(5, 'LossCount').set_index('GraphId')) \
540 | .rename_axis(index={'Nodes': 'NodeRange'}) \
541 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0)
542 |
543 | print(f"""
544 | Generalization:
545 | {df_train_val}\n
546 | Losses by range:
547 | {df_losses_by_node_range}\n
548 | Worst infection predictions:
549 | {df_worst_infection_loss_by_node_range}\n
550 | Worst count predictions:
551 | {df_worst_count_loss_by_node_range}
552 | """)
553 |
554 | if logger is not None:
555 | logger.add_text(
556 | 'Generalization',
557 | textwrap.indent(df_train_val.to_string(), ' '),
558 | global_step=experiment.samples)
559 | logger.add_text(
560 | 'Losses by range',
561 | textwrap.indent(df_losses_by_node_range.to_string(), ' '),
562 | global_step=experiment.samples)
563 | logger.add_text(
564 | 'Worst infection predictions',
565 | textwrap.indent(df_worst_infection_loss_by_node_range.to_string(), ' '),
566 | global_step=experiment.samples),
567 | logger.add_text(
568 | 'Worst count predictions',
569 | textwrap.indent(df_worst_count_loss_by_node_range.to_string(), ' '),
570 | global_step=experiment.samples)
571 | del graphs_df, df_losses_by_node_range, df_worst_infection_loss_by_node_range, df_worst_count_loss_by_node_range
572 |
573 | params = [f'{name}:\n{param.data.cpu().numpy().round(3)}' for name, param in model.named_parameters()]
574 | print('Parameters:', *params, sep='\n\n')
575 | if logger is not None:
576 | logger.add_text('Parameters', textwrap.indent('\n\n'.join(params), ' '), global_step=experiment.samples)
577 | del params
578 | # endregion
579 |
580 | # region Cleanup
581 | if logger is not None:
582 | logger.close()
583 | exit()
584 | # endregion
585 |
--------------------------------------------------------------------------------
/src/relevance/__init__.py:
--------------------------------------------------------------------------------
1 | # TODO split these into separate modules
2 | from .autograd_tricks import add, sum, \
3 | cat, index_select, repeat_tensor, \
4 | scatter_add, scatter_mean, scatter_max, \
5 | linear_eps, relu, get_aggregation
6 | from .graphs import EdgeLinearRelevance, NodeLinearRelevance, GlobalLinearRelevance, \
7 | EdgeReLURelevance, NodeReLURelevance, GlobalReLURelevance
8 |
--------------------------------------------------------------------------------
/src/relevance/autograd_tricks.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numpy as np
3 |
4 | import torch
5 | import torch_scatter
6 |
7 |
8 | class AddRelevance(torch.autograd.Function):
9 | @staticmethod
10 | def forward(ctx, a, b):
11 | out = a + b
12 | ctx.save_for_backward(a, b, out)
13 | return out
14 |
15 | @staticmethod
16 | def backward(ctx, rel_out):
17 | a, b, out = ctx.saved_tensors
18 | if ((out == 0) & (rel_out > 0)).any():
19 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
20 | rel_a = torch.where(out != 0, rel_out * a / out, out.new_tensor(0))
21 | rel_b = torch.where(out != 0, rel_out * b / out, out.new_tensor(0))
22 | return rel_a, rel_b
23 |
24 |
25 | def add(a, b):
26 | return AddRelevance.apply(a, b)
27 |
28 |
29 | class SumPooling(torch.autograd.Function):
30 | @staticmethod
31 | def forward(ctx, src, dim, keepdim):
32 | out = torch.sum(src, dim=dim, keepdim=keepdim)
33 | ctx.dim = dim
34 | ctx.keepdim = keepdim
35 | ctx.save_for_backward(src, out)
36 | return out
37 |
38 | @staticmethod
39 | def backward(ctx, rel_out):
40 | src, out = ctx.saved_tensors
41 | if ((out == 0) & (rel_out > 0)).any():
42 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
43 | rel_out = torch.where(out != 0, rel_out / out, out.new_tensor(0))
44 | if not ctx.keepdim and ctx.dim is not None:
45 | rel_out.unsqueeze_(ctx.dim)
46 | return rel_out * src, None, None
47 |
48 |
49 | def sum(tensor, dim=None, keepdim=False):
50 | return SumPooling.apply(tensor, dim, keepdim)
51 |
52 |
53 |
54 | class ScatterAddRelevance(torch.autograd.Function):
55 | @staticmethod
56 | def forward(ctx, src, out, idx, dim):
57 | torch_scatter.scatter_add(src, idx, dim=dim, out=out)
58 | ctx.dim = dim
59 | ctx.save_for_backward(src, idx, out)
60 | return out
61 |
62 | @staticmethod
63 | def backward(ctx, rel_out):
64 | src, idx, out = ctx.saved_tensors
65 | if ((out == 0) & (rel_out > 0)).any():
66 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
67 | rel_out = torch.where(out != 0, rel_out / out, out.new_tensor(0))
68 | rel_src = torch.index_select(rel_out, ctx.dim, idx) * src
69 | return rel_src, None, None, None
70 |
71 |
72 | def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
73 | src, out, _, dim = torch_scatter.utils.gen.gen(src, index, dim, out, dim_size, fill_value)
74 | return ScatterAddRelevance.apply(src, out, index, dim)
75 |
76 |
77 | class ScatterMeanRelevance(torch.autograd.Function):
78 | @staticmethod
79 | def forward(ctx, src, idx, dim, dim_size, fill_value):
80 | sums = torch_scatter.scatter_add(src, idx, dim, None, dim_size, fill_value)
81 | count = torch_scatter.scatter_add(torch.ones_like(src), idx, dim, None, dim_size, fill_value=0)
82 | out = sums / count.clamp(min=1)
83 | ctx.dim = dim
84 | ctx.save_for_backward(src, idx, sums)
85 | return out
86 |
87 | @staticmethod
88 | def backward(ctx, rel_out):
89 | src, idx, sums = ctx.saved_tensors
90 | if ((sums == 0) & (rel_out > 0)).any():
91 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
92 | rel_out = torch.where(sums != 0, rel_out / sums, sums.new_tensor(0))
93 | rel_src = torch.index_select(rel_out, ctx.dim, idx) * src
94 | return rel_src, None, None, None, None
95 |
96 |
97 | def scatter_mean(src, index, dim=-1, dim_size=None, fill_value=0):
98 | return ScatterMeanRelevance.apply(src, index, dim, dim_size, fill_value)
99 |
100 |
101 | class ScatterMaxRelevance(torch.autograd.Function):
102 | @staticmethod
103 | def forward(ctx, src, idx, dim, dim_size, fill_value):
104 | out, idx_maxes = torch_scatter.scatter_max(src, idx, dim=dim, dim_size=dim_size, fill_value=fill_value)
105 | ctx.dim = dim
106 | ctx.dim_size = src.shape[dim]
107 | ctx.save_for_backward(idx, out, idx_maxes)
108 | return out, idx_maxes
109 |
110 | @staticmethod
111 | def backward(ctx, rel_out, rel_idx_maxes):
112 | idx, out, idx_maxes = ctx.saved_tensors
113 | if ((out == 0) & (rel_out > 0)).any():
114 | warnings.warn('Relevance that is propagated back through an output of 0 will be lost')
115 | rel_out = torch.where(out != 0, rel_out, out.new_tensor(0))
116 |
117 | # Where idx_maxes==-1 set idx=0 so that the indexes are valid for scatter_add
118 | # The corresponding relevance should already be 0, but set it relevance=0 to be sure
119 | rel_out = torch.where(idx_maxes != -1, rel_out, torch.zeros_like(rel_out))
120 | idx_maxes = torch.where(idx_maxes != -1, idx_maxes, torch.zeros_like(idx_maxes))
121 |
122 | rel_src = torch_scatter.scatter_add(rel_out, idx_maxes, dim=ctx.dim, dim_size=ctx.dim_size)
123 | return rel_src, None, None, None, None
124 |
125 |
126 | def scatter_max(src, index, dim=-1, dim_size=None, fill_value=0):
127 | return ScatterMaxRelevance.apply(src, index, dim, dim_size, fill_value)
128 |
129 |
130 | class LinearEpsilonRelevance(torch.autograd.Function):
131 | eps = 1e-16
132 |
133 | @staticmethod
134 | def forward(ctx, input, weight, bias):
135 | Z = weight.t()[None, :, :] * input[:, :, None]
136 | Zs = Z.sum(dim=1, keepdim=True)
137 | if bias is not None:
138 | Zs += bias[None, None, :]
139 | ctx.save_for_backward(Z, Zs)
140 | return Zs.squeeze(dim=1)
141 |
142 | @staticmethod
143 | def backward(ctx, rel_out):
144 | Z, Zs = ctx.saved_tensors
145 | eps = rel_out.new_tensor(LinearEpsilonRelevance.eps)
146 | Zs += torch.where(Zs >= 0, eps, -eps)
147 | return (rel_out[:, None, :] * Z / Zs).sum(dim=2), None, None
148 |
149 |
150 | def linear_eps(input, weight, bias=None):
151 | return LinearEpsilonRelevance.apply(input, weight, bias)
152 |
153 |
154 | class IndexSelectRelevance(torch.autograd.Function):
155 | @staticmethod
156 | def forward(ctx, src, dim, idx):
157 | out = torch.index_select(src, dim, idx)
158 | ctx.dim = dim
159 | ctx.dim_size = src.shape[dim]
160 | ctx.save_for_backward(src, idx, out)
161 | return out
162 |
163 | @staticmethod
164 | def backward(ctx, rel_out):
165 | src, idx, out = ctx.saved_tensors
166 | return torch_scatter.scatter_add(rel_out, idx, dim=ctx.dim, dim_size=ctx.dim_size), None, None
167 |
168 |
169 | def index_select(src, dim, index):
170 | return IndexSelectRelevance.apply(src, dim, index)
171 |
172 |
173 | class CatRelevance(torch.autograd.Function):
174 | @staticmethod
175 | def forward(ctx, dim, *tensors):
176 | ctx.dim = dim
177 | ctx.sizes = [t.shape[dim] for t in tensors]
178 | return torch.cat(tensors, dim)
179 |
180 | @staticmethod
181 | def backward(ctx, rel_out):
182 | return (None, *torch.split_with_sizes(rel_out, dim=ctx.dim, split_sizes=ctx.sizes))
183 |
184 |
185 | def cat(tensors, dim=0):
186 | return CatRelevance.apply(dim, *tensors)
187 |
188 |
189 | def repeat_tensor(src, repeats, dim=0):
190 | idx = src.new_tensor(np.arange(len(repeats)).repeat(repeats.cpu().numpy()), dtype=torch.long)
191 | return torch.index_select(src, dim, idx)
192 |
193 |
194 | class ReLuRelevance(torch.autograd.Function):
195 | @staticmethod
196 | def forward(ctx, input):
197 | out = input.clamp(min=0)
198 | ctx.save_for_backward(out)
199 | return out
200 |
201 | @staticmethod
202 | def backward(ctx, rel_out):
203 | return rel_out
204 |
205 |
206 | def relu(input):
207 | return ReLuRelevance.apply(input)
208 |
209 |
210 | def get_aggregation(name):
211 | if name in ('add', 'sum'):
212 | return scatter_add
213 | elif name in ('mean', 'avg'):
214 | return scatter_mean
215 | elif name == 'max':
216 | from functools import wraps
217 |
218 | @wraps(scatter_max)
219 | def wrapper(*args, **kwargs):
220 | return scatter_max(*args, **kwargs)[0]
221 |
222 | return wrapper
223 |
--------------------------------------------------------------------------------
/src/relevance/graphs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchgraphs as tg
3 |
4 | from . import autograd_tricks as lrp
5 |
6 |
7 | class EdgeLinearRelevance(tg.EdgeLinear):
8 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
9 | new_edges = torch.tensor(0)
10 |
11 | if self.W_edge is not None:
12 | new_edges = lrp.add(new_edges, lrp.linear_eps(graphs.edge_features, self.W_edge))
13 | if self.W_sender is not None:
14 | new_edges = lrp.add(
15 | new_edges,
16 | lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_sender),
17 | dim=0, index=graphs.senders)
18 | )
19 | if self.W_receiver is not None:
20 | new_edges = lrp.add(
21 | new_edges,
22 | lrp.index_select(lrp.linear_eps(graphs.node_features, self.W_receiver),
23 | dim=0, index=graphs.receivers)
24 | )
25 | if self.W_global is not None:
26 | new_edges = lrp.add(
27 | new_edges,
28 | lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global),
29 | dim=0, repeats=graphs.num_edges_by_graph)
30 | )
31 | if self.bias is not None:
32 | new_edges = lrp.add(new_edges, self.bias)
33 |
34 | return graphs.evolve(edge_features=new_edges)
35 |
36 |
37 | class NodeLinearRelevance(tg.NodeLinear):
38 | def __init__(self, out_features, node_features=None, incoming_features=None, outgoing_features=None,
39 | global_features=None, aggregation=None, bias=True):
40 | super(NodeLinearRelevance, self).__init__(out_features, node_features, incoming_features,
41 | outgoing_features, global_features, lrp.get_aggregation(aggregation),
42 | bias)
43 |
44 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
45 | new_nodes = torch.tensor(0)
46 |
47 | if self.W_node is not None:
48 | new_nodes = lrp.add(
49 | new_nodes,
50 | lrp.linear_eps(graphs.node_features, self.W_node)
51 | )
52 | if self.W_incoming is not None:
53 | new_nodes = lrp.add(
54 | new_nodes,
55 | lrp.linear_eps(
56 | self.aggregation(graphs.edge_features, dim=0, index=graphs.receivers, dim_size=graphs.num_nodes),
57 | self.W_incoming)
58 | )
59 | if self.W_outgoing is not None:
60 | new_nodes = lrp.add(
61 | new_nodes,
62 | lrp.linear_eps(
63 | self.aggregation(graphs.edge_features, dim=0, index=graphs.senders, dim_size=graphs.num_nodes),
64 | self.W_outgoing)
65 | )
66 | if self.W_global is not None:
67 | new_nodes = lrp.add(
68 | new_nodes,
69 | lrp.repeat_tensor(lrp.linear_eps(graphs.global_features, self.W_global), dim=0,
70 | repeats=graphs.num_nodes_by_graph)
71 | )
72 | if self.bias is not None:
73 | new_nodes = lrp.add(new_nodes, self.bias)
74 |
75 | return graphs.evolve(node_features=new_nodes)
76 |
77 |
78 | class GlobalLinearRelevance(tg.GlobalLinear):
79 | def __init__(self, out_features, node_features=None, edge_features=None, global_features=None,
80 | aggregation=None, bias=True):
81 | super(GlobalLinearRelevance, self).__init__(out_features, node_features, edge_features,
82 | global_features, lrp.get_aggregation(aggregation), bias)
83 |
84 | def forward(self, graphs: tg.GraphBatch) -> tg.GraphBatch:
85 | new_globals = torch.tensor(0)
86 |
87 | if self.W_node is not None:
88 | index = tg.utils.segment_lengths_to_ids(graphs.num_nodes_by_graph)
89 | new_globals = lrp.add(
90 | new_globals,
91 | lrp.linear_eps(self.aggregation(graphs.node_features, dim=0, index=index, dim_size=graphs.num_graphs),
92 | self.W_node)
93 | )
94 | if self.W_edges is not None:
95 | index = tg.utils.segment_lengths_to_ids(graphs.num_edges_by_graph)
96 | new_globals = lrp.add(
97 | new_globals,
98 | lrp.linear_eps(self.aggregation(graphs.edge_features, dim=0, index=index, dim_size=graphs.num_graphs),
99 | self.W_edges)
100 | )
101 | if self.W_global is not None:
102 | new_globals = lrp.add(
103 | new_globals,
104 | lrp.linear_eps(graphs.global_features, self.W_global)
105 | )
106 | if self.bias is not None:
107 | new_globals = lrp.add(new_globals, self.bias)
108 |
109 | return graphs.evolve(global_features=new_globals)
110 |
111 |
112 | class EdgeReLURelevance(tg.EdgeFunction):
113 | def __init__(self):
114 | super(EdgeReLURelevance, self).__init__(lrp.relu)
115 |
116 |
117 | class NodeReLURelevance(tg.NodeFunction):
118 | def __init__(self):
119 | super(NodeReLURelevance, self).__init__(lrp.relu)
120 |
121 |
122 | class GlobalReLURelevance(tg.GlobalFunction):
123 | def __init__(self):
124 | super(GlobalReLURelevance, self).__init__(lrp.relu)
125 |
--------------------------------------------------------------------------------
/src/relevance/oldies.py:
--------------------------------------------------------------------------------
1 | """
2 | class DenseW2(RelevanceFunction):
3 | @staticmethod
4 | def forward_relevance(module, inputs, ctx):
5 | output = inputs @ module.weight.t()
6 | if module.bias is not None:
7 | output += module.bias
8 | return output
9 |
10 | @staticmethod
11 | def backward_relevance(module, relevance_outputs, ctx):
12 | return relevance_outputs @ (module.weight.pow(2) / (module.weight.pow(2).sum(dim=1, keepdim=True) + 10e-6))
13 |
14 |
15 | class DenseZPlus(RelevanceFunction):
16 | @staticmethod
17 | def forward_relevance(module, inputs, ctx):
18 | ctx['inputs'] = inputs
19 | output = inputs @ module.weight.t()
20 | if module.bias is not None:
21 | output += module.bias
22 | return output
23 |
24 | @staticmethod
25 | def backward_relevance(module, relevance_outputs, ctx):
26 | inputs = ctx['inputs']
27 | return inputs * (
28 | (relevance_outputs / (inputs @ module.weight.clamp(min=0).t() + 10e-6)) @
29 | module.weight.clamp(min=0)
30 | )
31 | """
32 |
--------------------------------------------------------------------------------
/src/relevance/patch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_scatter
3 | import torchgraphs as tg
4 |
5 | import textwrap
6 |
7 | from . import autograd_tricks as lrp
8 |
9 |
10 |
11 | def patch():
12 | torch.add = lrp.add
13 | torch.cat = lrp.cat
14 | torch.index_select = lrp.index_select
15 |
16 | tg.utils.repeat_tensor = lrp.repeat_tensor
17 |
18 | torch_scatter.scatter_add = lrp.scatter_add
19 | torch_scatter.scatter_mean = lrp.scatter_mean
20 | torch_scatter.scatter_max = lrp.scatter_max
21 |
22 | torch.nn.functional.linear = lrp.linear_eps
23 |
24 |
25 | def computational_graph(op):
26 | if op is None:
27 | return 'None'
28 | res = f'{op.__class__.__name__} at {hex(id(op))}:'
29 | if op.__class__.__name__ == 'AccumulateGrad':
30 | res += f'variable at {hex(id(op.variable))}'
31 | for op in op.next_functions:
32 | res += '\n-' + textwrap.indent(computational_graph(op[0]), ' ')
33 | return res
34 |
--------------------------------------------------------------------------------
/src/relevance_regression.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 |
4 | torch.manual_seed(0)
5 |
6 | x = torch.rand(10000, 5) * 2 - 1
7 | y = (3 * x[:, 0] + 8 * torch.cos(3.14 * x[:, 2]) - 5 * torch.pow(x[:, 4], 2)).view(-1, 1)
8 |
9 | ds = torch.utils.data.TensorDataset(x, y)
10 | dl = torch.utils.data.DataLoader(ds, batch_size=1000, shuffle=True, pin_memory=True, num_workers=0)
11 |
12 | model = torch.nn.Sequential(
13 | torch.nn.Linear(5, 16), torch.nn.ReLU(),
14 | torch.nn.Linear(16, 32), torch.nn.ReLU(),
15 | torch.nn.Linear(32, 16), torch.nn.ReLU(),
16 | torch.nn.Linear(16, 4), torch.nn.ReLU(),
17 | torch.nn.Linear(4, 1)
18 | ).to('cuda')
19 |
20 | model.train()
21 | opt = torch.optim.SGD(model.parameters(), lr=.01)
22 | scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.7)
23 |
24 | for e in range(600):
25 | for xs, ys in dl:
26 | xs = xs.to('cuda')
27 | ys = ys.to('cuda')
28 |
29 | out = model(xs)
30 | model.zero_grad()
31 |
32 | loss = torch.nn.functional.mse_loss(out, ys)
33 | loss.backward()
34 |
35 | opt.step()
36 |
37 | scheduler.step()
38 | if e % 50 == 0:
39 | print(e, loss.item(), opt.param_groups[0]['lr'])
40 |
41 | torch.set_grad_enabled(False)
42 | model.cpu()
43 | model.eval()
44 |
45 | print(y[:10].squeeze())
46 | print(model(x[:10]).squeeze())
47 |
48 | from relevance import forward_relevance, backward_relevance
49 |
50 | ctx = {}
51 |
52 | y = forward_relevance(model, x[:10], ctx=ctx)
53 | print(y.squeeze())
54 | print()
55 | print(*ctx.items(), sep='\n\n')
56 | print()
57 |
58 | rel_out = y
59 | print(rel_out)
60 | print()
61 | rel = backward_relevance(model, rel_out, ctx=ctx)
62 | print(rel)
63 | print(rel / rel_out)
64 | print()
65 | print(rel.abs().mean(dim=0))
66 |
67 | print()
68 | rel_out = torch.ones_like(y)
69 | print(rel_out)
70 | print()
71 | rel = backward_relevance(model, rel_out, ctx=ctx)
72 | print(rel)
73 | print()
74 | print(rel.mean(dim=0))
75 |
76 | print(torch.allclose(rel_out.sum(dim=1), rel.sum(dim=1)))
77 |
--------------------------------------------------------------------------------
/src/saver.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Union
4 |
5 | import pyaml
6 | import torch
7 |
8 |
9 | class Saver(object):
10 | def __init__(self, folder: Union[str, Path]):
11 | self.base_folder = Path(folder).expanduser().resolve()
12 | self.checkpoint_folder = self.base_folder / 'checkpoints'
13 | self.checkpoint_folder.mkdir(parents=True, exist_ok=True)
14 |
15 | def save_model(self, model, suffix=None, is_best=False):
16 | if isinstance(model, torch.nn.DataParallel):
17 | model = model.module
18 | name = 'model.pt' if suffix is None else f'model.{suffix}.pt'
19 | model_path = self.checkpoint_folder / name
20 | torch.save(model.state_dict(), model_path)
21 |
22 | latest_path = self.base_folder / 'model.latest.pt'
23 | if latest_path.exists():
24 | os.unlink(latest_path)
25 | os.link(model_path, latest_path)
26 |
27 | if is_best:
28 | best_path = self.base_folder / 'model.best.pt'
29 | if best_path.exists():
30 | os.unlink(latest_path)
31 | os.link(model_path, best_path)
32 |
33 | return model_path.as_posix()
34 |
35 | def save_optimizer(self, optimizer, suffix=None):
36 | name = 'optimizer.pt' if suffix is None else f'optimizer.{suffix}.pt'
37 | optimizer_path = self.checkpoint_folder / name
38 | torch.save(optimizer.state_dict(), optimizer_path)
39 |
40 | latest_path = self.base_folder / 'optimizer.latest.pt'
41 | if latest_path.exists():
42 | os.unlink(latest_path)
43 | os.link(optimizer_path, latest_path)
44 |
45 | return optimizer_path.as_posix()
46 |
47 | def save_experiment(self, experiment, suffix=None):
48 | name = 'experiment.yaml' if suffix is None else f'experiment.{suffix}.yaml'
49 | experiment_path = self.checkpoint_folder / name
50 | with open(experiment_path, 'w') as f:
51 | pyaml.dump(experiment, f, safe=True, sort_dicts=False)
52 |
53 | latest_path = self.base_folder / 'experiment.latest.yaml'
54 | if latest_path.exists():
55 | os.unlink(latest_path)
56 | os.link(experiment_path, latest_path)
57 |
58 | return experiment_path.as_posix()
59 |
60 | def save(self, model, experiment, optimizer, suffix=None, is_best=False):
61 | experiment.model.state_dict = self.save_model(model, suffix=suffix, is_best=is_best)
62 | experiment.optimizer.state_dict = self.save_optimizer(optimizer, suffix=suffix)
63 | return {
64 | 'model': experiment.model.state_dict,
65 | 'optimizer': experiment.optimizer.state_dict,
66 | 'experiment': self.save_experiment(experiment, suffix=suffix)
67 | }
68 |
--------------------------------------------------------------------------------
/src/solubility/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/baldassarreFe/graph-network-explainability/39f8dd1fa245545f54a9f29582d0cf85c681edec/src/solubility/__init__.py
--------------------------------------------------------------------------------
/src/solubility/dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Tuple
3 |
4 | import torch
5 | import pandas as pd
6 | from pandas.api.types import CategoricalDtype
7 |
8 | from rdkit import Chem
9 |
10 | import torchgraphs as tg
11 |
12 | symbols = CategoricalDtype([
13 | 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
14 | 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb',
15 | 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', # H?
16 | 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',
17 | 'Cr', 'Pt', 'Hg', 'Pb', 'Unknown'
18 | ], ordered=True)
19 |
20 | bonds = CategoricalDtype([
21 | 'SINGLE',
22 | 'DOUBLE',
23 | 'TRIPLE',
24 | 'AROMATIC'
25 | ], ordered=True)
26 |
27 |
28 | def smiles_to_graph(smiles: str) -> tg.Graph:
29 | molecule = Chem.MolFromSmiles(smiles)
30 |
31 | atoms_df = []
32 | for i in range(molecule.GetNumAtoms()):
33 | atom = molecule.GetAtomWithIdx(i)
34 | atoms_df.append({
35 | 'index': i,
36 | 'symbol': atom.GetSymbol(),
37 | 'degree': atom.GetDegree(),
38 | 'hydrogens': atom.GetTotalNumHs(),
39 | 'impl_valence': atom.GetImplicitValence(),
40 | })
41 | atoms_df = pd.DataFrame.from_records(atoms_df, index='index',
42 | columns=['index', 'symbol', 'degree', 'hydrogens', 'impl_valence'])
43 | atoms_df.symbol = atoms_df.symbol.astype(symbols)
44 |
45 | node_features = torch.tensor(pd.get_dummies(atoms_df, columns=['symbol']).values, dtype=torch.float)
46 |
47 | bonds_df = []
48 | for bond in molecule.GetBonds():
49 | bonds_df.append({
50 | 'sender': bond.GetBeginAtomIdx(),
51 | 'receiver': bond.GetEndAtomIdx(),
52 | 'type': bond.GetBondType().name,
53 | 'conj': bond.GetIsConjugated(),
54 | 'ring': bond.IsInRing()
55 | })
56 | bonds_df.append({
57 | 'sender': bond.GetEndAtomIdx(),
58 | 'receiver': bond.GetBeginAtomIdx(),
59 | 'type': bond.GetBondType().name,
60 | 'conj': bond.GetIsConjugated(),
61 | 'ring': bond.IsInRing()
62 | })
63 | bonds_df = pd.DataFrame.from_records(bonds_df, columns=['sender', 'receiver', 'type', 'conj', 'ring'])\
64 | .set_index(['sender', 'receiver'])
65 | bonds_df.conj = bonds_df.conj * 2. - 1
66 | bonds_df.ring = bonds_df.ring * 2. - 1
67 | bonds_df.type = bonds_df.type.astype(bonds)
68 |
69 | edge_features = torch.tensor(pd.get_dummies(bonds_df, columns=['type']).values.astype(float), dtype=torch.float)
70 | senders = torch.tensor(bonds_df.index.get_level_values('sender'), dtype=torch.long)
71 | receivers = torch.tensor(bonds_df.index.get_level_values('receiver'), dtype=torch.long)
72 |
73 | return tg.Graph(
74 | num_nodes=molecule.GetNumAtoms(),
75 | num_edges=molecule.GetNumBonds() * 2,
76 | node_features=node_features,
77 | edge_features=edge_features,
78 | senders=senders,
79 | receivers=receivers
80 | )
81 |
82 |
83 | class SolubilityDataset(torch.utils.data.Dataset):
84 | def __init__(self, path):
85 | self.df = pd.read_csv(path)
86 | # self.df['molecules'] = self.df.smiles.apply(smiles_to_graph)
87 |
88 | def __len__(self):
89 | return len(self.df)
90 |
91 | def __getitem__(self, item) -> Tuple[tg.Graph, float]:
92 | mol = smiles_to_graph(self.df['smiles'].iloc[item])
93 | target = self.df['measured log solubility in mols per litre'].iloc[item]
94 | return mol, torch.tensor(target)
95 |
96 |
97 | def describe(cfg):
98 | pd.options.display.precision = 2
99 | pd.options.display.max_columns = 999
100 | pd.options.display.expand_frame_repr = False
101 | target = Path(cfg.target).expanduser().resolve()
102 | if target.is_dir():
103 | paths = target.glob('*.pt')
104 | else:
105 | paths = [target]
106 | for p in paths:
107 | print(f"Loading dataset from: {p}")
108 | dataset = SolubilityDataset(p)
109 | print(f"{p.with_suffix('').name.capitalize()} contains:\n"
110 | f"{dataset.df.drop(columns=['molecules']).describe().transpose()}")
111 |
112 |
113 | def main():
114 | from argparse import ArgumentParser
115 | from config import Config
116 |
117 | parser = ArgumentParser()
118 | subparsers = parser.add_subparsers()
119 |
120 | sp_print = subparsers.add_parser('print', help='Print parsed configuration')
121 | sp_print.add_argument('config', nargs='*')
122 | sp_print.set_defaults(command=lambda c: print(c.toYAML()))
123 |
124 | sp_describe = subparsers.add_parser('describe', help='Describe existing datasets')
125 | sp_describe.add_argument('config', nargs='*')
126 | sp_describe.set_defaults(command=describe)
127 |
128 | args = parser.parse_args()
129 | cfg = Config.build(*args.config)
130 | args.command(cfg)
131 |
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/src/solubility/layout.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import tensorflow as tf
5 |
6 | from tensorboard import summary as summary_lib
7 | from tensorboard.plugins.custom_scalar import layout_pb2
8 |
9 | layout_summary = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=[
10 | layout_pb2.Category(
11 | title='losses',
12 | chart=[
13 | # Chart 'losses' (include all losses, exclude upper and lower bounds)
14 | layout_pb2.Chart(
15 | title='losses',
16 | multiline=layout_pb2.MultilineChartContent(
17 | tag=[
18 | r'loss(?!.*bound.*)'
19 | ]
20 | )
21 | ),
22 | ])
23 | ]))
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('folder', help='The log folder to place the layout in')
27 | args = parser.parse_args()
28 |
29 | folder = (Path(args.folder) / 'layout').expanduser().resolve()
30 | with tf.summary.FileWriter(folder) as writer:
31 | writer.add_summary(layout_summary)
32 |
33 | print('Layout saved to', folder)
34 |
--------------------------------------------------------------------------------
/src/solubility/networks.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 |
6 | import torchgraphs as tg
7 |
8 |
9 | def build_network(num_hidden):
10 | return SolubilityGN(num_hidden)
11 |
12 |
13 | class SolubilityGN(nn.Module):
14 | def __init__(self, num_layers, hidden_bias, hidden_node, dropout, aggregation):
15 | super().__init__()
16 |
17 | hidden_edge = hidden_node // 4
18 | hidden_global = hidden_node // 8
19 |
20 | self.encoder = nn.Sequential(OrderedDict({
21 | 'edge': tg.EdgeLinear(hidden_edge, edge_features=6),
22 | 'edge_relu': tg.EdgeReLU(),
23 | 'node': tg.NodeLinear(hidden_node, node_features=47),
24 | 'node_relu': tg.NodeReLU(),
25 | 'global': tg.GlobalLinear(hidden_global, node_features=hidden_node,
26 | edge_features=hidden_edge, aggregation=aggregation),
27 | 'global_relu': tg.GlobalReLU(),
28 | }))
29 | if dropout:
30 | self.hidden = nn.Sequential(OrderedDict({
31 | f'hidden_{i}': nn.Sequential(OrderedDict({
32 | 'edge': tg.EdgeLinear(hidden_edge, edge_features=hidden_edge,
33 | sender_features=hidden_node, bias=hidden_bias),
34 | 'edge_relu': tg.EdgeReLU(),
35 | 'edge_dropout': tg.EdgeDroput(),
36 | 'node': tg.NodeLinear(hidden_node, node_features=hidden_node, incoming_features=hidden_edge,
37 | aggregation=aggregation, bias=hidden_bias),
38 | 'node_relu': tg.NodeReLU(),
39 | 'node_dropout': tg.EdgeDroput(),
40 | 'global': tg.GlobalLinear(hidden_global, node_features=hidden_node, edge_features=hidden_edge,
41 | global_features=hidden_global, aggregation=aggregation, bias=hidden_bias),
42 | 'global_relu': tg.GlobalReLU(),
43 | 'global_dropout': tg.EdgeDroput(),
44 | }))
45 | for i in range(num_layers)
46 | }))
47 | else:
48 | self.hidden = nn.Sequential(OrderedDict({
49 | f'hidden_{i}': nn.Sequential(OrderedDict({
50 | 'edge': tg.EdgeLinear(hidden_edge, edge_features=hidden_edge,
51 | sender_features=hidden_node, bias=hidden_bias),
52 | 'edge_relu': tg.EdgeReLU(),
53 | 'node': tg.NodeLinear(hidden_node, node_features=hidden_node, incoming_features=hidden_edge,
54 | aggregation=aggregation, bias=hidden_bias),
55 | 'node_relu': tg.NodeReLU(),
56 | 'global': tg.GlobalLinear(hidden_global, node_features=hidden_node, edge_features=hidden_edge,
57 | global_features=hidden_global, aggregation=aggregation, bias=hidden_bias),
58 | 'global_relu': tg.GlobalReLU(),
59 | }))
60 | for i in range(num_layers)
61 | }))
62 | self.readout_globals = tg.GlobalLinear(1, global_features=hidden_global, bias=True)
63 |
64 | def forward(self, graphs):
65 | graphs = self.encoder(graphs)
66 | graphs = self.hidden(graphs)
67 | globals = self.readout_globals(graphs).global_features
68 |
69 | return graphs.evolve(
70 | num_nodes=0,
71 | node_features=None,
72 | num_nodes_by_graph=None,
73 | num_edges=0,
74 | num_edges_by_graph=None,
75 | edge_features=None,
76 | global_features=globals,
77 | senders=None,
78 | receivers=None
79 | )
80 |
81 |
82 | def describe(cfg):
83 | from pathlib import Path
84 | from utils import import_
85 | klass = import_(cfg.model.klass)
86 | model = klass(*cfg.model.args, **cfg.model.kwargs)
87 | if 'state_dict' in cfg:
88 | model.load_state_dict(torch.load(Path(cfg.state_dict).expanduser().resolve()))
89 | print(model)
90 | print(f'Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
91 |
92 | for name, parameter in model.named_parameters():
93 | print(f'{name} {tuple(parameter.shape)}:')
94 | if 'state_dict' in cfg:
95 | print(parameter.numpy().round())
96 | print()
97 |
98 |
99 | def main():
100 | from argparse import ArgumentParser
101 | from config import Config
102 |
103 | parser = ArgumentParser()
104 | subparsers = parser.add_subparsers()
105 |
106 | sp_print = subparsers.add_parser('print', help='Print parsed configuration')
107 | sp_print.add_argument('config', nargs='*')
108 | sp_print.set_defaults(command=lambda c: print(c.toYAML()))
109 |
110 | sp_describe = subparsers.add_parser('describe', help='Describe a model')
111 | sp_describe.add_argument('config', nargs='*')
112 | sp_describe.set_defaults(command=describe)
113 |
114 | args = parser.parse_args()
115 | cfg = Config.build(*args.config)
116 | args.command(cfg)
117 |
118 |
119 | if __name__ == '__main__':
120 | main()
121 |
--------------------------------------------------------------------------------
/src/solubility/notes.md:
--------------------------------------------------------------------------------
1 | # Solubility
2 |
3 | ## Task
4 | We want to predict the water solubility (in `log mol/L`) of the organic molecules from their molecular structure.
5 |
6 | ## Data
7 | The molecules in [this dataset](../../data/delaney-processed.csv) are loaded and parsed using RDKit.
8 | Train/test split is 70/30.
9 |
10 | ## Losses
11 |
12 | Two losses are used for training: a global-level loss and a regularization term.
13 | The two terms are added together in a weighted sum and constitute the final training objective.
14 |
15 | ### Global-level regression
16 | The network should also output a global-level prediction corresponding to the log solubility of the molecule.
17 | The loss on this prediction is computed as Mean Squared Error
18 |
19 | No weighting is used to account for more/less common values.
20 |
21 | ### L1 regularization
22 | The weights of the network are regularized with L1 regularization.
23 |
24 | ## Workflow
25 |
26 | 1. Create base folder
27 | ```bash
28 | SOLUBILITY=~/experiments/solubility/
29 | mkdir -p "$SOLUBILITY/"{runs,data}
30 | ```
31 | 2. Launch one experiment (from the root of the repo):
32 | ```bash
33 | python -m solubility.train --experiment config/solubility/train.yaml
34 | ```
35 | Or make a grid search over the hyperparameters (from the root of the repo):
36 | ```bash
37 | conda activate tg-experiments
38 | function train {
39 | python -m solubility.train \
40 | --experiment config/solubility/train.yaml \
41 | "tags=[layers${3},lr${1},bias${4},size${5},wd${2},dr${7},e${6},${8}]" \
42 | --model "kwargs.num_layers=${3}" "kwargs.hidden_bias=${4}" "kwargs.hidden_node=${5}" "kwargs.dropout=${6}" "kwargs.aggregation=${8}"\
43 | --optimizer "kwargs.lr=${1}" \
44 | --session "losses.l1=${2}" "epochs=${6}"
45 | }
46 | export -f train # use bash otherwise `export -f` won't work
47 | parallel --verbose --max-procs 6 --load 200% --delay 1 --noswap \
48 | 'train {1} {2} {3} {4} {5} {6} {7} {8}' \
49 | `# Learning rate` ::: .01 .001 \
50 | `# L1 loss` ::: .0001 .001 .01 \
51 | `# Hidden layers` ::: 3 4 5 10 \
52 | `# Hidden bias` ::: yes no \
53 | `# Hidden node` ::: 32 64 128 256 512 \
54 | `# Epochs` ::: 50 75 100 \
55 | `# Dropout` ::: yes no \
56 | `# Aggregation` ::: mean sum
57 | ```
58 |
59 | 6. Query logs and visualize
60 | - Tensorboard: `tensorboard --logdir "$SOLUBILITY/runs"`
61 | - Find best model
62 | ```bash
63 | for f in */experiment.latest.yaml; do
64 | echo -e $(grep loss_sol_val < $f) $(dirname $f)
65 | done | sort -k 3 -g -r | cut -f2,3 | tail -n 5
66 | ```
67 |
--------------------------------------------------------------------------------
/src/solubility/predict.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import yaml
3 | import pyaml
4 | import multiprocessing
5 | from pathlib import Path
6 | from argparse import ArgumentParser
7 |
8 | import numpy as np
9 | from munch import AutoMunch, Munch
10 |
11 | import torch
12 | import torch.utils.data
13 | import torch.nn.functional as F
14 | import torchgraphs as tg
15 |
16 | from utils import parse_dotted, update_rec, import_
17 | from .dataset import InfectionDataset
18 |
19 | parser = ArgumentParser()
20 | parser.add_argument('--model', nargs='+', required=True)
21 | parser.add_argument('--data', nargs='+', required=True, default=[])
22 | parser.add_argument('--options', nargs='+', required=False, default=[])
23 | parser.add_argument('--output', type=str, required=True)
24 |
25 | args = parser.parse_args()
26 |
27 |
28 | # region Collecting phase
29 |
30 | # Defaults
31 | model = Munch(fn=None, args=[], kwargs={}, state_dict=None)
32 | data = []
33 | options = AutoMunch()
34 | options.cpus = multiprocessing.cpu_count() - 1
35 | options.device = 'cuda' if torch.cuda.is_available() else 'cpu'
36 | options.output = args.output
37 |
38 | # Model from --model args
39 | for string in args.model:
40 | if '=' in string:
41 | update = parse_dotted(string)
42 | else:
43 | with open(string, 'r') as f:
44 | update = yaml.safe_load(f)
45 | # If the yaml file contains an entry with key `model` use that one instead
46 | if 'model' in update.keys():
47 | update = update['model']
48 | update_rec(model, update)
49 |
50 | # Data from --data args
51 | for path in args.data:
52 | path = Path(path).expanduser().resolve()
53 | if path.is_dir():
54 | data.extend(path.glob('*.pt'))
55 | elif path.is_file() and path.suffix == '.pt':
56 | data.append(path)
57 | else:
58 | raise ValueError(f'Invalid data: {path}')
59 |
60 | # Options from --options args
61 | for string in args.options:
62 | if '=' in string:
63 | update = parse_dotted(string)
64 | else:
65 | with open(string, 'r') as f:
66 | update = yaml.safe_load(f)
67 | update_rec(options, update)
68 |
69 | # Resolving paths
70 | model.state_dict = Path(model.state_dict).expanduser().resolve()
71 | options.output = Path(options.output).expanduser().resolve()
72 |
73 | # Checks (some missing, others redundant)
74 | if model.fn is None:
75 | raise ValueError('Model constructor function not defined')
76 | if model.state_dict is None:
77 | raise ValueError(f'Model state dict is required to predict')
78 | if len(data) == 0:
79 | raise ValueError(f'No data to predict')
80 | if options.cpus < 0:
81 | raise ValueError(f'Invalid number of cpus: {options.cpus}')
82 | if options.output.exists() and not options.output.is_dir():
83 | raise ValueError(f'Invalid output path {options.output}')
84 |
85 |
86 | pyaml.pprint({'model': model, 'options': options, 'data': data}, sort_dicts=False, width=200)
87 | # endregion
88 |
89 | # region Building phase
90 | # Model
91 | net: torch.nn.Module = import_(model.fn)(*model.args, **model.kwargs)
92 | net.load_state_dict(torch.load(model.state_dict))
93 | net.to(options.device)
94 |
95 | # Output folder
96 | options.output.mkdir(parents=True, exist_ok=True)
97 | # endregion
98 |
99 | # region Training
100 | # Dataset and dataloader
101 | dataset_predict: InfectionDataset = torch.load(data[0])
102 | dataloader_predict = torch.utils.data.DataLoader(
103 | dataset_predict,
104 | shuffle=False,
105 | num_workers=min(options.cpus, 1) if 'cuda' in options.device else options.cpus,
106 | pin_memory='cuda' in options.device,
107 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)),
108 | batch_size=options.batch_size,
109 | collate_fn=tg.GraphBatch.collate,
110 | )
111 |
112 | # region Predict
113 | net.eval()
114 | torch.set_grad_enabled(False)
115 | i = 0
116 | with tqdm.tqdm(desc='Predict', total=len(dataloader_predict.dataset), unit='g') as bar:
117 | for graphs, *_ in dataloader_predict:
118 | graphs = graphs.to(options.device)
119 |
120 | results = net(graphs)
121 | results.node_features.sigmoid_()
122 |
123 | for result in results:
124 | torch.save(result.cpu(), options.output / f'output_{i:06d}.pt')
125 | i += 1
126 |
127 | bar.update(graphs.num_graphs)
128 | # endregion
129 |
--------------------------------------------------------------------------------
/src/solubility/train.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import yaml
3 | import pyaml
4 | import random
5 | import textwrap
6 | import multiprocessing
7 | from pathlib import Path
8 | from datetime import datetime
9 | from argparse import ArgumentParser
10 |
11 | import numpy as np
12 | import pandas as pd
13 | from munch import AutoMunch
14 |
15 | import torch
16 | import torch.utils.data
17 | import torch.nn.functional as F
18 | import torchgraphs as tg
19 | from tensorboardX import SummaryWriter
20 |
21 | from saver import Saver
22 | from utils import git_info, cuda_info, parse_dotted, update_rec, set_seeds, import_, sort_dict, RunningWeightedAverage
23 | from .dataset import SolubilityDataset
24 |
25 | parser = ArgumentParser()
26 | parser.add_argument('--experiment', nargs='+', required=True)
27 | parser.add_argument('--model', nargs='+', required=False, default=[])
28 | parser.add_argument('--optimizer', nargs='+', required=False, default=[])
29 | parser.add_argument('--session', nargs='+', required=False, default=[])
30 |
31 | args = parser.parse_args()
32 |
33 |
34 | # region Collecting phase
35 | class Experiment(AutoMunch):
36 | @property
37 | def session(self):
38 | return self.sessions[-1]
39 |
40 |
41 | experiment = Experiment()
42 |
43 | # Experiment defaults
44 | experiment.name = 'experiment'
45 | experiment.tags = []
46 | experiment.samples = 0
47 | experiment.model = {'fn': None, 'args': [], 'kwargs': {}}
48 | experiment.optimizer = {'fn': None, 'args': [], 'kwargs': {}}
49 | experiment.sessions = []
50 |
51 | # Session defaults
52 | session = AutoMunch()
53 | session.losses = {'solubility': 0, 'l1': 0}
54 | session.seed = random.randint(0, 99)
55 | session.cpus = multiprocessing.cpu_count() - 1
56 | session.device = 'cuda' if torch.cuda.is_available() else 'cpu'
57 | session.log = {'when': []}
58 | session.checkpoint = {'when': []}
59 |
60 | # Experiment configuration
61 | for string in args.experiment:
62 | if '=' in string:
63 | update = parse_dotted(string)
64 | else:
65 | with open(string, 'r') as f:
66 | update = yaml.safe_load(f)
67 | # If the current session is defined inside the experiment update the session instead
68 | if 'session' in update:
69 | update_rec(session, update.pop('session'))
70 | update_rec(experiment, update)
71 |
72 | # Model from --model args
73 | for string in args.model:
74 | if '=' in string:
75 | update = parse_dotted(string)
76 | else:
77 | with open(string, 'r') as f:
78 | update = yaml.safe_load(f)
79 | # If the yaml document contains a single entry with key `model` use that one instead
80 | if update.keys() == {'model'}:
81 | update = update['model']
82 | update_rec(experiment.model, update)
83 | del update
84 |
85 | # Optimizer from --optimizer args
86 | for string in args.optimizer:
87 | if '=' in string:
88 | update = parse_dotted(string)
89 | else:
90 | with open(string, 'r') as f:
91 | update = yaml.safe_load(f)
92 | # If the yaml document contains a single entry with key `optimizer` use that one instead
93 | if update.keys() == {'optimizer'}:
94 | update = update['optimizer']
95 | update_rec(experiment.optimizer, update)
96 | del update
97 |
98 | # Session from --session args
99 | for string in args.session:
100 | if '=' in string:
101 | update = parse_dotted(string)
102 | else:
103 | with open(string, 'r') as f:
104 | update = yaml.safe_load(f)
105 | # If the yaml document contains a single entry with key `session` use that one instead
106 | if update.keys() == {'session'}:
107 | update = update['session']
108 | update_rec(session, update)
109 | del update
110 |
111 | # Checks (some missing, others redundant)
112 | if experiment.name is None or len(experiment.name) == 0:
113 | raise ValueError(f'Experiment name is empty: {experiment.name}')
114 | if experiment.tags is None:
115 | raise ValueError('Experiment tags is None')
116 | if experiment.model.fn is None:
117 | raise ValueError('Model constructor function not defined')
118 | if experiment.optimizer.fn is None:
119 | raise ValueError('Optimizer constructor function not defined')
120 | if session.cpus < 0:
121 | raise ValueError(f'Invalid number of cpus: {session.cpus}')
122 | if any(l < 0 for l in session.losses.values()) or all(l == 0 for l in session.losses.values()):
123 | raise ValueError(f'Invalid losses: {session.losses}')
124 | if len(experiment.sessions) > 0 and ('state_dict' not in experiment.model or 'state_dict' not in experiment.optimizer):
125 | raise ValueError(f'Model and optimizer state dicts are required to restore training')
126 |
127 | # Experiment computed fields
128 | experiment.epoch = sum((s.epochs for s in experiment.sessions), 0)
129 |
130 | # Session computed fields
131 | session.status = 'NEW'
132 | session.datetime_started = None
133 | session.datetime_completed = None
134 | git = git_info()
135 | if git is not None:
136 | session.git = git
137 | if 'cuda' in session.device:
138 | session.cuda = cuda_info()
139 |
140 | # Resolving paths
141 | rand_id = ''.join(chr(random.randint(ord('A'), ord('Z'))) for _ in range(6))
142 | session.data.path = Path(session.data.path.replace('{name}', experiment.name)).expanduser().resolve().as_posix()
143 | session.log.folder = session.log.folder \
144 | .replace('{name}', experiment.name) \
145 | .replace('{tags}', '_'.join(experiment.tags)) \
146 | .replace('{rand}', rand_id)
147 | if len(session.checkpoint.when) > 0:
148 | if len(session.log.when) > 0:
149 | session.log.folder = Path(session.log.folder).expanduser().resolve().as_posix()
150 | session.checkpoint.folder = session.checkpoint.folder \
151 | .replace('{name}', experiment.name) \
152 | .replace('{tags}', '_'.join(experiment.tags)) \
153 | .replace('{rand}', rand_id)
154 | session.checkpoint.folder = Path(session.checkpoint.folder).expanduser().resolve().as_posix()
155 | if 'state_dict' in experiment.model:
156 | experiment.model.state_dict = Path(experiment.model.state_dict).expanduser().resolve().as_posix()
157 | if 'state_dict' in experiment.optimizer:
158 | experiment.optimizer.state_dict = Path(experiment.optimizer.state_dict).expanduser().resolve().as_posix()
159 |
160 | sort_dict(experiment, ['name', 'tags', 'epoch', 'samples', 'model', 'optimizer', 'sessions'])
161 | sort_dict(session, ['epochs', 'batch_size', 'losses', 'seed', 'cpus', 'device', 'samples', 'status',
162 | 'datetime_started', 'datetime_completed', 'data', 'log', 'checkpoint', 'git', 'gpus'])
163 | experiment.sessions.append(session)
164 | pyaml.pprint(experiment, sort_dicts=False, width=200)
165 | del session
166 | # endregion
167 |
168 | # region Building phase
169 | # Seeds (set them after the random run id is generated)
170 | set_seeds(experiment.session.seed)
171 |
172 | # Model
173 | model: torch.nn.Module = import_(experiment.model.fn)(*experiment.model.args, **experiment.model.kwargs)
174 | if 'state_dict' in experiment.model:
175 | model.load_state_dict(torch.load(experiment.model.state_dict))
176 | model.to(experiment.session.device)
177 |
178 | # Optimizer
179 | optimizer: torch.optim.Optimizer = import_(experiment.optimizer.fn)(
180 | model.parameters(), *experiment.optimizer.args, **experiment.optimizer.kwargs)
181 | if 'state_dict' in experiment.optimizer:
182 | optimizer.load_state_dict(torch.load(experiment.optimizer.state_dict))
183 |
184 | # Logger
185 | if len(experiment.session.log.when) > 0:
186 | logger = SummaryWriter(experiment.session.log.folder)
187 | logger.add_text(
188 | 'Experiment', textwrap.indent(pyaml.dump(experiment, safe=True, sort_dicts=False), ' '), experiment.samples)
189 | else:
190 | logger = None
191 |
192 | # Saver
193 | if len(experiment.session.checkpoint.when) > 0:
194 | saver = Saver(experiment.session.checkpoint.folder)
195 | if experiment.epoch == 0:
196 | saver.save_experiment(experiment, suffix=f'e{experiment.epoch:04d}')
197 | else:
198 | saver = None
199 | # endregion
200 |
201 | # Datasets and dataloaders
202 | dataset = SolubilityDataset(experiment.session.data.path)
203 | dataloader_kwargs = dict(
204 | num_workers=min(experiment.session.cpus, 1) if 'cuda' in experiment.session.device else experiment.session.cpus,
205 | pin_memory='cuda' in experiment.session.device,
206 | worker_init_fn=lambda _: np.random.seed(int(torch.initial_seed()) % (2 ** 32 - 1)),
207 | batch_size=experiment.session.batch_size,
208 | collate_fn=tg.GraphBatch.collate,
209 | )
210 |
211 | dataset_train = torch.utils.data.Subset(
212 | dataset, indices=np.arange(0, int(np.floor(experiment.session.data.train * len(dataset)))))
213 | dataloader_train = torch.utils.data.DataLoader(
214 | dataset_train,
215 | shuffle=True,
216 | **dataloader_kwargs
217 | )
218 |
219 | dataset_val = torch.utils.data.Subset(
220 | dataset, indices=np.arange(int(np.floor(experiment.session.data.train * len(dataset))), len(dataset)))
221 | dataloader_val = torch.utils.data.DataLoader(
222 | dataset_val,
223 | shuffle=False,
224 | **dataloader_kwargs
225 | )
226 | del dataset, dataloader_kwargs
227 |
228 | # region Training
229 | # Train and validation loops
230 | experiment.session.status = 'RUNNING'
231 | experiment.session.datetime_started = datetime.utcnow()
232 |
233 | graphs_df = {k: [] for k in ['LossSolubility', 'Nodes', 'Edges', 'Pred', 'Real']}
234 |
235 | epoch_bar_postfix = {}
236 | epoch_bar = tqdm.trange(1, experiment.session.epochs + 1, desc='Epochs', unit='e', leave=True)
237 | for epoch_idx in epoch_bar:
238 | experiment.epoch += 1
239 |
240 | # region Training loop
241 | model.train()
242 | torch.set_grad_enabled(True)
243 |
244 | train_bar_postfix = {}
245 | loss_sol_avg = RunningWeightedAverage()
246 | loss_l1_avg = RunningWeightedAverage()
247 | loss_total_avg = RunningWeightedAverage()
248 |
249 | train_bar = tqdm.tqdm(desc=f'Train {experiment.epoch}', total=len(dataloader_train.dataset), unit='g')
250 | for graphs, targets in dataloader_train:
251 | graphs = graphs.to(experiment.session.device)
252 | targets = targets.to(experiment.session.device)
253 | results = model(graphs)
254 |
255 | loss_total = torch.tensor(0., device=experiment.session.device)
256 |
257 | if experiment.session.losses.solubility > 0:
258 | loss_sol = F.mse_loss(
259 | results.global_features.squeeze(), targets, reduction='mean')
260 | loss_total += experiment.session.losses.solubility * loss_sol
261 | loss_sol_avg.add(loss_sol.mean().item(), len(graphs))
262 | train_bar_postfix['Solubility'] = f'{loss_sol.item():.5f}'
263 |
264 | if 'every batch' in experiment.session.log.when:
265 | logger.add_scalar('loss/train/solubility', loss_sol.mean().item(), global_step=experiment.samples)
266 |
267 | if experiment.session.losses.l1 > 0:
268 | loss_l1 = sum([p.abs().sum() for p in model.parameters()])
269 | loss_total += experiment.session.losses.l1 * loss_l1
270 | loss_l1_avg.add(loss_l1.item(), len(graphs))
271 | train_bar_postfix['L1'] = f'{loss_l1.item():.5f}'
272 | if 'every batch' in experiment.session.log.when:
273 | logger.add_scalar('loss/train/l1', loss_l1.item(), global_step=experiment.samples)
274 |
275 | loss_total_avg.add(loss_total.item(), len(graphs))
276 | train_bar_postfix['Total'] = f'{loss_total.item():.5f}'
277 | if 'every batch' in experiment.session.log.when:
278 | logger.add_scalar('loss/train/total', loss_total.item(), global_step=experiment.samples)
279 |
280 | optimizer.zero_grad()
281 | loss_total.backward()
282 | optimizer.step(closure=None)
283 |
284 | experiment.samples += len(graphs)
285 | train_bar.update(len(graphs))
286 | train_bar.set_postfix(train_bar_postfix)
287 | train_bar.close()
288 |
289 | experiment.loss_sol_train = loss_sol_avg.get()
290 | epoch_bar_postfix['Train'] = f'{loss_total_avg.get():.4f}'
291 | epoch_bar.set_postfix(epoch_bar_postfix)
292 |
293 | if 'every epoch' in experiment.session.log.when and 'every batch' not in experiment.session.log.when:
294 | logger.add_scalar('loss/train/total', loss_total_avg.get(), global_step=experiment.samples)
295 | if experiment.session.losses.solubility > 0:
296 | logger.add_scalar('loss/train/solubility', loss_sol_avg.get(), global_step=experiment.samples)
297 | if experiment.session.losses.l1 > 0:
298 | logger.add_scalar('loss/train/l1', loss_l1_avg.get(), global_step=experiment.samples)
299 |
300 | del train_bar, train_bar_postfix, loss_sol_avg, loss_l1_avg, loss_total_avg
301 | # endregion
302 |
303 | # region Validation loop
304 | model.eval()
305 | torch.set_grad_enabled(False)
306 |
307 | val_bar_postfix = {}
308 | loss_sol_avg = RunningWeightedAverage()
309 | loss_total_avg = RunningWeightedAverage()
310 | loss_l1 = sum([p.abs().sum() for p in model.parameters()])
311 | val_bar_postfix['Solubility'] = ''
312 | val_bar_postfix['L1'] = f'{loss_l1.item():.5f}'
313 |
314 | val_bar = tqdm.tqdm(desc=f'Val {experiment.epoch}', total=len(dataloader_val.dataset), unit='g')
315 | for batch_idx, (graphs, targets) in enumerate(dataloader_val):
316 | graphs = graphs.to(experiment.session.device)
317 | targets = targets.to(experiment.session.device)
318 | results = model(graphs)
319 |
320 | loss_total = torch.tensor(0., device=experiment.session.device)
321 |
322 | if experiment.session.losses.solubility > 0:
323 | loss_sol = F.mse_loss(
324 | results.global_features.squeeze(), targets, reduction='mean')
325 | loss_total += experiment.session.losses.solubility * loss_sol
326 | loss_sol_avg.add(loss_sol.item(), len(graphs))
327 | val_bar_postfix['Solubility'] = f'{loss_sol.item():.5f}'
328 |
329 | if experiment.session.losses.l1 > 0:
330 | loss_total += experiment.session.losses.l1 * loss_l1
331 |
332 | val_bar_postfix['Total'] = f'{loss_total.item():.5f}'
333 | loss_total_avg.add(loss_total.item(), len(graphs))
334 |
335 | # region Last epoch
336 | if epoch_idx == experiment.session.epochs:
337 | loss_sol_by_graph = F.mse_loss(
338 | results.global_features.squeeze(), targets, reduction='none')
339 |
340 | graphs_df['LossSolubility'].append(loss_sol_by_graph.cpu())
341 | graphs_df['Nodes'].append(graphs.num_nodes_by_graph.cpu())
342 | graphs_df['Edges'].append(graphs.num_edges_by_graph.cpu())
343 | graphs_df['Pred'].append(results.global_features.squeeze().cpu())
344 | graphs_df['Real'].append(targets.cpu())
345 | # endregion
346 |
347 | val_bar.update(len(graphs))
348 | val_bar.set_postfix(val_bar_postfix)
349 | val_bar.close()
350 |
351 | experiment.loss_sol_val = loss_sol_avg.get()
352 | epoch_bar_postfix['Val'] = f'{loss_total_avg.get():.4f}'
353 | epoch_bar.set_postfix(epoch_bar_postfix)
354 |
355 | if (
356 | 'every batch' in experiment.session.log.when or
357 | 'every epoch' in experiment.session.log.when or
358 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs
359 | ):
360 | logger.add_scalar('loss/val/total', loss_total_avg.get(), global_step=experiment.samples)
361 | if experiment.session.losses.solubility > 0:
362 | logger.add_scalar('loss/val/solubility', loss_sol_avg.get(), global_step=experiment.samples)
363 | if experiment.session.losses.l1 > 0:
364 | logger.add_scalar('loss/val/l1', loss_l1.item(), global_step=experiment.samples)
365 |
366 | del val_bar, val_bar_postfix, loss_sol_avg, loss_l1, loss_total_avg, batch_idx
367 | # endregion
368 |
369 | # Saving
370 | if epoch_idx == experiment.session.epochs:
371 | experiment.session.status = 'DONE'
372 | experiment.session.datetime_completed = datetime.utcnow()
373 | if (
374 | 'every batch' in experiment.session.checkpoint.when or
375 | 'every epoch' in experiment.session.checkpoint.when or
376 | 'last epoch' in experiment.session.checkpoint.when and epoch_idx == experiment.session.epochs
377 | ):
378 | saver.save(model, experiment, optimizer, suffix=f'e{experiment.epoch:04d}')
379 | epoch_bar.close()
380 | print()
381 | del epoch_bar, epoch_bar_postfix, epoch_idx
382 | # endregion
383 |
384 | # region Final report
385 | pd.options.display.precision = 2
386 | pd.options.display.max_columns = 999
387 | pd.options.display.expand_frame_repr = False
388 |
389 | graphs_df = pd.DataFrame({k: np.concatenate(v) for k, v in graphs_df.items()}).rename_axis('GraphId').reset_index()
390 | experiment.loss_sol = graphs_df.LossSolubility.mean()
391 | print('Solubility MSE:', experiment.loss_sol)
392 |
393 | # Split the results in ranges based on the number of nodes and compute the average loss per range
394 | df_losses_by_node_range = graphs_df \
395 | .groupby(graphs_df.Nodes // 10) \
396 | .agg({'Nodes': ['min', 'max'], 'GraphId': 'count', 'LossSolubility': 'mean'}) \
397 | .rename_axis(index='NodeRange') \
398 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index') \
399 | .rename(str.capitalize, axis='columns', level=1)
400 |
401 | # Split the results in ranges based on the number of nodes and keep the N worst predictions w.r.t. graph-wise loss
402 | df_worst_solubility_loss_by_node_range = graphs_df \
403 | .groupby(graphs_df.Nodes // 10) \
404 | .apply(lambda df_gr: df_gr.nlargest(5, 'LossSolubility').set_index('GraphId')) \
405 | .rename_axis(index={'Nodes': 'NodeRange'}) \
406 | .rename(lambda node_group_min: f'[{node_group_min * 10}, {node_group_min * 10 + 10})', axis='index', level=0)
407 |
408 | print(f"""
409 | Losses by range:
410 | {df_losses_by_node_range}\n
411 | Worst solubility predictions:
412 | {df_worst_solubility_loss_by_node_range}
413 | """)
414 |
415 | if logger is not None:
416 | logger.add_text(
417 | 'Losses by range',
418 | textwrap.indent(df_losses_by_node_range.to_string(), ' '),
419 | global_step=experiment.samples)
420 | logger.add_text(
421 | 'Worst solubility predictions',
422 | textwrap.indent(df_worst_solubility_loss_by_node_range.to_string(), ' '),
423 | global_step=experiment.samples)
424 | del graphs_df, df_losses_by_node_range, df_worst_solubility_loss_by_node_range
425 | # endregion
426 |
427 | # region Cleanup
428 | if logger is not None:
429 | logger.close()
430 | # endregion
431 |
--------------------------------------------------------------------------------
/src/test_edge_linear.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import torch
4 |
5 | import relevance
6 |
7 | import torchgraphs as tg
8 |
9 | torch.set_grad_enabled(False)
10 | relevance.register_function(torch.nn.Linear, relevance.LinearEpsilon)
11 |
12 | graphs = tg.GraphBatch.from_graphs([tg.Graph(
13 | node_features=torch.rand(10, 9),
14 | edge_features=torch.rand(6, 5),
15 | senders=torch.tensor([0, 0, 1, 2, 4, 5]),
16 | receivers=torch.tensor([1, 2, 2, 4, 3, 3]),
17 | global_features=torch.rand(7)
18 | )])
19 |
20 | print('bias ef sf rf gf out input'.replace(' ', '\t'))
21 | print('-' * 45)
22 |
23 | for bias, ef, sf, rf, gf in itertools.product(
24 | [True, False, 'zero'],
25 | [5, None],
26 | [9, None],
27 | [9, None],
28 | [7, None]
29 | ):
30 | if ef is sf is rf is gf is None:
31 | continue
32 | net = tg.EdgeLinear(3, edge_features=ef, sender_features=sf, receiver_features=rf, global_features=gf, bias=bool(bias))
33 | if bias == 'zero':
34 | net.bias.zero_()
35 |
36 | ctx = {}
37 | out = relevance.forward_relevance(net, graphs, ctx=ctx)
38 | torch.testing.assert_allclose(out.edge_features, net(graphs).edge_features)
39 |
40 | rel_out = graphs.evolve(
41 | edge_features=torch.ones_like(out.edge_features) * (out.edge_features != 0).float()
42 | )
43 |
44 | rel_in = relevance.backward_relevance(net, rel_out, ctx=ctx)
45 |
46 | # If bias==0 then relevance is conserved at a graph level
47 | print(
48 | {True: ' x ', False: ' - ', 'zero': ' 0 '}[bias], ef if ef else '-', sf if sf else '-',
49 | rf if rf else '-', gf if gf else '-',
50 | rel_out.edge_features.sum().item(),
51 | ((rel_in.global_features.sum() if rel_in.global_features is not None else 0) +
52 | (rel_in.edge_features.sum() if rel_in.edge_features is not None else 0) +
53 | (rel_in.node_features.sum() if rel_in.node_features is not None else 0)).item(),
54 | sep='\t'
55 | )
56 |
--------------------------------------------------------------------------------
/src/test_global_linear.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import torch
4 |
5 | import relevance
6 |
7 | import torchgraphs as tg
8 |
9 | torch.set_grad_enabled(False)
10 | relevance.register_function(torch.nn.Linear, relevance.LinearEpsilon)
11 |
12 | graphs = tg.GraphBatch.from_graphs([tg.Graph(
13 | node_features=torch.rand(10, 9),
14 | edge_features=torch.rand(6, 5),
15 | senders=torch.tensor([0, 0, 1, 2, 4, 5]),
16 | receivers=torch.tensor([1, 2, 2, 4, 3, 3]),
17 | global_features=torch.rand(7)
18 | )])
19 |
20 | print('bias nf ef gf R_out R_in'.replace(' ', '\t'))
21 | print('-' * 45)
22 |
23 | for bias, nf, ef, gf in itertools.product(
24 | [True, False, 'zero'],
25 | [9, None],
26 | [5, None],
27 | [7, None]
28 | ):
29 | if nf is ef is gf is None:
30 | continue
31 | net = tg.GlobalLinear(3, node_features=nf, edge_features=ef, global_features=gf, bias=bool(bias), aggregation='sum')
32 |
33 | if bias == 'zero':
34 | net.bias.zero_()
35 |
36 | ctx = {}
37 | out = relevance.forward_relevance(net, graphs, ctx=ctx)
38 | torch.testing.assert_allclose(out.node_features, net(graphs).node_features)
39 |
40 | rel_out = graphs.evolve(
41 | global_features=torch.ones_like(out.global_features) * (out.global_features != 0).float()
42 | )
43 |
44 | rel_in = relevance.backward_relevance(net, rel_out, ctx=ctx)
45 |
46 | # If bias==0 then relevance is conserved at a graph level
47 | print(
48 | {True: ' x ', False: ' - ', 'zero': ' 0 '}[bias], nf if nf else '-', ef if ef else '-', gf if gf else '-',
49 | rel_out.global_features.sum().item(),
50 | ((rel_in.global_features.sum() if rel_in.global_features is not None else 0) +
51 | (rel_in.edge_features.sum() if rel_in.edge_features is not None else 0) +
52 | (rel_in.node_features.sum() if rel_in.node_features is not None else 0)).item(),
53 | sep='\t'
54 | )
55 |
--------------------------------------------------------------------------------
/src/test_node_linear.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import torch
4 |
5 | import relevance
6 |
7 | import torchgraphs as tg
8 |
9 | torch.set_grad_enabled(False)
10 | relevance.register_function(torch.nn.Linear, relevance.LinearEpsilon)
11 |
12 | graphs = tg.GraphBatch.from_graphs([tg.Graph(
13 | node_features=torch.rand(10, 9),
14 | edge_features=torch.rand(6, 5),
15 | senders=torch.tensor([0, 0, 1, 2, 4, 5]),
16 | receivers=torch.tensor([1, 2, 2, 4, 3, 3]),
17 | global_features=torch.rand(7)
18 | )])
19 |
20 | print('agg bias nf inf outf gf out input'.replace(' ', '\t'))
21 | print('-' * 45)
22 |
23 | for agg, bias, nf, inf, outf, gf in itertools.product(
24 | ['sum', 'avg', 'max'],
25 | [True, False, 'zero'],
26 | [9, None],
27 | [5, None],
28 | [5, None],
29 | [7, None]
30 | ):
31 | if nf is inf is outf is gf is None:
32 | continue
33 | net = tg.NodeLinear(3, node_features=nf, incoming_features=inf, outgoing_features=outf, global_features=gf, bias=bias, aggregation=agg)
34 | if bias == 'zero':
35 | net.bias.zero_()
36 |
37 | ctx = {}
38 | out = relevance.forward_relevance(net, graphs, ctx=ctx)
39 | torch.testing.assert_allclose(out.node_features, net(graphs).node_features)
40 |
41 | rel_out = graphs.evolve(
42 | node_features=torch.ones_like(out.node_features) * (out.node_features != 0).float()
43 | )
44 |
45 | rel_in = relevance.backward_relevance(net, rel_out, ctx=ctx)
46 |
47 | # If bias==0 then relevance is conserved at a graph level
48 | print(
49 | agg, {True: ' x ', False: ' - ', 'zero': ' 0 '}[bias], nf if nf else '-', inf if inf else '-',
50 | outf if outf else '-', gf if gf else '-',
51 | rel_out.node_features.sum().item(),
52 | ((rel_in.global_features.sum() if rel_in.global_features is not None else 0) +
53 | (rel_in.edge_features.sum() if rel_in.edge_features is not None else 0) +
54 | (rel_in.node_features.sum() if rel_in.node_features is not None else 0)).item(),
55 | sep='\t'
56 | )
57 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import subprocess
3 | import collections
4 | from typing import Mapping, Iterable, MutableMapping
5 |
6 | import yaml
7 | from munch import munchify, AutoMunch
8 |
9 |
10 | def git_info():
11 | try:
12 | import git
13 | try:
14 | result = {}
15 | repo = git.Repo(search_parent_directories=True)
16 | try:
17 | result['url'] = repo.remote(name='origin').url
18 | except ValueError:
19 | result['url'] = 'git:/' + repo.working_dir
20 | result['commit'] = repo.head.commit.hexsha
21 | result['dirty'] = repo.is_dirty()
22 | if repo.is_dirty():
23 | result['diffs'] = [str(diff) for diff in repo.head.commit.diff(other=None, create_patch=True)]
24 | if len(repo.untracked_files) > 0:
25 | result['untracked_files'] = repo.untracked_files
26 | return result
27 | except (git.InvalidGitRepositoryError, ValueError):
28 | pass
29 | except ImportError:
30 | return None
31 |
32 |
33 | def cuda_info():
34 | from xml.etree import ElementTree
35 | try:
36 | nvidia_smi_xml = subprocess.check_output(['nvidia-smi', '-q', '-x']).decode()
37 | except (FileNotFoundError, OSError, subprocess.CalledProcessError):
38 | return None
39 |
40 | driver = ''
41 | gpus = []
42 | for child in ElementTree.fromstring(nvidia_smi_xml):
43 | if child.tag == 'driver_version':
44 | driver = child.text
45 | elif child.tag == 'gpu':
46 | gpus.append({
47 | 'model': child.find('product_name').text,
48 | 'utilization': child.find('utilization').find('gpu_util').text,
49 | 'memory_used': child.find('fb_memory_usage').find('used').text,
50 | 'memory_total': child.find('fb_memory_usage').find('total').text,
51 | })
52 |
53 | return {'driver': driver, 'gpus': gpus}
54 |
55 |
56 | def parse_dotted(string):
57 | result_dict = {}
58 | for kv_pair in string.split(' '):
59 | sub_dict = result_dict
60 | name_dotted, value = kv_pair.split('=')
61 | name_head, *name_rest = name_dotted.split('.')
62 | while len(name_rest) > 0:
63 | sub_dict = sub_dict.setdefault(name_head, {})
64 | name_head, *name_rest = name_rest
65 | sub_dict[name_head] = yaml.safe_load(value)
66 | return result_dict
67 |
68 |
69 | def update_rec(target, source):
70 | for k in source.keys():
71 | if k in target and isinstance(target[k], Mapping) and isinstance(source[k], Mapping):
72 | update_rec(target[k], source[k])
73 | else:
74 | # AutoMunch should do its job, but sometimes it doesn't
75 | target[k] = munchify(source[k], AutoMunch)
76 |
77 |
78 | def import_(fullname):
79 | import importlib
80 | package, name = fullname.rsplit('.', maxsplit=1)
81 | package = importlib.import_module(package)
82 | return getattr(package, name)
83 |
84 |
85 | def set_seeds(seed):
86 | import random
87 | import torch
88 | import numpy as np
89 | random.seed(seed)
90 | np.random.seed(seed)
91 | torch.random.manual_seed(seed)
92 |
93 |
94 | def sort_dict(mapping: MutableMapping, order: Iterable):
95 | for key in itertools.chain(filter(mapping.__contains__, order), set(mapping) - set(order)):
96 | mapping[key] = mapping.pop(key)
97 | return mapping
98 |
99 |
100 | class RunningWeightedAverage(object):
101 | def __init__(self):
102 | self.total_weight = 0
103 | self.total_weighted_value = 0
104 |
105 | def add(self, value, weight):
106 | if weight <= 0:
107 | raise ValueError()
108 | self.total_weighted_value += value * weight
109 | self.total_weight += weight
110 |
111 | def get(self):
112 | if self.total_weight == 0:
113 | return 0
114 | return self.total_weighted_value / self.total_weight
115 |
116 | def __repr__(self):
117 | return f'{self.get() (self.total_weight)}'
118 |
--------------------------------------------------------------------------------
/src/yaml_ext.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import yaml
3 |
4 | _init = False
5 |
6 |
7 | class TorchDevice:
8 | _tag = '!torch.device'
9 | _class = torch.device
10 |
11 | @staticmethod
12 | def represent(dumper, device):
13 | return dumper.represent_scalar(TorchDevice._tag, str(device))
14 |
15 | @staticmethod
16 | def construct(loader, node):
17 | return torch.device(loader.construct_scalar(node))
18 |
19 | @classmethod
20 | def register(cls):
21 | yaml.add_representer(cls._class, cls.represent, yaml.SafeDumper)
22 | yaml.add_constructor(cls._tag, cls.construct, yaml.SafeLoader)
23 |
24 |
25 | def init_ext():
26 | global _init
27 | if not _init:
28 | TorchDevice.register()
29 | _init = True
30 |
--------------------------------------------------------------------------------