├── LICENSE
├── Makefile
├── README.md
├── data
├── figures
├── models
├── notebooks
├── discrete_diffusion_graph.ipynb
├── discrete_diffusion_graph_controlled_cliques.ipynb
├── discrete_diffusion_graph_controlled_community.ipynb
├── discrete_diffusion_graph_controlled_molecule_like.ipynb
├── discrete_diffusion_graph_controlled_zinc.ipynb
├── discrete_diffusion_graph_degree.ipynb
├── discrete_diffusion_graph_sanity_checks.ipynb
├── discrete_diffusion_mnist.ipynb
├── figures
│ ├── controlled_cliques.ipynb
│ ├── controlled_community.ipynb
│ ├── controlled_molecule_like.ipynb
│ ├── discrete_diffusion_summary.ipynb
│ ├── mmd_performance.ipynb
│ └── mmd_performance_nocache.ipynb
├── generative_performance_controlled_cliques.ipynb
├── generative_performance_controlled_community.ipynb
├── generative_performance_controlled_molecule_like.ipynb
├── generative_performance_kernel_comparison_controlled_cliques.ipynb
├── generative_performance_kernel_comparison_controlled_community.ipynb
└── zinc_sandbox.ipynb
├── references
├── bibtex.bib
└── thumbnail.png
├── results
└── src
├── analysis
├── graph_metrics.py
├── mmd.py
└── orca.cpp
├── feature
├── __init__.py
├── graph_conversions.py
├── molecule_dataset.py
└── random_graph_dataset.py
├── model
├── __init__.py
├── digress_gnn.py
├── discrete_diffusers.py
├── generate.py
├── gnn.py
├── image_unet.py
├── train_model.py
└── util.py
└── plot
├── __init__.py
└── plot.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2022, Genentech, Inc.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | install-dependencies:
2 | # Must be running Python 3.8
3 | conda install -y -c pytorch pytorch=1.11.0
4 | TORCH=$(python -c "import torch; print(torch.__version__)")
5 | CUDA=$(python -c "import torch; print(torch.version.cuda)")
6 |
7 | pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
8 | pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html # This installs NumPy and SciPy
9 |
10 | pip install torch-geometric # This installs tqdm and scikit-learn
11 |
12 | pip install networkx==2.6
13 | conda install -y -c anaconda click pymongo jupyter pandas
14 | conda install -y -c conda-forge matplotlib
15 | pip install sacred tables vdom
16 | conda install -y h5py
17 |
18 | # Note about torch-scatter:
19 | # It is possible that this installation pipeline will not install torch-scatter
20 | # successfully. In particular, torch-scatter might import fine, but it may not
21 | # be usable on GPU. Using certain PyTorch Geometric operations might cause a
22 | # "Not compiled with CUDA support" error to be thrown on GPU. To fix this,
23 | # reinstall torch-scatter as follows:
24 | # 1. Navigate to https://data.pyg.org/whl/, which has all the pip wheel files
25 | # for PyTorch Geometric
26 | # 2. Go to https://data.pyg.org/whl/torch-1.11.0+cu113.html
27 | # Replace the versions of PyTorch and CUDA appropriately, using the same
28 | # versions as $TORCH and $CUDA above; for example, cu113 is for CUDA 11.3
29 | # 3. Download the wheel file `torch_scatter-2.0.9-cp38-cp38-linux_x86_64.whl`,
30 | # which is for Linux; the `cp38` refers to Python 3.8, which is what we're
31 | # using
32 | # 4. Uninstall torch-scatter using `pip uninstall torch-scatter` and manually
33 | # install the wheel file using `pip install torch_scatter*.whl`
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GraphGUIDE: interpretable and controllable conditional graph generation with discrete Bernoulli diffusion
2 |
3 |
4 |
5 |
6 |
7 | ### Introduction
8 |
9 | Diffusion models achieve state-of-the art performance in generating realistic objects, with the most useful and impactful work in diffusion being on _conditional_ generation, where objects of certain specific properties are generated.
10 |
11 | Unfortunately, it has remained difficult to perform conditional generation on graphs, particularly due to the discrete nature of graphs and the large number of possible structural properties which could be desired in any given generation task.
12 |
13 | This repository implements GraphGUIDE, a framework for graph generation using diffusion models, where edges in the graph are flipped or set at each discrete time step. GraphGUIDE enables full control over the conditional generation of arbitrary structural properties without relying on predefined labels. This framework for graph diffusion allows for the interpretable conditional generation of graphs, including the generation of drug-like molecules with desired properties in a way which is informed by experimental evidence.
14 |
15 | See the [corresponding paper](https://arxiv.org/abs/2302.03790) for more information.
16 |
17 | This repository houses all of the code used to generate the results for the paper, including code that processes data, trains models, implements GraphGUIDE, and generates all figures in the paper.
18 |
19 | ### Citing this work
20 |
21 | If you found GraphGUIDE to be helpful for your work, please cite the following:
22 |
23 | Tseng, A.M., Diamant, N., Biancalani, T., Scalia, G. GraphGUIDE: interpretable and controllable conditional graph generation with discrete Bernoulli diffusion. arXiv (2023) [Link](https://arxiv.org/abs/2302.03790)
24 |
25 | [\[BibTeX\]](references/bibtex.bib)
26 |
27 | ### Description of files and directories
28 |
29 | ```
30 | ├── Makefile <- Installation of dependencies
31 | ├── data <- Contains data for training and downstream analysis
32 | │ ├── raw <- Raw data, directly downloaded from the source
33 | │ ├── interim <- Intermediate data mid-processing
34 | │ ├── processed <- Processed data ready for training or analysis
35 | │ └── README.md <- Description of data
36 | ├── models
37 | │ └── trained_models <- Trained models
38 | ├── notebooks <- Jupyter notebooks that explore data, plot results, and analyze results
39 | │ └── figures <- Jupyter notebooks that create figures
40 | ├── results <- Saved results
41 | ├── README.md <- This file
42 | └── src <- Source code
43 | ├── feature <- Code for data loading and featurization
44 | ├── model <- Code for model architectures and training
45 | ├── analysis <- Code for analyzing results
46 | └── plot <- Code for plotting and visualization
47 | ```
48 |
--------------------------------------------------------------------------------
/data:
--------------------------------------------------------------------------------
1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/data
--------------------------------------------------------------------------------
/figures:
--------------------------------------------------------------------------------
1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/figures/
--------------------------------------------------------------------------------
/models:
--------------------------------------------------------------------------------
1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/models
--------------------------------------------------------------------------------
/notebooks/figures/mmd_performance.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "5b894b89",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import os\n",
11 | "import json\n",
12 | "import numpy as np"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "id": "6f3afe9a",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "# Define MMD values from other papers\n",
23 | "\n",
24 | "metric_list = [\"Degree\", \"Cluster\", \"Spectrum\", \"Orbit\"]\n",
25 | "\n",
26 | "benchmark_mmds = { # From SPECTRE paper\n",
27 | " \"Community (small)\": {\n",
28 | " \"GraphRNN\": [0.08, 0.12, -1, 0.04],\n",
29 | " \"GRAN\": [0.06, 0.11, -1, 0.01],\n",
30 | " \"MolGAN\": [0.06, 0.13, -1, 0.01],\n",
31 | " \"SPECTRE\": [0.02, 0.21, -1, 0.01]\n",
32 | " \n",
33 | " },\n",
34 | " \"Stochastic block models\": {\n",
35 | " \"GraphRNN\": [0.0055, 0.0584, 0.0065, 0.0785],\n",
36 | " \"GRAN\": [0.0113, 0.0553, 0.0054, 0.0540],\n",
37 | " \"MolGAN\": [0.0235, 0.1161, 0.0117, 0.0712],\n",
38 | " \"SPECTRE\": [0.0079, 0.0528, 0.0643, 0.0074]\n",
39 | " }\n",
40 | "}\n",
41 | "\n",
42 | "benchmark_mmd_ratios = { # From DiGress paper\n",
43 | " \"Community (small)\": {\n",
44 | " \"GraphRNN\": [4.0, 1.7, -1, 4.0],\n",
45 | " \"GRAN\": [3.0, 1.6, -1, 1.0],\n",
46 | " \"SPECTRE\": [0.5, 2.7, -1, 2.0],\n",
47 | " \"DiGress\": [1.0, 0.9, -1, 1.0],\n",
48 | " \n",
49 | " },\n",
50 | " \"Stochastic block models\": {\n",
51 | " \"GraphRNN\": [6.9, 1.7, -1, 3.1],\n",
52 | " \"GRAN\": [14.1, 1.7, -1, 2.1],\n",
53 | " \"SPECTRE\": [1.9, 1.6, -1, 1.6],\n",
54 | " \"DiGress\": [1.6, 1.5, -1, 1.7]\n",
55 | " }\n",
56 | "}\n",
57 | "\n",
58 | "benchmark_baselines = { # From SPECTRE paper\n",
59 | " \"Community (small)\": [0.02, 0.07, 1, 0.01],\n",
60 | " \"Stochastic block models\": [0.0008, 0.0332, 0.0063, 0.0255]\n",
61 | "}"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "id": "05ea7b2d",
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "def get_best_mmds(run_dir):\n",
72 | " # First, get the best run based on last loss\n",
73 | " best_loss, best_metrics = float(\"inf\"), None\n",
74 | " for run_num in os.listdir(run_dir):\n",
75 | " if run_num == \"_sources\":\n",
76 | " continue\n",
77 | " metrics_path = os.path.join(run_dir, run_num, \"metrics.json\")\n",
78 | " with open(metrics_path, \"r\") as f:\n",
79 | " metrics = json.load(f)\n",
80 | " last_loss = metrics[\"train_epoch_loss\"][\"values\"][-1]\n",
81 | " if last_loss < best_loss:\n",
82 | " best_loss, best_metrics = last_loss, metrics\n",
83 | " \n",
84 | " # Now return the MMDs and baselines\n",
85 | " return (\n",
86 | " [\n",
87 | " best_metrics[\"degree_mmd\"][\"values\"][0],\n",
88 | " best_metrics[\"cluster_coef_mmd\"][\"values\"][0],\n",
89 | " best_metrics[\"spectra_mmd\"][\"values\"][0],\n",
90 | " best_metrics[\"orbit_mmd\"][\"values\"][0]\n",
91 | " ],\n",
92 | " [\n",
93 | " best_metrics[\"degree_mmd_baseline\"][\"values\"][0],\n",
94 | " best_metrics[\"cluster_coef_mmd_baseline\"][\"values\"][0],\n",
95 | " best_metrics[\"spectra_mmd_baseline\"][\"values\"][0],\n",
96 | " best_metrics[\"orbit_mmd_baseline\"][\"values\"][0]\n",
97 | " ]\n",
98 | " )"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 4,
104 | "id": "aa392c06",
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "# Import MMD and baseline values from training runs\n",
109 | "\n",
110 | "base_path = \"/gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/\"\n",
111 | "my_mmds_and_baselines = {\n",
112 | " \"Community (small)\": {\n",
113 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark_community-small_edge-flip\")),\n",
114 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark_community-small_edge-addition\")),\n",
115 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark_community-small_edge-deletion\"))\n",
116 | " },\n",
117 | " \"Stochastic block models\": {\n",
118 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark_sbm_edge-flip\")),\n",
119 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark_sbm_edge-addition\")),\n",
120 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark_sbm_edge-deletion\"))\n",
121 | " }\n",
122 | "}\n",
123 | "\n",
124 | "my_mmds = {d_key : {k_key : vals[0] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}\n",
125 | "my_baselines = {d_key : {k_key : vals[1] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 5,
131 | "id": "90b38078",
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "name": "stdout",
136 | "output_type": "stream",
137 | "text": [
138 | "Community (small)\n",
139 | "GraphRNN & 2.00 & 1.31 & 2.00\n",
140 | "GRAN & 1.73 & 1.25 & 1.00\n",
141 | "MolGAN & 1.73 & 1.36 & 1.00\n",
142 | "SPECTRE & 1.00 & 1.73 & 1.00\n",
143 | "DiGress & 1.00 & 0.95 & 1.00\n",
144 | "Edge-flip & 0.99 & 0.58 & 2.55\n",
145 | "Edge-one & 1.21 & 0.62 & 1.83\n",
146 | "Edge-zero & 1.87 & 1.02 & 4.69\n",
147 | "=========================\n",
148 | "Stochastic block models\n",
149 | "GraphRNN & 2.62 & 1.33 & 1.75\n",
150 | "GRAN & 3.76 & 1.29 & 1.46\n",
151 | "MolGAN & 5.42 & 1.87 & 1.67\n",
152 | "SPECTRE & 3.14 & 1.26 & 0.54\n",
153 | "DiGress & 1.26 & 1.22 & 1.30\n",
154 | "Edge-flip & 2.73 & 1.23 & 0.94\n",
155 | "Edge-one & 1.00 & 1.21 & 0.81\n",
156 | "Edge-zero & 1.31 & 1.19 & 0.80\n",
157 | "=========================\n"
158 | ]
159 | },
160 | {
161 | "name": "stderr",
162 | "output_type": "stream",
163 | "text": [
164 | "/local/60676799/ipykernel_28989/435977596.py:4: RuntimeWarning: invalid value encountered in sqrt\n",
165 | " vals = np.sqrt(vals)\n"
166 | ]
167 | }
168 | ],
169 | "source": [
170 | "# Print out results\n",
171 | "\n",
172 | "def print_vals(key, vals):\n",
173 | " vals = np.sqrt(vals)\n",
174 | " print(\"%s & %.2f & %.2f & %.2f\" % (key, vals[0], vals[1], vals[3]))\n",
175 | "\n",
176 | "for d_key in my_mmds.keys():\n",
177 | " print(d_key)\n",
178 | " \n",
179 | " for bm_key, bm_vals in benchmark_mmds[d_key].items():\n",
180 | " print_vals(bm_key, np.array(bm_vals) / np.array(benchmark_baselines[d_key]))\n",
181 | " # print(bm_key, np.array(bm_vals) / np.array(my_baselines[d_key][\"Edge-flip\"]))\n",
182 | "# for bm_key, bm_vals in benchmark_mmd_ratios[d_key].items():\n",
183 | "# print(bm_key, np.array(bm_vals))\n",
184 | " print_vals(\"DiGress\", np.array(benchmark_mmd_ratios[d_key][\"DiGress\"]))\n",
185 | " \n",
186 | " for my_key, my_vals in my_mmds[d_key].items():\n",
187 | " print_vals(my_key, np.array(my_vals) / np.array(benchmark_baselines[d_key]))\n",
188 | "# print(my_key, np.array(my_vals) / np.array(my_baselines[d_key][my_key]))\n",
189 | " print(\"=========================\")"
190 | ]
191 | }
192 | ],
193 | "metadata": {
194 | "kernelspec": {
195 | "display_name": "Python 3 (ipykernel)",
196 | "language": "python",
197 | "name": "python3"
198 | },
199 | "language_info": {
200 | "codemirror_mode": {
201 | "name": "ipython",
202 | "version": 3
203 | },
204 | "file_extension": ".py",
205 | "mimetype": "text/x-python",
206 | "name": "python",
207 | "nbconvert_exporter": "python",
208 | "pygments_lexer": "ipython3",
209 | "version": "3.8.13"
210 | }
211 | },
212 | "nbformat": 4,
213 | "nbformat_minor": 5
214 | }
215 |
--------------------------------------------------------------------------------
/notebooks/figures/mmd_performance_nocache.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "5b894b89",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import os\n",
11 | "import json\n",
12 | "import numpy as np"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 8,
18 | "id": "05ea7b2d",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "def get_best_mmds(run_dir):\n",
23 | " # First, get the best run based on last loss\n",
24 | " best_loss, best_metrics = float(\"inf\"), None\n",
25 | " for run_num in os.listdir(run_dir):\n",
26 | " if run_num == \"_sources\":\n",
27 | " continue\n",
28 | " metrics_path = os.path.join(run_dir, run_num, \"metrics.json\")\n",
29 | " with open(metrics_path, \"r\") as f:\n",
30 | " metrics = json.load(f)\n",
31 | " try:\n",
32 | " last_loss = metrics[\"train_epoch_loss\"][\"values\"][-1]\n",
33 | " \n",
34 | " # Try and get a metric out\n",
35 | " _ = metrics[\"orbit_mmd\"][\"values\"]\n",
36 | " except KeyError:\n",
37 | " print(\"Warning: Did not find finished run in %s\" % os.path.join(run_dir, run_num))\n",
38 | " last_loss = float(\"inf\")\n",
39 | " if last_loss < best_loss:\n",
40 | " best_loss, best_metrics = last_loss, metrics\n",
41 | " \n",
42 | " # Now return the MMDs and baselines\n",
43 | " return (\n",
44 | " [\n",
45 | " best_metrics[\"degree_mmd\"][\"values\"][0],\n",
46 | " best_metrics[\"cluster_coef_mmd\"][\"values\"][0],\n",
47 | "# best_metrics[\"spectra_mmd\"][\"values\"][0],\n",
48 | " best_metrics[\"orbit_mmd\"][\"values\"][0]\n",
49 | " ],\n",
50 | " [\n",
51 | " best_metrics[\"degree_mmd_baseline\"][\"values\"][0],\n",
52 | " best_metrics[\"cluster_coef_mmd_baseline\"][\"values\"][0],\n",
53 | "# best_metrics[\"spectra_mmd_baseline\"][\"values\"][0],\n",
54 | " best_metrics[\"orbit_mmd_baseline\"][\"values\"][0]\n",
55 | " ]\n",
56 | " )"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 9,
62 | "id": "aa392c06",
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "name": "stdout",
67 | "output_type": "stream",
68 | "text": [
69 | "Warning: Did not find finished run in /gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/benchmark-nocache_sbm_edge-addition/3\n",
70 | "Warning: Did not find finished run in /gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/benchmark-nocache_sbm_edge-deletion/4\n",
71 | "Warning: Did not find finished run in /gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/benchmark-nocache_sbm_edge-deletion/1\n"
72 | ]
73 | }
74 | ],
75 | "source": [
76 | "# Import MMD and baseline values from training runs\n",
77 | "\n",
78 | "base_path = \"/gstore/home/tsenga5/discrete_graph_diffusion/models/trained_models/\"\n",
79 | "my_mmds_and_baselines = {\n",
80 | " \"Community (small)\": {\n",
81 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_community-small_edge-flip\")),\n",
82 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_community-small_edge-addition\")),\n",
83 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_community-small_edge-deletion\"))\n",
84 | " },\n",
85 | " \"Stochastic block models\": {\n",
86 | " \"Edge-flip\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_sbm_edge-flip\")),\n",
87 | " \"Edge-one\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_sbm_edge-addition\")),\n",
88 | " \"Edge-zero\": get_best_mmds(os.path.join(base_path, \"benchmark-nocache_sbm_edge-deletion\"))\n",
89 | " }\n",
90 | "}\n",
91 | "\n",
92 | "my_mmds = {d_key : {k_key : vals[0] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}\n",
93 | "my_baselines = {d_key : {k_key : vals[1] for k_key, vals in d_dict.items()} for d_key, d_dict in my_mmds_and_baselines.items()}"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 13,
99 | "id": "667229b5",
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "benchmark_baselines = { # From SPECTRE paper\n",
104 | " \"Community (small)\": [0.02, 0.07, 0.01],\n",
105 | " \"Stochastic block models\": [0.0008, 0.0332, 0.0255]\n",
106 | "}"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 15,
112 | "id": "90b38078",
113 | "metadata": {},
114 | "outputs": [
115 | {
116 | "name": "stdout",
117 | "output_type": "stream",
118 | "text": [
119 | "Community (small)\n",
120 | "Edge-flip & 0.0072 & 0.0319 & 0.0033 & 0.36 & 0.46 & 0.33 \\\\\n",
121 | "Edge-one & 0.1913 & 0.1332 & 0.2977 & 9.56 & 1.90 & 29.77 \\\\\n",
122 | "Edge-zero & 0.0171 & 0.0274 & 0.0046 & 0.85 & 0.39 & 0.46 \\\\\n",
123 | "=========================\n",
124 | "Stochastic block models\n",
125 | "Edge-flip & 0.0409 & 0.0359 & 0.0147 & 51.17 & 1.08 & 0.58 \\\\\n",
126 | "Edge-one & 0.0313 & 0.0328 & 0.0160 & 39.16 & 0.99 & 0.63 \\\\\n",
127 | "Edge-zero & 0.0013 & 0.0337 & 0.0168 & 1.59 & 1.02 & 0.66 \\\\\n",
128 | "=========================\n"
129 | ]
130 | }
131 | ],
132 | "source": [
133 | "# Print out results\n",
134 | "\n",
135 | "def print_vals(key, vals):\n",
136 | " vals = np.sqrt(vals)\n",
137 | " print(\"%s & %.2f & %.2f & %.2f\" % (key, vals[0], vals[1], vals[3]))\n",
138 | "\n",
139 | "for d_key in my_mmds.keys():\n",
140 | " print(d_key) \n",
141 | " for my_key, my_vals in my_mmds[d_key].items():\n",
142 | " v = np.array(my_vals)\n",
143 | " b = np.array(my_baselines[d_key][my_key])\n",
144 | " b = np.array(benchmark_baselines[d_key])\n",
145 | " s = my_key + \" & \"\n",
146 | " s += \" & \".join([\"%.4f\" % x for x in v]) + \" & \"\n",
147 | "# s += \" & \".join([\"%.4f\" % x for x in b]) + \" & \"\n",
148 | " s += \" & \".join([\"%.2f\" % x for x in (v / b)]) + \" \\\\\\\\\"\n",
149 | " print(s)\n",
150 | " print(\"=========================\")"
151 | ]
152 | }
153 | ],
154 | "metadata": {
155 | "kernelspec": {
156 | "display_name": "Python 3 (ipykernel)",
157 | "language": "python",
158 | "name": "python3"
159 | },
160 | "language_info": {
161 | "codemirror_mode": {
162 | "name": "ipython",
163 | "version": 3
164 | },
165 | "file_extension": ".py",
166 | "mimetype": "text/x-python",
167 | "name": "python",
168 | "nbconvert_exporter": "python",
169 | "pygments_lexer": "ipython3",
170 | "version": "3.8.13"
171 | }
172 | },
173 | "nbformat": 4,
174 | "nbformat_minor": 5
175 | }
176 |
--------------------------------------------------------------------------------
/references/bibtex.bib:
--------------------------------------------------------------------------------
1 | @misc{https://doi.org/10.48550/arxiv.2302.03790,
2 | doi = {10.48550/ARXIV.2302.03790},
3 |
4 | url = {https://arxiv.org/abs/2302.03790},
5 |
6 | author = {Tseng, Alex M. and Diamant, Nathaniel and Biancalani, Tommaso and Scalia, Gabriele},
7 |
8 | keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
9 |
10 | title = {GraphGUIDE: interpretable and controllable conditional graph generation with discrete Bernoulli diffusion},
11 |
12 | publisher = {arXiv},
13 |
14 | year = {2023},
15 |
16 | copyright = {Creative Commons Attribution 4.0 International}
17 | }
18 |
--------------------------------------------------------------------------------
/references/thumbnail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Genentech/GraphGUIDE/dad0dd371268684a5203839441febe0484d8a3e4/references/thumbnail.png
--------------------------------------------------------------------------------
/results:
--------------------------------------------------------------------------------
1 | /gstore/data/resbioai/tsenga5/discrete_graph_diffusion/results
--------------------------------------------------------------------------------
/src/analysis/graph_metrics.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 | import scipy.linalg
4 | import tempfile
5 | import os
6 | import subprocess
7 |
8 | ORCA_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "orca")
9 |
10 | def get_degrees(graphs):
11 | """
12 | Computes the degrees of nodes in graphs.
13 | Arguments:
14 | `graphs`: a list of N NetworkX undirected graphs
15 | Returns a list of N NumPy arrays (each of size V) of degrees, where each
16 | array is ordered by the nodes in the graph. Note that V can be different for
17 | each graph.
18 | """
19 | result = []
20 | for g in graphs:
21 | degrees = nx.degree(g)
22 | result.append(np.array([degrees[i] for i in g.nodes]))
23 | return result
24 |
25 |
26 | def get_clustering_coefficients(graphs):
27 | """
28 | Computes the clustering coefficients of nodes in graphs.
29 | Arguments:
30 | `graphs`: a list of N NetworkX undirected graphs
31 | Returns a list of N NumPy arrays (each of size V) of clustering
32 | coefficients, where each array is ordered by the nodes in the graph. Note
33 | that V can be different for each graph.
34 | """
35 | result = []
36 | for g in graphs:
37 | coefs = nx.clustering(g)
38 | result.append(np.array([coefs[i] for i in g.nodes]))
39 | return result
40 |
41 |
42 | def get_spectra(graphs):
43 | """
44 | Computes the spectrum of graphs as the eigenvalues of the normalized
45 | Laplacian.
46 | Arguments:
47 | `graphs`: a list of N NetworkX undirected graphs
48 | Returns a list of N NumPy arrays (each of size V) of eigenvalues. Note that
49 | V can be different for each graph.
50 | """
51 | return [nx.normalized_laplacian_spectrum(g) for g in graphs]
52 |
53 |
54 | def run_orca(graph, max_graphlet_size=4):
55 | """
56 | Runs Orca on a given graph to count the number of orbits of each type each
57 | node belongs in.
58 | Arguments:
59 | `graph`: a NetworkX undirected graph
60 | `max_graphlet_size`: maximum size of graphlets whose orbits to count;
61 | must be either 4 or 5
62 | Returns a V x O NumPy array of orbit counts for each of the V nodes.
63 | """
64 | # Create temporary directory to do work in
65 | temp_dir_obj = tempfile.TemporaryDirectory()
66 | temp_dir = temp_dir_obj.name
67 | in_path = os.path.join(temp_dir, "graph.in")
68 | out_path = os.path.join(temp_dir, "orca.out")
69 |
70 | # Create input file
71 | with open(in_path, "w") as f:
72 | edges = graph.edges
73 | f.write("%d %d\n" % (max(graph.nodes) + 1, len(edges)))
74 | for edge in edges:
75 | f.write("%d %d\n" % edge)
76 |
77 | # Run Orca
78 | with open(os.devnull, "w") as f:
79 | subprocess.check_call([
80 | ORCA_PATH, "node", str(max_graphlet_size), in_path, out_path
81 | ], stdout=f)
82 |
83 | # Read in result
84 | with open(out_path, "r") as f:
85 | result = np.stack([
86 | np.array(list(map(int, line.strip().split()))) for line in f
87 | ])
88 |
89 | temp_dir_obj.cleanup()
90 |
91 | return result
92 |
93 |
94 | def get_orbit_counts(graphs, max_graphlet_size=4):
95 | """
96 | Computes the orbit counts of nodes in graphs. The orbit counts of a node are
97 | the number of times it appears in each orbit of the possible graphlets of
98 | size up to `max_graphlet_size`. For example, for `max_graphlet_size` of 4,
99 | there are 15 orbits possible for a node.
100 | Orbits are computed using Orca:
101 | https://academic.oup.com/bioinformatics/article/30/4/559/205331
102 | Arguments:
103 | `graphs`: a list of N NetworkX undirected graphs
104 | `max_graphlet_size`: maximum size of graphlets whose orbits to count;
105 | must be either 4 or 5
106 | Returns a list of N NumPy arrays (of size V x O) of orbit counts, where each
107 | array is ordered by the nodes in the graph. Note that V can be different for
108 | each graph, but O is the same for the same `max_graphlet_size`).
109 | """
110 | assert max_graphlet_size in (4, 5)
111 | return [run_orca(graph, max_graphlet_size) for graph in graphs]
112 |
113 |
114 | if __name__ == "__main__":
115 | graphs = [
116 | nx.erdos_renyi_graph(
117 | np.random.choice(np.arange(10, 20)),
118 | np.random.choice(np.linspace(0, 1, 10))
119 | )
120 | for _ in range(100)
121 | ]
122 |
123 | degrees = get_degrees(graphs)
124 | cluster_coefs = get_clustering_coefficients(graphs)
125 | spectra = get_spectra(graphs)
126 | orbit_counts = get_orbit_counts(graphs)
127 |
--------------------------------------------------------------------------------
/src/analysis/mmd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats
3 |
4 | def make_histograms(
5 | value_arrs, num_bins=None, bin_width=None, bin_array=None, frequency=False,
6 | epsilon=1e-6
7 | ):
8 | """
9 | Given a set of value arrays, converts them into histograms of counts or
10 | frequencies. The bins may be specified as a number of bins, a bin width, or
11 | a pre-defined array of bin edges. This function creates histograms such that
12 | all value arrays given are transformed into the same histogram space.
13 | Arguments:
14 | `value_arrs`: an iterable of N 1D NumPy arrays to make histograms of
15 | `num_bins`: if given, make the histograms have this number of bins total
16 | `bin_width`: if given, make the histograms have bins of this width,
17 | aligned starting at the minimum value
18 | `bin_array`: if given, make the histograms according to this NumPy array
19 | of bin edges
20 | `frequency`: if True, normalize each histogram into frequencies
21 | `epsilon`: small number for stability of last endpoint if `bin_width` is
22 | specified
23 | Returns an N x B array of counts or frequencies (N is parallel to the input
24 | `value_arrs`), where B is the number of bins in the histograms.
25 | """
26 | # Compute bins if needed
27 | if num_bins is not None:
28 | assert bin_width is None and bin_array is None
29 | min_val = min(np.nanmin(arr) for arr in value_arrs)
30 | max_val = max(np.nanmax(arr) for arr in value_arrs)
31 | bin_array = np.linspace(min_val, max_val, num_bins + 1)
32 | elif bin_width is not None:
33 | assert num_bins is None and bin_array is None
34 | min_val = min(np.nanmin(arr) for arr in value_arrs)
35 | max_val = max(np.nanmax(arr) for arr in value_arrs) + bin_width + \
36 | epsilon
37 | bin_array = np.arange(min_val, max_val, bin_width)
38 | elif bin_array is not None:
39 | assert num_bins is None and bin_width is None
40 | else:
41 | raise ValueError(
42 | "Must specify one of `num_bins`, `bin_width`, or `bin_array`"
43 | )
44 |
45 | # Compute histograms
46 | hists = np.empty((len(value_arrs), len(bin_array) - 1))
47 | for i, arr in enumerate(value_arrs):
48 | hist = np.histogram(arr, bins=bin_array)[0]
49 | if frequency:
50 | hist = hist / len(arr)
51 | hists[i] = hist
52 |
53 | return hists
54 |
55 |
56 | def gaussian_kernel(vec_1, vec_2, sigma=1):
57 | """
58 | Computes the Gaussian kernel function on two vectors. This is also known as
59 | the radial basis function.
60 | Arguments:
61 | `vec_1`: a NumPy array of values
62 | `vec_2`: a NumPy array of values; the underlying vector space must be
63 | the same as `vec_1`
64 | `sigma`: standard deviation for the Gaussian kernel
65 | Returns a scalar similarity value between 0 and 1.
66 | """
67 | l2_dist_squared = np.sum(np.square(vec_1 - vec_2))
68 | return np.exp(-l2_dist_squared / (2 * sigma * sigma))
69 |
70 |
71 | def gaussian_wasserstein_kernel(vec_1, vec_2, sigma=1):
72 | """
73 | Computes the Gaussian kernel function on two vectors, where the similarity
74 | metric within the Gaussian is the Wasserstein distance (i.e. Earthmover's
75 | distance). The two vectors must be distributions represented as PMFs over
76 | the same probability space.
77 | Arguments:
78 | `vec_1`: a NumPy array representing a PMF distribution (values are
79 | probabilities)
80 | `vec_2`: a NumPy array representing a PMF distribution (values are
81 | probabilities); the underlying probability space (i.e. support) must
82 | be the same as `vec_1`
83 | `sigma`: standard deviation for the Gaussian kernel
84 | Returns a scalar similarity value between 0 and 1.
85 | """
86 | assert vec_1.shape == vec_2.shape
87 | # Since the vectors are supposed to be PMFs, if everything is 0 then just
88 | # turn it into an (unnormalized) uniform distribution
89 | if np.all(vec_1 == 0):
90 | vec_1 = np.ones_like(vec_1)
91 | if np.all(vec_2 == 0):
92 | vec_2 = np.ones_like(vec_2)
93 | # The SciPy Wasserstein distance function takes in empirical observations
94 | # instead of histograms/distributions as an input, but we can get the same
95 | # result by specifying weights which are the PMF probabilities
96 | w_dist = scipy.stats.wasserstein_distance(
97 | np.arange(len(vec_1)), np.arange(len(vec_1)), vec_1, vec_2
98 | )
99 | return np.exp(-(w_dist * w_dist) / (2 * sigma * sigma))
100 |
101 |
102 | def gaussian_total_variation_kernel(vec_1, vec_2, sigma=1):
103 | """
104 | Computes the Gaussian kernel function on two vectors, where the similarity
105 | metric within the Gaussian is the total variation between the two vectors.
106 | Arguments:
107 | `vec_1`: a NumPy array of values
108 | `vec_2`: a NumPy array of values; the underlying vector space must be
109 | the same as `vec_1`
110 | `sigma`: standard deviation for the Gaussian kernel
111 | Returns a scalar similarity value between 0 and 1.
112 | """
113 | tv_dist = np.sum(np.abs(vec_1 - vec_2)) / 2
114 | return np.exp(-(tv_dist * tv_dist) / (2 * sigma * sigma))
115 |
116 |
117 | def compute_inner_prod_feature_mean(dist_1, dist_2, kernel_type, **kwargs):
118 | """
119 | Given two empirical distributions of vectors, computes the inner product of
120 | the feature means using the specified kernel. This is equivalent to the
121 | expected/average kernel function on all pairs of vectors between the two
122 | distributions.
123 | Arguments:
124 | `dist_1`: an M x D NumPy array of M vectors, each of size D; all vectors
125 | must share the same underlying vector space (or probability space if
126 | the vectors represent a probability distribution) with each other
127 | and with `dist_2`
128 | `dist_2`: an M x D NumPy array of M vectors, each of size D; all vectors
129 | must share the same underlying vector space (or probability space if
130 | the vectors represent a probability distribution) with each other
131 | and with `dist_1`
132 | `kernel_type`: type of kernel to apply for computing the kernelized
133 | inner product; can be "gaussian", "gaussian_wasserstein", or
134 | "gaussian_total_variation"
135 | `kwargs`: extra keyword arguments to be passed to the kernel function
136 | Returns a scalar which is the average kernelized inner product between all
137 | pairs of vectors across the two distributions.
138 | """
139 | if kernel_type == "gaussian":
140 | kernel_func = gaussian_kernel
141 | elif kernel_type == "gaussian_wasserstein":
142 | kernel_func = gaussian_wasserstein_kernel
143 | elif kernel_type == "gaussian_total_variation":
144 | kernel_func = gaussian_total_variation_kernel
145 | else:
146 | raise ValueError("Unknown kernel type: %s" % kernel_type)
147 |
148 | inner_prods = []
149 | for vec_1 in dist_1:
150 | for vec_2 in dist_2:
151 | inner_prods.append(kernel_func(vec_1, vec_2, **kwargs))
152 |
153 | return np.mean(inner_prods)
154 |
155 |
156 | def compute_maximum_mean_discrepancy(
157 | dist_1, dist_2, kernel_type, normalize=True, **kwargs
158 | ):
159 | """
160 | Given two empirical distributions of vectors, computes the maximum mean
161 | discrepancy (MMD) between the two distributions.
162 | Arguments:
163 | `dist_1`: an M x D NumPy array of M vectors, each of size D; all vectors
164 | must share the same underlying vector space (or probability space if
165 | the vectors represent a probability distribution) with each other
166 | and with `dist_2`
167 | `dist_2`: an M x D NumPy array of M vectors, each of size D; all vectors
168 | must share the same underlying vector space (or probability space if
169 | the vectors represent a probability distribution) with each other
170 | and with `dist_1`
171 | `kernel_type`: type of kernel to apply for computing the kernelized
172 | inner product; can be "gaussian", "gaussian_wasserstein", or
173 | "gaussian_total_variation"
174 | `normalize`: if True, normalize each D-vector to sum to 1
175 | `kwargs`: extra keyword arguments to be passed to the kernel function
176 | Returns the scalar MMD value.
177 | """
178 | if normalize:
179 | dist_1 = dist_1 / np.sum(dist_1, axis=1, keepdims=True)
180 | dist_2 = dist_2 / np.sum(dist_2, axis=1, keepdims=True)
181 |
182 | term_1 = compute_inner_prod_feature_mean(
183 | dist_1, dist_1, kernel_type, **kwargs
184 | )
185 | term_2 = compute_inner_prod_feature_mean(
186 | dist_2, dist_2, kernel_type, **kwargs
187 | )
188 | term_3 = compute_inner_prod_feature_mean(
189 | dist_1, dist_2, kernel_type, **kwargs
190 | )
191 | return np.sqrt(term_1 + term_2 - (2 * term_3))
192 |
193 |
194 | if __name__ == "__main__":
195 | import networkx as nx
196 | import graph_metrics
197 |
198 | graphs = [
199 | nx.erdos_renyi_graph(
200 | np.random.choice(np.arange(10, 20)),
201 | np.random.choice(np.linspace(0, 1, 10))
202 | ) for _ in range(50)
203 | ] + [
204 | nx.erdos_renyi_graph(
205 | np.random.choice(np.arange(10, 20)),
206 | np.random.choice(np.linspace(0, 1, 10))
207 | ) for _ in range(50)
208 | ]
209 |
210 | degrees = graph_metrics.get_degrees(graphs)
211 | cluster_coefs = graph_metrics.get_clustering_coefficients(graphs)
212 | spectra = graph_metrics.get_spectra(graphs)
213 | orbit_counts = graph_metrics.get_orbit_counts(graphs)
214 | orbit_counts = np.stack([
215 | np.mean(counts, axis=0) for counts in orbit_counts
216 | ])
217 |
218 | kernel_type = "gaussian_total_variation"
219 |
220 | degree_hists = make_histograms(degrees, bin_width=1)
221 | degree_mmd = compute_maximum_mean_discrepancy(
222 | degree_hists[:50], degree_hists[50:], kernel_type, sigma=1
223 | )
224 | cluster_coef_hists = make_histograms(cluster_coefs, num_bins=100)
225 | cluster_coef_mmd = compute_maximum_mean_discrepancy(
226 | cluster_coef_hists[:50], cluster_coef_hists[50:], kernel_type, sigma=0.1
227 | )
228 | spectra_hists = make_histograms(
229 | spectra, bin_array=np.linspace(-1e-5, 2, 200 + 1)
230 | )
231 | spectra_mmd = compute_maximum_mean_discrepancy(
232 | spectra_hists[:50], spectra_hists[50:], kernel_type, sigma=1
233 | )
234 | orbit_mmd = compute_maximum_mean_discrepancy(
235 | orbit_counts[:50], orbit_counts[50:], kernel_type, normalize=False,
236 | sigma=30
237 | )
238 |
239 | print("MMD values")
240 | print("Degree: %.15f" % np.square(degree_mmd))
241 | print("Clustering coefficient: %.15f" % np.square(cluster_coef_mmd))
242 | print("Spectrum: %.15f" % np.square(spectra_mmd))
243 | print("Orbit: %.15f" % np.square(orbit_mmd))
244 |
--------------------------------------------------------------------------------
/src/analysis/orca.cpp:
--------------------------------------------------------------------------------
1 | // Downloaded from https://github.com/thocevar/orca
2 | // Compile with `g++ -O2 -std=c++11 -o orca orca.cpp`
3 |
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | using namespace std;
16 |
17 |
18 | typedef long long int64;
19 | typedef pair PII;
20 | typedef struct { int first, second, third; } TIII;
21 |
22 | struct PAIR {
23 | int a, b;
24 | PAIR(int a0, int b0) { a=min(a0,b0); b=max(a0,b0); }
25 | };
26 | bool operator<(const PAIR &x, const PAIR &y) {
27 | if (x.a==y.a) return x.bb) swap(a,b);
44 | if (b>c) swap(b,c);
45 | if (a>b) swap(a,b);
46 | }
47 | };
48 | bool operator<(const TRIPLE &x, const TRIPLE &y) {
49 | if (x.a==y.a) {
50 | if (x.b==y.b) return x.c common2;
64 | unordered_map common3;
65 | unordered_map::iterator common2_it;
66 | unordered_map::iterator common3_it;
67 |
68 | #define common3_get(x) (((common3_it=common3.find(x))!=common3.end())?(common3_it->second):0)
69 | #define common2_get(x) (((common2_it=common2.find(x))!=common2.end())?(common2_it->second):0)
70 |
71 | int n,m; // n = number of nodes, m = number of edges
72 | int *deg; // degrees of individual nodes
73 | PAIR *edges; // list of edges
74 |
75 | int **adj; // adj[x] - adjacency list of node x
76 | PII **inc; // inc[x] - incidence list of node x: (y, edge id)
77 | bool adjacent_list(int x, int y) { return binary_search(adj[x],adj[x]+deg[x],y); }
78 | int *adj_matrix; // compressed adjacency matrix
79 | const int adj_chunk = 8*sizeof(int);
80 | bool adjacent_matrix(int x, int y) { return adj_matrix[(x*n+y)/adj_chunk]&(1<<((x*n+y)%adj_chunk)); }
81 | bool (*adjacent)(int,int);
82 | int getEdgeId(int x, int y) { return inc[x][lower_bound(adj[x],adj[x]+deg[x],y)-adj[x]].second; }
83 |
84 | int64 **orbit; // orbit[x][o] - how many times does node x participate in orbit o
85 | int64 **eorbit; // eorbit[x][o] - how many times does node x participate in edge orbit o
86 |
87 |
88 | /** count graphlets on max 4 nodes */
89 | void count4() {
90 | clock_t startTime, endTime;
91 | startTime = clock();
92 | clock_t startTime_all, endTime_all;
93 | startTime_all = startTime;
94 | int frac,frac_prev;
95 |
96 | // precompute triangles that span over edges
97 | printf("stage 1 - precomputing common nodes\n");
98 | int *tri = (int*)calloc(m,sizeof(int));
99 | frac_prev=-1;
100 | for (int i=0;i= x) break;
131 | nn=0;
132 | for (int ny=0;ny= y) break;
135 | if (adjacent(x,z)==0) continue;
136 | neigh[nn++]=z;
137 | }
138 | for (int i=0;i= x) break;
292 | nn=0;
293 | for (int ny=0;ny= y) break;
296 | if (neighx[z]==-1) continue;
297 | int xz=neighx[z];
298 | neigh[nn]=z;
299 | neigh_edges[nn]={xz, yz};
300 | nn++;
301 | }
302 | for (int i=0;i=0;nx--) {
333 | int y=inc[x][nx].first, xy=inc[x][nx].second;
334 | if (y <= x) break;
335 | nn=0;
336 | for (int ny=deg[y]-1;ny>=0;ny--) {
337 | int z=adj[y][ny];
338 | if (z <= y) break;
339 | if (adjacent(x,z)==0) continue;
340 | neigh[nn++]=z;
341 | }
342 | for (int i=0;i= x) break;
497 | nn=0;
498 | for (int ny=0;ny= y) break;
501 | if (adjacent(x,z)) {
502 | neigh[nn++]=z;
503 | }
504 | }
505 | for (int i=0;i2 && tri[xb]>2)?(common3_get(TRIPLE(x,a,b))-1):0;
604 | f_71 += (tri[xa]>2 && tri[xc]>2)?(common3_get(TRIPLE(x,a,c))-1):0;
605 | f_71 += (tri[xb]>2 && tri[xc]>2)?(common3_get(TRIPLE(x,b,c))-1):0;
606 | f_67 += tri[xa]-2+tri[xb]-2+tri[xc]-2;
607 | f_66 += common2_get(PAIR(a,b))-2;
608 | f_66 += common2_get(PAIR(a,c))-2;
609 | f_66 += common2_get(PAIR(b,c))-2;
610 | f_58 += deg[x]-3;
611 | f_57 += deg[a]-3+deg[b]-3+deg[c]-3;
612 | }
613 | }
614 |
615 | // x = orbit-13 (diamond)
616 | for (int nx2=0;nx21 && tri[xc]>1)?(common3_get(TRIPLE(x,b,c))-1):0;
624 | f_68 += common3_get(TRIPLE(a,b,c))-1;
625 | f_64 += common2_get(PAIR(b,c))-2;
626 | f_61 += tri[xb]-1+tri[xc]-1;
627 | f_60 += common2_get(PAIR(a,b))-1;
628 | f_60 += common2_get(PAIR(a,c))-1;
629 | f_55 += tri[xa]-2;
630 | f_48 += deg[b]-2+deg[c]-2;
631 | f_42 += deg[x]-3;
632 | f_41 += deg[a]-3;
633 | }
634 | }
635 |
636 | // x = orbit-12 (diamond)
637 | for (int nx2=nx1+1;nx21)?common3_get(TRIPLE(a,b,c)):0;
645 | f_63 += common_x[c]-2;
646 | f_59 += tri[ac]-1+common2_get(PAIR(b,c))-1;
647 | f_54 += common2_get(PAIR(a,b))-2;
648 | f_47 += deg[x]-2;
649 | f_46 += deg[c]-2;
650 | f_40 += deg[a]-3+deg[b]-3;
651 | }
652 | }
653 |
654 | // x = orbit-8 (cycle)
655 | for (int nx2=nx1+1;nx20)?common3_get(TRIPLE(a,b,c)):0;
663 | f_53 += tri[xa]+tri[xb];
664 | f_51 += tri[ac]+common2_get(PAIR(c,b));
665 | f_50 += common_x[c]-2;
666 | f_49 += common_a[b]-2;
667 | f_38 += deg[x]-2;
668 | f_37 += deg[a]-2+deg[b]-2;
669 | f_36 += deg[c]-2;
670 | }
671 | }
672 |
673 | // x = orbit-11 (paw)
674 | for (int nx2=nx1+1;nx21 && tri[ac]>1)?common3_get(TRIPLE(a,b,c)):0;
713 | f_45 += common2_get(PAIR(b,c))-1;
714 | f_39 += tri[ab]-1+tri[ac]-1;
715 | f_31 += deg[a]-3;
716 | f_28 += deg[x]-1;
717 | f_24 += deg[b]-2+deg[c]-2;
718 | }
719 | }
720 |
721 | // x = orbit-4 (path)
722 | for (int na=0;na= x) break;
917 | nn=0;
918 | for (int ny=0;ny= y) break;
921 | if (neighx[z]==-1) continue;
922 | int xz=neighx[z];
923 | neigh[nn]=z;
924 | neigh_edges[nn]={xz, yz};
925 | nn++;
926 | }
927 | for (int i=0;i=x) break;
997 |
998 | // common nodes of y and some other node
999 | for (int i=0;i> n >> m;
1378 | int d_max=0;
1379 | edges = (PAIR*)malloc(m*sizeof(PAIR));
1380 | deg = (int*)calloc(n,sizeof(int));
1381 | for (int i=0;i> a >> b;
1384 | if (!(0<=a && a(edges,edges+m).size())!=m) {
1401 | cerr << "Input file contains duplicate undirected edges." << endl;
1402 | return 0;
1403 | }
1404 | // set up adjacency matrix if it's smaller than 100MB
1405 | if ((int64)n*n < 100LL*1024*1024*8) {
1406 | adjacent = adjacent_matrix;
1407 | adj_matrix = (int*)calloc((n*n)/adj_chunk+1,sizeof(int));
1408 | for (int i=0;i= len(graph_sizes):
82 | # There are more indices in `batch` than there are pointers, so cut off
83 | # the excess
84 | adj_matrix = adj_matrix[:len(graph_sizes)]
85 |
86 | # Create boolean mask of only the top upper triangle of each B x V x V
87 | # matrix, ignoring the diagonal
88 | triu_mask = torch.triu(torch.ones_like(adj_matrix), diagonal=1) == 1
89 | for i, size in enumerate(torch.diff(data.ptr)):
90 | # TODO: make this more efficient
91 | # For each individual graph, limit the upper-triangle mask to
92 | # only the size of that graph
93 | triu_mask[i, :, size:] = False
94 |
95 | # Edge vector is all entries in the graph-specific upper triangle of each
96 | # individual adjacency matrix
97 | edge_vec = adj_matrix[triu_mask]
98 |
99 | if return_batch_inds:
100 | # Number of edges in each graph:
101 | edge_counts = ((graph_sizes * (graph_sizes - 1)) / 2).int()
102 | edge_counts_cumsum = torch.cumsum(edge_counts, dim=0)
103 |
104 | # Create binary marker array, which is all 0s except with 1s
105 | # wherever we switch to a new graph
106 | markers = torch.zeros(
107 | edge_counts_cumsum[-1], dtype=torch.int, device=DEVICE
108 | )
109 | markers[edge_counts_cumsum[:-1]] = 1
110 |
111 | batch_inds = torch.cumsum(markers, dim=0)
112 | return edge_vec, batch_inds
113 | return edge_vec
114 |
115 |
116 | def edge_vector_to_pyg_data(data, edges, reflect=True):
117 | """
118 | Returns the edge-index tensor which would be associated with the
119 | torch-geometric Data object `data`, if the edges in the edge index attribute
120 | were set according to the given edges. Note that self-edges are not allowed.
121 | If `edges` is a scalar 1, then this returns the set of all edges as an
122 | edge-index tensor.
123 | Arguments:
124 | `data`: a torch-geometric Data object
125 | `edges`: a binary E-tensor of edges in canonical order, or the scalar 1
126 | `reflect`: by default, each edge will be represented twice in the
127 | edge-index tensor (no self-edges are allowed); if False, only the
128 | upper-triangular indices will be present (and thus the tensor's size
129 | will be halved)
130 | Returns a 2 x E' edge-index tensor (type long).
131 | """
132 | graph_sizes = torch.diff(data.ptr)
133 | max_size = torch.max(graph_sizes)
134 |
135 | # Create filler adjacency matrix that starts out as all 0s
136 | adj_matrix = torch.zeros(
137 | graph_sizes.shape[0], max_size, max_size, device=DEVICE
138 | )
139 |
140 | # Create boolean mask of only the top upper triangle of each B x V x V
141 | # matrix, ignoring the diagonal
142 | triu_mask = torch.triu(torch.ones_like(adj_matrix), diagonal=1) == 1
143 | for i, size in enumerate(torch.diff(data.ptr)):
144 | # TODO: make this more efficient
145 | # For each individual graph, limit the upper-triangle mask to
146 | # only the size of that graph
147 | triu_mask[i, :, size:] = False
148 |
149 | # Set the upper triangle of each graph-specific adjacency matrix to
150 | # the edges given,
151 | adj_matrix[triu_mask] = edges
152 |
153 | # Symmetrize the matrix
154 | if reflect:
155 | adj_matrix = adj_matrix + torch.transpose(adj_matrix, 1, 2)
156 |
157 | # Get indices where the adjacency matrix is nonzero (an E x 3 matrix)
158 | nonzero_inds = adj_matrix.nonzero()
159 |
160 | # The indices are for each individual graph, so add the graph-size
161 | # pointers to each set of indices so later graphs have higher
162 | # indices based on the sizes of earlier graphs
163 | edges_to_set = nonzero_inds[:, 1:] + data.ptr[nonzero_inds[:, 0]][:, None]
164 |
165 | # Convert to a 2 x E matrix
166 | edges_to_set = edges_to_set.t().contiguous()
167 | return torch_geometric.utils.sort_edge_index(edges_to_set).long()
168 |
169 |
170 | def split_pyg_data_to_nx_graphs(data):
171 | """
172 | Given a torch-geometric Data object, splits the objects into an ordered list
173 | of NetworkX graphs. The NetworkX graphs will be undirected (no self-edges
174 | allowed), and node features will be under the attribute "feats".
175 | Arguments:
176 | `data`: a batched torch-geometric Data object
177 | Returns an ordered list of NetworkX graph objects.
178 | """
179 | graphs = []
180 | pointers = data.ptr.cpu().numpy()
181 | graph_sizes = np.diff(pointers)
182 | num_graphs = len(graph_sizes)
183 |
184 | # First, get (padded) adjacency matrix of size B x V x V, where V is
185 | # the maximum number of nodes in any individual graph
186 | adj_matrix = torch_geometric.utils.to_dense_adj(data.edge_index, data.batch)
187 |
188 | for i in range(num_graphs):
189 | graph = nx.empty_graph(graph_sizes[i])
190 |
191 | # Get the indices of adjacency matrix, upper triangle only
192 | edge_indices = torch.triu(
193 | adj_matrix[i], diagonal=1
194 | ).nonzero().cpu().numpy()
195 | for u, v in edge_indices:
196 | graph.add_edge(u, v)
197 |
198 | node_feats = data.x[pointers[i] : pointers[i + 1]].cpu().numpy()
199 | feat_dict = {i : node_feats[i] for i in range(graph.number_of_nodes())}
200 | nx.set_node_attributes(graph, feat_dict, "feats")
201 |
202 | graphs.append(graph)
203 | return graphs
204 |
205 |
206 | def add_virtual_nodes(data):
207 | """
208 | Given a torch-geometric Data object, adds a virtual node to each individual
209 | graph in the batch. This modifies `data` in place. The attributes are
210 | modified as follows:
211 | `x`: new row of all 0s added for each virtual node
212 | `edge_index`: edges from each virtual node to all other nodes in its
213 | graph
214 | `batch`: new entry with unique index for each virtual node
215 | `edge_type`: introduces new edge type for each edge added
216 | `ptr`: unmodified
217 | Arguments:
218 | `data`: a batched torch-geometric Data object
219 | """
220 | # Code adapted from `torch_geometric.transforms.virtual_node.VirtualNode`
221 | # We write our own function to make sure we add a distinct virtual node for
222 | # each graph in the batch, and in light of the issue here:
223 | # https://github.com/pyg-team/pytorch_geometric/issues/5818
224 | updates = {
225 | "x": [data.x],
226 | "edge_index": [data.edge_index],
227 | "batch": [data.batch],
228 | "edge_type": [
229 | data.get("edge_type", torch.zeros_like(data.edge_index[0]))
230 | ]
231 | }
232 | if "num_nodes" in data:
233 | num_nodes = data.num_nodes
234 |
235 | graph_sizes = torch.diff(data.ptr)
236 | num_graphs = len(graph_sizes)
237 | device = data.x.device
238 |
239 | if not updates["edge_type"][0].numel():
240 | new_edge_type = 1
241 | else:
242 | new_edge_type = int(torch.max(updates["edge_type"][0])) + 1
243 | next_node_i = len(data.x)
244 | next_batch_i = num_graphs
245 |
246 | for i in range(num_graphs):
247 | start_node_i, end_node_i = data.ptr[i], data.ptr[i + 1]
248 | graph_size = graph_sizes[i]
249 |
250 | # New edges
251 | new_edges_1 = torch.full((graph_size,), next_node_i, device=device)
252 | new_edges_2 = torch.arange(start_node_i, end_node_i, device=device)
253 | new_edges = torch.cat([
254 | torch.stack([new_edges_1, new_edges_2], dim=0),
255 | torch.stack([new_edges_2, new_edges_1], dim=0),
256 | ], dim=1)
257 | updates["edge_index"].append(new_edges)
258 |
259 | new_edge_types_1 = torch.full(
260 | new_edges_1.shape, new_edge_type, device=device
261 | )
262 | new_edge_types_2 = new_edge_types_1 + 1
263 | new_edge_types = torch.cat([new_edge_types_1, new_edge_types_2], dim=0)
264 | updates["edge_type"].append(new_edge_types)
265 |
266 | # New virtual node
267 | new_x = torch.zeros_like(data.x[0])[None]
268 | updates["x"].append(new_x)
269 | next_node_i = next_node_i + 1
270 |
271 | # New batch
272 | new_batch = torch.tensor([next_batch_i], device=device)
273 | updates["batch"].append(new_batch)
274 | next_batch_i = next_batch_i + 1
275 |
276 | data.edge_index = torch.cat(updates["edge_index"], dim=1)
277 | data.edge_type = torch.cat(updates["edge_type"], dim=0)
278 | data.x = torch.cat(updates["x"], dim=0)
279 | data.batch = torch.cat(updates["batch"], dim=0)
280 |
281 | if "num_nodes" in data:
282 | data.num_nodes = num_nodes + num_graphs
283 |
--------------------------------------------------------------------------------
/src/feature/molecule_dataset.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import rdkit.Chem
3 | import torch
4 | import torch_geometric
5 | import networkx as nx
6 |
7 | # Define device
8 | if torch.cuda.is_available():
9 | DEVICE = "cuda"
10 | else:
11 | DEVICE = "cpu"
12 |
13 |
14 | ZINC250K_PATH = "/gstore/home/tsenga5/discrete_graph_diffusion/data/250k_rndm_zinc_drugs_clean_3.csv"
15 |
16 |
17 | def smiles_to_networkx(smiles):
18 | """
19 | Converts a SMILES string to a NetworkX graph. The graph will retain the
20 | atomic number and bond type for nodes and edges (respectively), under the
21 | keys `atomic_num` and `bond_type` (respectively).
22 | Arguments:
23 | `smiles`: a SMILES string
24 | Returns a NetworkX graph.
25 | """
26 | mol = rdkit.Chem.MolFromSmiles(smiles)
27 | g = nx.Graph()
28 | for atom in mol.GetAtoms():
29 | g.add_node(
30 | atom.GetIdx(),
31 | atomic_num=atom.GetAtomicNum()
32 | )
33 | for bond in mol.GetBonds():
34 | g.add_edge(
35 | bond.GetBeginAtomIdx(),
36 | bond.GetEndAtomIdx(),
37 | bond_type=bond.GetBondType()
38 | )
39 | return g
40 |
41 |
42 | ATOM_MAP = torch.tensor([6, 7, 8, 9, 16, 17, 35, 53, 15])
43 | BOND_MAP = torch.tensor([1, 2, 3, 12])
44 |
45 | def smiles_to_pyg_data(smiles, ignore_edge_attr=False):
46 | """
47 | Converts a SMILES string to a torch-geometric Data object. The data object
48 | will have node attributes and edge attributes under `x` and `edge_attr`,
49 | respectively.
50 | Arguments:
51 | `smiles`: a SMILES string
52 | `ignore_edge_attr`: if True, no edge attributes will be included
53 | Returns a torch-geometric Data object.
54 | """
55 | g = smiles_to_networkx(smiles)
56 | data = torch_geometric.utils.from_networkx(g)
57 |
58 | # Set atom features
59 | atom_inds = torch.argmax(
60 | (data.atomic_num.view(-1, 1) == ATOM_MAP).int(), dim=1
61 | )
62 | data.x = torch.nn.functional.one_hot(atom_inds, num_classes=len(ATOM_MAP))
63 |
64 | if not ignore_edge_attr:
65 | # Set bond features
66 | # For aromatic bonds, set them to be both single and double
67 | aromatic_mask = data.bond_type == BOND_MAP[-1]
68 | bond_inds = torch.argmax(
69 | (data.bond_type.view(-1, 1) == BOND_MAP).int(), dim=1
70 | )
71 | bond_inds[aromatic_mask] = 0
72 | data.edge_attr = torch.nn.functional.one_hot(
73 | bond_inds, num_classes=(len(BOND_MAP) - 1)
74 | )
75 | data.edge_attr[aromatic_mask, 1] = 1
76 |
77 | del data.atomic_num
78 | del data.bond_type
79 |
80 | return data
81 |
82 |
83 | class ZINCDataset(torch.utils.data.Dataset):
84 | def __init__(self, connectivity_only=False):
85 | """
86 | Create a PyTorch IterableDataset which yields molecular graphs.
87 | Arguments:
88 | `connectivity_only`: if True, only connectivity information is
89 | retained, and no edge attributes will be included
90 | """
91 | super().__init__()
92 |
93 | self.connectivity_only = connectivity_only
94 | self.node_dim = len(ATOM_MAP)
95 |
96 | zinc_table = pd.read_csv(ZINC250K_PATH, sep=",", header=0)
97 | zinc_table["smiles"] = zinc_table["smiles"].str.strip()
98 | self.all_smiles = zinc_table["smiles"]
99 |
100 | def __getitem__(self, index):
101 | """
102 | Returns the torch-geometric Data object representing the molecule at
103 | index `index` in `self.all_smiles`.
104 | """
105 | data = smiles_to_pyg_data(self.all_smiles[index])
106 | data.edge_index = torch_geometric.utils.sort_edge_index(data.edge_index)
107 | return data.to(DEVICE)
108 |
109 | def __len__(self):
110 | return len(self.all_smiles)
111 |
112 |
113 | if __name__ == "__main__":
114 | dataset = ZINCDataset(connectivity_only=True)
115 | data_loader = torch_geometric.loader.DataLoader(
116 | dataset, batch_size=32, shuffle=False
117 | )
118 | batch = next(iter(data_loader))
119 |
--------------------------------------------------------------------------------
/src/feature/random_graph_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric
3 | import numpy as np
4 | import networkx as nx
5 | import scipy.spatial
6 |
7 | # Define device
8 | if torch.cuda.is_available():
9 | DEVICE = "cuda"
10 | else:
11 | DEVICE = "cpu"
12 |
13 |
14 | def create_tree(node_dim, num_nodes=10, noise_level=1):
15 | """
16 | Creates a random connected tree. The node attributes will be initialized by
17 | a random vector plus the distance from a randomly selected source node, plus
18 | noise.
19 | Arguments:
20 | `node_dim`: size of node feature vector
21 | `num_nodes`: number of nodes in the graph, or an array to sample from
22 | `noise_level`: standard deviation of Gaussian noise to add to distances
23 | Returns a NetworkX Graph with NumPy arrays as node attributes.
24 | """
25 | if type(num_nodes) is not int:
26 | num_nodes = np.random.choice(num_nodes)
27 |
28 | g = nx.random_tree(num_nodes)
29 |
30 | node_features = np.empty((num_nodes, node_dim))
31 |
32 | # Pick a random source node
33 | source = np.random.choice(num_nodes)
34 |
35 | # Set source node's feature to random vector
36 | source_feat = np.random.randn(node_dim) * 2 * np.sqrt(num_nodes)
37 | node_features[source] = source_feat
38 |
39 | # Run BFS starting from source node; for each layer, set features
40 | # to be the source vector plus the distance plus noise
41 | current_layer = [source]
42 | distance = 1
43 | visited = set(current_layer)
44 | while current_layer:
45 | next_layer = []
46 | for node in current_layer:
47 | for child in g[node]:
48 | if child not in visited:
49 | visited.add(child)
50 | next_layer.append(child)
51 | node_features[child] = source_feat + distance + (
52 | np.random.randn(node_dim) * noise_level
53 | )
54 | current_layer = next_layer
55 | distance += 1
56 |
57 | nx.set_node_attributes(
58 | g, {i : node_features[i] for i in range(num_nodes)}, "feats"
59 | )
60 | return g
61 |
62 |
63 | def create_uniform_cliques(
64 | node_dim, num_nodes=10, clique_size=6, noise_level=1
65 | ):
66 | """
67 | Creates a random graph of disconnected cliques. The node attributes will be
68 | initialized by a constant vector for each clique, plus some noise. If
69 | `clique_size` does not divide `num_nodes`, there will be a smaller clique.
70 | Arguments:
71 | `node_dim`: size of node feature vector
72 | `num_nodes`: number of nodes in the graph, or an array to sample from
73 | `clique_size`: size of cliques
74 | `noise_level`: standard deviation of Gaussian noise to add to node
75 | features
76 | Returns a NetworkX Graph with NumPy arrays as node attributes.
77 | """
78 | if type(num_nodes) is not int:
79 | num_nodes = np.random.choice(num_nodes)
80 |
81 | g = nx.empty_graph()
82 |
83 | clique_count = 0
84 | while g.number_of_nodes() < num_nodes:
85 | size = min(clique_size, num_nodes - g.number_of_nodes())
86 |
87 | # Create clique
88 | clique = nx.complete_graph(size)
89 |
90 | # Create the core feature vector for the clique
91 | core = np.ones((node_dim, 1)) * clique_count * \
92 | (2 * num_nodes / clique_size)
93 |
94 | # Add a small bit of noise to for each node in the clique
95 | node_features = core + (np.random.randn(size, node_dim) * noise_level)
96 | nx.set_node_attributes(
97 | clique, {i : node_features[i] for i in range(size)}, "feats"
98 | )
99 |
100 | # Add the clique to the graph
101 | g = nx.disjoint_union(g, clique)
102 | clique_count += 1
103 |
104 | return g
105 |
106 |
107 | def create_diverse_cliques(
108 | node_dim, num_nodes=10, clique_sizes=[3, 4, 5], repeat=False, noise_level=1,
109 | unity_features=False
110 | ):
111 | """
112 | Creates a random graph of disconnected cliques. The node attributes will be
113 | initialized by a constant vector for each clique (which is the size of the
114 | clique), plus some noise. Leftover nodes will be singleton nodes.
115 | Arguments:
116 | `node_dim`: size of node feature vector
117 | `num_nodes`: number of nodes in the graph, or an array to sample from
118 | `clique_sizes`: iterable of clique sizes to use
119 | `repeat`: if False, all clique sizes will be unique
120 | `noise_level`: standard deviation of Gaussian noise to add to node
121 | features
122 | `unity_features`: if True, use 1 for all features instead of clique size
123 | Returns a NetworkX Graph with NumPy arrays as node attributes.
124 | """
125 | if type(num_nodes) is not int:
126 | num_nodes = np.random.choice(num_nodes)
127 |
128 | g = nx.empty_graph()
129 |
130 | clique_sizes = np.unique(clique_sizes)
131 |
132 | sizes_to_make, size_left = [], num_nodes
133 | if repeat:
134 | while clique_sizes.size:
135 | size = np.random.choice(clique_sizes)
136 | sizes_to_make.append(size)
137 | size_left -= size
138 | clique_sizes = clique_sizes[clique_sizes >= size_left]
139 | else:
140 | clique_sizes = np.random.permutation(clique_sizes)
141 | for size in clique_sizes:
142 | if size <= size_left:
143 | sizes_to_make.append(size)
144 | size_left -= size
145 | sizes_to_make.extend([1] * size_left)
146 |
147 | for size in sizes_to_make:
148 | # Create clique
149 | clique = nx.complete_graph(size)
150 |
151 | # Create the core feature vector for the clique
152 | core = np.ones((size, 1)) * (1 if unity_features else size)
153 |
154 | # Add a small bit of noise to for each node in the clique
155 | node_features = core + (np.random.randn(size, node_dim) * noise_level)
156 | nx.set_node_attributes(
157 | clique, {i : node_features[i] for i in range(size)}, "feats"
158 | )
159 |
160 | # Add the clique to the graph
161 | g = nx.disjoint_union(g, clique)
162 |
163 | return g
164 |
165 |
166 | def create_degree_graph(
167 | node_dim, num_nodes=10, edge_prob=0.2, noise_level=1
168 | ):
169 | """
170 | Creates a random Erdos-Renyi graph where the node attributes are a constant
171 | vector of the degree of the node, plus some noise.
172 | Arguments:
173 | `node_dim`: size of node feature vector
174 | `num_nodes`: number of nodes in the graph, or an array to sample from
175 | `edge_prob`: probability of edges in Erdos-Renyi graph
176 | `noise_level`: standard deviation of Gaussian noise to add to node
177 | features
178 | Returns a NetworkX Graph with NumPy arrays as node attributes.
179 | """
180 | if type(num_nodes) is not int:
181 | num_nodes = np.random.choice(num_nodes)
182 |
183 | g = nx.erdos_renyi_graph(num_nodes, edge_prob)
184 |
185 | degrees = dict(g.degree())
186 | node_features = np.tile(
187 | np.array([degrees[n] for n in range(len(g))])[:, None],
188 | (1, node_dim)
189 | )
190 | node_features = node_features + \
191 | (np.random.randn(num_nodes, node_dim) * noise_level)
192 |
193 | nx.set_node_attributes(
194 | g, {i : node_features[i] for i in range(num_nodes)}, "feats"
195 | )
196 |
197 | return g
198 |
199 |
200 | def create_planar_graph(node_dim, num_nodes=64):
201 | """
202 | Creates a planar graph using the Delaunay triangulation algorithm.
203 | All nodes will be given a feature vector of all 1s.
204 | Arguments:
205 | `node_dim`: size of node feature vector
206 | `num_nodes`: number of nodes in the graph, or an array to sample from
207 | Returns a NetworkX Graph with NumPy arrays as node attributes.
208 | """
209 | if type(num_nodes) is not int:
210 | num_nodes = np.random.choice(num_nodes)
211 |
212 | # Sample points uniformly at random from unit square
213 | points = np.random.rand(num_nodes, 2)
214 |
215 | # Perform Delaunay triangulation
216 | tri = scipy.spatial.Delaunay(points)
217 |
218 | # Create graph and add edges from triangulation result
219 | g = nx.empty_graph(num_nodes)
220 | indptr, indices = tri.vertex_neighbor_vertices
221 | for i in range(num_nodes):
222 | for j in indices[indptr[i]:indptr[i + 1]]:
223 | g.add_edge(i, j)
224 |
225 | nx.set_node_attributes(
226 | g, {i : np.ones(node_dim) for i in range(num_nodes)}, "feats"
227 | )
228 |
229 | return g
230 |
231 |
232 | def create_community_graph(
233 | node_dim, num_nodes=np.arange(12, 21), num_comms=2,
234 | intra_comm_edge_prob=0.3, inter_comm_edge_frac=0.05
235 | ):
236 | """
237 | Creates a community graph following this paper:
238 | https://arxiv.org/abs/1802.08773
239 | The default values give the definition of a "community-small" graph in the
240 | above paper. Each community is a Erdos-Renyi graph, with a certain set
241 | number of edges connecting the communities sparsely (drawn uniformly).
242 | All nodes will be given a feature vector of all 1s.
243 | Arguments:
244 | `node_dim`: size of node feature vector
245 | `num_nodes`: number of nodes in the graph, or an array to sample from
246 | `num_comms`: number of communities to create
247 | `intra_comm_edge_prob`: probability of edge in Erdos-Renyi graph for
248 | each community
249 | `inter_comm_edge_frac`: number of edges to draw between each pair of
250 | communities, as a fraction of `num_nodes`; edges are drawn uniformly
251 | at random between communities
252 | Returns a NetworkX Graph with NumPy arrays as node attributes.
253 | """
254 | if type(num_nodes) is not int:
255 | num_nodes = np.random.choice(num_nodes)
256 |
257 | # Create communities
258 | exp_size = int(num_nodes / num_comms)
259 | comm_sizes = []
260 | total_size = 0
261 | g = nx.empty_graph()
262 | while total_size < num_nodes:
263 | size = min(exp_size, num_nodes - total_size)
264 | g = nx.disjoint_union(
265 | g, nx.erdos_renyi_graph(size, intra_comm_edge_prob)
266 | )
267 | comm_sizes.append(size)
268 | total_size += size
269 |
270 | # Link together communities
271 | node_inds = np.cumsum(comm_sizes)
272 | num_inter_edges = int(num_nodes * inter_comm_edge_frac)
273 | for i in range(num_comms):
274 | for j in range(i):
275 | i_nodes = np.arange(node_inds[i - 1] if i else 0, node_inds[i])
276 | j_nodes = np.arange(node_inds[j - 1] if j else 0, node_inds[j])
277 | for _ in range(num_inter_edges):
278 | g.add_edge(
279 | np.random.choice(i_nodes), np.random.choice(j_nodes)
280 | )
281 |
282 | nx.set_node_attributes(
283 | g, {i : np.ones(node_dim) for i in range(num_nodes)}, "feats"
284 | )
285 |
286 | return g
287 |
288 |
289 | def create_sbm_graph(
290 | node_dim, num_blocks_arr=np.arange(2, 6), block_size_arr=np.arange(20, 41),
291 | intra_block_edge_prob=0.3, inter_block_edge_prob=0.05
292 | ):
293 | """
294 | Creates a stochastic-block-model graph, where the number of blocks and size
295 | of blocks is drawn randomly.
296 | All nodes will be given a feature vector of all 1s.
297 | Arguments:
298 | `node_dim`: size of node feature vector
299 | `num_blocks_arr`: iterable containing possible numbers of blocks to have
300 | (selected uniformly)
301 | `block_size_arr`: iterable containing possible block sizes for each
302 | block (selected uniformly per block)
303 | `intra_block_edge_prob`: probability of edge within blocks
304 | `inter_block_edge_prob`: probability of edge between blocks
305 | Returns a NetworkX Graph with NumPy arrays as node attributes.
306 | """
307 | num_blocks = np.random.choice(num_blocks_arr)
308 | block_sizes = np.random.choice(block_size_arr, num_blocks, replace=True)
309 | num_nodes = np.sum(block_sizes)
310 |
311 | # Create matrix of edge probabilities between blocks
312 | p = np.full((len(block_sizes), len(block_sizes)), inter_block_edge_prob)
313 | np.fill_diagonal(p, intra_block_edge_prob)
314 |
315 | # Create SBM graph
316 | g = nx.stochastic_block_model(block_sizes, p)
317 |
318 | nx.set_node_attributes(
319 | g, {i : np.ones(node_dim) for i in range(num_nodes)}, "feats"
320 | )
321 |
322 | # Delete these two attributes, or else conversion to PyTorch Geometric Data
323 | # object will fail
324 | del g.graph["partition"]
325 | del g.graph["name"]
326 |
327 | return g
328 |
329 |
330 | def create_molecule_like_graph(
331 | node_dim, num_nodes=10, backbone_size_range=(4, 6),
332 | ornament_size_range=(0, 4), ring_prob=0.5
333 | ):
334 | """
335 | Creates a random graph that looks a bit like a molecule. There are two types
336 | of nodes: backbone nodes (with node features of 0s) and ornamental nodes
337 | (with node features of 1s). Leftover nodes will be either backbone or
338 | ornamental nodes (selected uniformly at random). The backbone may be cyclic.
339 | Each backbone node can be connected to at most 4 other nodes (other backbone
340 | nodes or ornamental nodes).
341 | Arguments:
342 | `node_dim`: size of node feature vector
343 | `num_nodes`: number of nodes in the graph, or an array to sample from
344 | `backbone_size_range`: pair of minimum and maximum sizes of backbones to
345 | choose from
346 | `ornament_size_range`: pair of minimum and maximum number of ornaments
347 | to choose from; the actual number of ornaments added may be smaller
348 | if the backbone selected is too small
349 | `ring_prob`: probability of generating a ring for the backbone
350 | Returns a NetworkX Graph with NumPy arrays as node attributes.
351 | """
352 | if type(num_nodes) is not int:
353 | assert np.max(num_nodes) >= \
354 | backbone_size_range[1] + ornament_size_range[1]
355 | num_nodes = np.random.choice(num_nodes)
356 | else:
357 | assert num_nodes >= backbone_size_range[1] + ornament_size_range[1]
358 |
359 | backbone_size = np.random.randint(
360 | backbone_size_range[0], backbone_size_range[1] + 1
361 | )
362 | ornament_size = np.random.randint(
363 | ornament_size_range[0],
364 | min(ornament_size_range[1], (backbone_size * 2) + 2) + 1
365 | )
366 | ring = np.random.random() <= ring_prob
367 |
368 | # Create backbone
369 | if ring:
370 | g = nx.circulant_graph(backbone_size, [1])
371 | else:
372 | g = nx.empty_graph()
373 | g.add_node(0)
374 | for _ in range(backbone_size - 1):
375 | # Add leaf node to randomly selected node with degree < 4
376 | degrees = g.degree()
377 | connector = np.random.choice(
378 | [i for i in range(len(g)) if degrees[i] < 4]
379 | )
380 | new_node = len(g)
381 | g.add_node(new_node)
382 | g.add_edge(new_node, connector)
383 |
384 | # Add ornaments
385 | capacities = [(n, 4 - m) for n, m in g.degree() if m < 4]
386 | inds, counts = zip(*capacities)
387 | node_arr = np.repeat(inds, counts)
388 | connectors = np.random.choice(node_arr, size=ornament_size, replace=False)
389 | for connector in connectors:
390 | new_node = len(g)
391 | g.add_node(new_node)
392 | g.add_edge(new_node, connector)
393 |
394 | # Add extra singleton nodes
395 | extra_nodes = num_nodes - len(g)
396 | extra_backbone_size = extra_nodes // 2
397 | extra_ornament_size = extra_nodes - extra_backbone_size
398 | for _ in range(extra_backbone_size):
399 | g.add_node(len(g))
400 | for _ in range(extra_ornament_size):
401 | g.add_node(len(g))
402 |
403 | node_features = np.concatenate([
404 | np.zeros((backbone_size, node_dim)),
405 | np.ones((ornament_size, node_dim)),
406 | np.zeros((extra_backbone_size, node_dim)),
407 | np.ones((extra_ornament_size, node_dim))
408 | ], axis=0)
409 |
410 | nx.set_node_attributes(
411 | g, {i : node_features[i] for i in range(len(g))}, "feats"
412 | )
413 | return g
414 |
415 |
416 | class RandomGraphDataset(torch.utils.data.Dataset):
417 | def __init__(
418 | self, node_dim, graph_type="tree", num_items=1000, static=False,
419 | **kwargs
420 | ):
421 | """
422 | Create a PyTorch IterableDataset which yields random graphs.
423 | Arguments:
424 | `node_dim`: size of node feature vector
425 | `graph_type`: type of graph to generate; can be "tree",
426 | "uniform_cliques", "diverse_cliques", "degree", "planar",
427 | "community", or "sbm"
428 | `num_items`: number of items to yield in an epoch
429 | `static`: if True, generate `num_items` graphs initially and only
430 | yield pregenerated graphs
431 | `kwargs`: extra keyword arguments for the graph generator
432 | """
433 | super().__init__()
434 |
435 | self.node_dim = node_dim
436 | self.num_items = num_items
437 | self.static = static
438 | self.kwargs = kwargs
439 |
440 | if graph_type == "tree":
441 | self.graph_creater = create_tree
442 | elif graph_type == "uniform_cliques":
443 | self.graph_creater = create_uniform_cliques
444 | elif graph_type == "diverse_cliques":
445 | self.graph_creater = create_diverse_cliques
446 | elif graph_type == "degree":
447 | self.graph_creater = create_degree_graph
448 | elif graph_type == "planar":
449 | self.graph_creater = create_planar_graph
450 | elif graph_type == "community":
451 | self.graph_creater = create_community_graph
452 | elif graph_type == "sbm":
453 | self.graph_creater = create_sbm_graph
454 | elif graph_type == "molecule_like":
455 | self.graph_creater = create_molecule_like_graph
456 | else:
457 | raise ValueError("Unrecognize random graph type: %s" % graph_type)
458 |
459 | if static:
460 | self.graph_cache = [
461 | self.graph_creater(node_dim, **kwargs) for _ in range(num_items)
462 | ]
463 |
464 |
465 | def __getitem__(self, index):
466 | """
467 | Returns a single data point generated randomly, as a torch-geometric
468 | Data object. `index` is ignored.
469 | """
470 | if self.static:
471 | graph = self.graph_cache[index]
472 | else:
473 | graph = self.graph_creater(self.node_dim, **self.kwargs)
474 | data = torch_geometric.utils.from_networkx(
475 | graph, group_node_attrs=["feats"]
476 | )
477 | data.edge_index = torch_geometric.utils.sort_edge_index(data.edge_index)
478 | return data.to(DEVICE)
479 |
480 | def __len__(self):
481 | return self.num_items
482 |
483 |
484 | if __name__ == "__main__":
485 | node_dim = 5
486 |
487 | dataset = RandomGraphDataset(
488 | node_dim, num_items=100,
489 | graph_type="diverse_cliques", num_nodes=np.arange(10, 20),
490 | clique_sizes=[3, 4, 5, 6], noise_level=0
491 | )
492 | data_loader = torch_geometric.loader.DataLoader(
493 | dataset, batch_size=32, shuffle=False
494 | )
495 | batch = next(iter(data_loader))
496 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Genentech/GraphGUIDE/dad0dd371268684a5203839441febe0484d8a3e4/src/model/__init__.py
--------------------------------------------------------------------------------
/src/model/digress_gnn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn.modules.dropout import Dropout
5 | from torch.nn.modules.linear import Linear
6 | from torch.nn.modules.normalization import LayerNorm
7 | from torch.nn import functional as F
8 | from torch import Tensor
9 | import torch_geometric
10 |
11 | def assert_correctly_masked(variable, node_mask):
12 | assert (variable * (1 - node_mask.long())).abs().max().item() < 1e-4, \
13 | 'Variables not masked properly.'
14 |
15 | class PlaceHolder:
16 | def __init__(self, X, E, y):
17 | self.X = X
18 | self.E = E
19 | self.y = y
20 |
21 | def type_as(self, x: torch.Tensor):
22 | """ Changes the device and dtype of X, E, y. """
23 | self.X = self.X.type_as(x)
24 | self.E = self.E.type_as(x)
25 | self.y = self.y.type_as(x)
26 | return self
27 |
28 | def mask(self, node_mask, collapse=False):
29 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1
30 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
31 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
32 |
33 | if collapse:
34 | self.X = torch.argmax(self.X, dim=-1)
35 | self.E = torch.argmax(self.E, dim=-1)
36 |
37 | self.X[node_mask == 0] = - 1
38 | self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
39 | else:
40 | self.X = self.X * x_mask
41 | self.E = self.E * e_mask1 * e_mask2
42 | assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
43 | return self
44 |
45 |
46 | def encode_no_edge(E):
47 | assert len(E.shape) == 4
48 | if E.shape[-1] == 0:
49 | return E
50 | no_edge = torch.sum(E, dim=3) == 0
51 | first_elt = E[:, :, :, 0]
52 | first_elt[no_edge] = 1
53 | E[:, :, :, 0] = first_elt
54 | diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
55 | E[diag] = 0
56 | return E
57 |
58 | def to_dense(x, edge_index, edge_attr, batch):
59 | X, node_mask = torch_geometric.utils.to_dense_batch(x=x, batch=batch)
60 | # node_mask = node_mask.float()
61 | edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
62 | # TODO: carefully check if setting node_mask as a bool breaks the continuous case
63 | max_num_nodes = X.size(1)
64 | if edge_index.numel() == 0:
65 | # We have to do this check otherwise things fail
66 | E = torch_geometric.utils.to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr)
67 | else:
68 | E = torch_geometric.utils.to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
69 | E = encode_no_edge(E)
70 |
71 | return PlaceHolder(X=X, E=E, y=None), node_mask
72 |
73 |
74 | class Xtoy(nn.Module):
75 | def __init__(self, dx, dy):
76 | """ Map node features to global features """
77 | super().__init__()
78 | self.lin = nn.Linear(4 * dx, dy)
79 |
80 | def forward(self, X):
81 | """ X: bs, n, dx. """
82 | m = X.mean(dim=1)
83 | mi = X.min(dim=1)[0]
84 | ma = X.max(dim=1)[0]
85 | std = X.std(dim=1)
86 | z = torch.hstack((m, mi, ma, std))
87 | out = self.lin(z)
88 | return out
89 |
90 |
91 | class Etoy(nn.Module):
92 | def __init__(self, d, dy):
93 | """ Map edge features to global features. """
94 | super().__init__()
95 | self.lin = nn.Linear(4 * d, dy)
96 |
97 | def forward(self, E):
98 | """ E: bs, n, n, de
99 | Features relative to the diagonal of E could potentially be added.
100 | """
101 | m = E.mean(dim=(1, 2))
102 | mi = E.min(dim=2)[0].min(dim=1)[0]
103 | ma = E.max(dim=2)[0].max(dim=1)[0]
104 | std = torch.std(E, dim=(1, 2))
105 | z = torch.hstack((m, mi, ma, std))
106 | out = self.lin(z)
107 | return out
108 |
109 | class XEyTransformerLayer(nn.Module):
110 | """ Transformer that updates node, edge and global features
111 | d_x: node features
112 | d_e: edge features
113 | dz : global features
114 | n_head: the number of heads in the multi_head_attention
115 | dim_feedforward: the dimension of the feedforward network model after self-attention
116 | dropout: dropout probablility. 0 to disable
117 | layer_norm_eps: eps value in layer normalizations.
118 | """
119 | def __init__(self, dx: int, de: int, dy: int, n_head: int, dim_ffX: int = 2048,
120 | dim_ffE: int = 128, dim_ffy: int = 2048, dropout: float = 0.1,
121 | layer_norm_eps: float = 1e-5, device=None, dtype=None) -> None:
122 | kw = {'device': device, 'dtype': dtype}
123 | super().__init__()
124 |
125 | self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw)
126 |
127 | self.linX1 = Linear(dx, dim_ffX, **kw)
128 | self.linX2 = Linear(dim_ffX, dx, **kw)
129 | self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
130 | self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
131 | self.dropoutX1 = Dropout(dropout)
132 | self.dropoutX2 = Dropout(dropout)
133 | self.dropoutX3 = Dropout(dropout)
134 |
135 | self.linE1 = Linear(de, dim_ffE, **kw)
136 | self.linE2 = Linear(dim_ffE, de, **kw)
137 | self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
138 | self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
139 | self.dropoutE1 = Dropout(dropout)
140 | self.dropoutE2 = Dropout(dropout)
141 | self.dropoutE3 = Dropout(dropout)
142 |
143 | self.lin_y1 = Linear(dy, dim_ffy, **kw)
144 | self.lin_y2 = Linear(dim_ffy, dy, **kw)
145 | self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw)
146 | self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw)
147 | self.dropout_y1 = Dropout(dropout)
148 | self.dropout_y2 = Dropout(dropout)
149 | self.dropout_y3 = Dropout(dropout)
150 |
151 | self.activation = F.relu
152 |
153 | def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor):
154 | """ Pass the input through the encoder layer.
155 | X: (bs, n, d)
156 | E: (bs, n, n, d)
157 | y: (bs, dy)
158 | node_mask: (bs, n) Mask for the src keys per batch (optional)
159 | Output: newX, newE, new_y with the same shape.
160 | """
161 |
162 | newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask)
163 |
164 | newX_d = self.dropoutX1(newX)
165 | X = self.normX1(X + newX_d)
166 |
167 | newE_d = self.dropoutE1(newE)
168 | E = self.normE1(E + newE_d)
169 |
170 | new_y_d = self.dropout_y1(new_y)
171 | y = self.norm_y1(y + new_y_d)
172 |
173 | ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
174 | ff_outputX = self.dropoutX3(ff_outputX)
175 | X = self.normX2(X + ff_outputX)
176 |
177 | ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
178 | ff_outputE = self.dropoutE3(ff_outputE)
179 | E = self.normE2(E + ff_outputE)
180 |
181 | ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y))))
182 | ff_output_y = self.dropout_y3(ff_output_y)
183 | y = self.norm_y2(y + ff_output_y)
184 |
185 | return X, E, y
186 |
187 |
188 | class NodeEdgeBlock(nn.Module):
189 | """ Self attention layer that also updates the representations on the edges. """
190 | def __init__(self, dx, de, dy, n_head, **kwargs):
191 | super().__init__()
192 | assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
193 | self.dx = dx
194 | self.de = de
195 | self.dy = dy
196 | self.df = int(dx / n_head)
197 | self.n_head = n_head
198 |
199 | # Attention
200 | self.q = Linear(dx, dx)
201 | self.k = Linear(dx, dx)
202 | self.v = Linear(dx, dx)
203 |
204 | # FiLM E to X
205 | self.e_add = Linear(de, dx)
206 | self.e_mul = Linear(de, dx)
207 |
208 | # FiLM y to E
209 | self.y_e_mul = Linear(dy, dx) # Warning: here it's dx and not de
210 | self.y_e_add = Linear(dy, dx)
211 |
212 | # FiLM y to X
213 | self.y_x_mul = Linear(dy, dx)
214 | self.y_x_add = Linear(dy, dx)
215 |
216 | # Process y
217 | self.y_y = Linear(dy, dy)
218 | self.x_y = Xtoy(dx, dy)
219 | self.e_y = Etoy(de, dy)
220 |
221 | # Output layers
222 | self.x_out = Linear(dx, dx)
223 | self.e_out = Linear(dx, de)
224 | self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy))
225 |
226 | def forward(self, X, E, y, node_mask):
227 | """
228 | :param X: bs, n, d node features
229 | :param E: bs, n, n, d edge features
230 | :param y: bs, dz global features
231 | :param node_mask: bs, n
232 | :return: newX, newE, new_y with the same shape.
233 | """
234 | x_mask = node_mask.unsqueeze(-1) # bs, n, 1
235 | e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
236 | e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
237 |
238 | # 1. Map X to keys and queries
239 | Q = self.q(X) * x_mask # (bs, n, dx)
240 | K = self.k(X) * x_mask # (bs, n, dx)
241 | assert_correctly_masked(Q, x_mask)
242 | # 2. Reshape to (bs, n, n_head, df) with dx = n_head * df
243 |
244 | Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df))
245 | K = K.reshape((K.size(0), K.size(1), self.n_head, self.df))
246 |
247 | Q = Q.unsqueeze(2) # (bs, 1, n, n_head, df)
248 | K = K.unsqueeze(1) # (bs, n, 1, n head, df)
249 |
250 | # Compute unnormalized attentions. Y is (bs, n, n, n_head, df)
251 | Y = Q * K
252 | Y = Y / math.sqrt(Y.size(-1))
253 | assert_correctly_masked(Y, (e_mask1 * e_mask2).unsqueeze(-1))
254 |
255 | E1 = self.e_mul(E) * e_mask1 * e_mask2 # bs, n, n, dx
256 | E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))
257 |
258 | E2 = self.e_add(E) * e_mask1 * e_mask2 # bs, n, n, dx
259 | E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))
260 |
261 | # Incorporate edge features to the self attention scores.
262 | Y = Y * (E1 + 1) + E2 # (bs, n, n, n_head, df)
263 |
264 | # Incorporate y to E
265 | newE = Y.flatten(start_dim=3) # bs, n, n, dx
266 | ye1 = self.y_e_add(y).unsqueeze(1).unsqueeze(1) # bs, 1, 1, de
267 | ye2 = self.y_e_mul(y).unsqueeze(1).unsqueeze(1)
268 | newE = ye1 + (ye2 + 1) * newE
269 |
270 | # Output E
271 | newE = self.e_out(newE) * e_mask1 * e_mask2 # bs, n, n, de
272 | assert_correctly_masked(newE, e_mask1 * e_mask2)
273 |
274 | # Compute attentions. attn is still (bs, n, n, n_head, df)
275 | attn = F.softmax(Y, dim=2)
276 |
277 | V = self.v(X) * x_mask # bs, n, dx
278 | V = V.reshape((V.size(0), V.size(1), self.n_head, self.df))
279 | V = V.unsqueeze(1) # (bs, 1, n, n_head, df)
280 |
281 | # Compute weighted values
282 | weighted_V = attn * V
283 | weighted_V = weighted_V.sum(dim=2)
284 |
285 | # Send output to input dim
286 | weighted_V = weighted_V.flatten(start_dim=2) # bs, n, dx
287 |
288 | # Incorporate y to X
289 | yx1 = self.y_x_add(y).unsqueeze(1)
290 | yx2 = self.y_x_mul(y).unsqueeze(1)
291 | newX = yx1 + (yx2 + 1) * weighted_V
292 |
293 | # Output X
294 | newX = self.x_out(newX) * x_mask
295 | assert_correctly_masked(newX, x_mask)
296 |
297 | # Process y based on X axnd E
298 | y = self.y_y(y)
299 | e_y = self.e_y(E)
300 | x_y = self.x_y(X)
301 | new_y = y + x_y + e_y
302 | new_y = self.y_out(new_y) # bs, dy
303 |
304 | return newX, newE, new_y
305 |
306 |
307 | class GraphTransformer(nn.Module):
308 | """
309 | n_layers : int -- number of layers
310 | dims : dict -- contains dimensions for each feature type
311 | """
312 | def __init__(self, n_layers: int, input_dims: dict, hidden_mlp_dims: dict, hidden_dims: dict,
313 | output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU()):
314 | super().__init__()
315 | self.n_layers = n_layers
316 | self.out_dim_X = output_dims['X']
317 | self.out_dim_E = output_dims['E']
318 | self.out_dim_y = output_dims['y']
319 |
320 | self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X'], hidden_mlp_dims['X']), act_fn_in,
321 | nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in)
322 |
323 | self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E'], hidden_mlp_dims['E']), act_fn_in,
324 | nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in)
325 |
326 | self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in,
327 | nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in)
328 |
329 | self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'],
330 | de=hidden_dims['de'],
331 | dy=hidden_dims['dy'],
332 | n_head=hidden_dims['n_head'],
333 | dim_ffX=hidden_dims['dim_ffX'],
334 | dim_ffE=hidden_dims['dim_ffE'])
335 | for i in range(n_layers)])
336 |
337 | self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out,
338 | nn.Linear(hidden_mlp_dims['X'], output_dims['X']))
339 |
340 | # Note: we change the activation function here to sigmoid!
341 | self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), nn.Sigmoid(),
342 | nn.Linear(hidden_mlp_dims['E'], output_dims['E']))
343 |
344 | self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out,
345 | nn.Linear(hidden_mlp_dims['y'], output_dims['y']))
346 |
347 | def forward(self, X, E, y, node_mask):
348 | bs, n = X.shape[0], X.shape[1]
349 |
350 | diag_mask = torch.eye(n)
351 | diag_mask = ~diag_mask.type_as(E).bool()
352 | diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)
353 |
354 | X_to_out = X[..., :self.out_dim_X]
355 | E_to_out = E[..., :self.out_dim_E]
356 | y_to_out = y[..., :self.out_dim_y]
357 |
358 | new_E = self.mlp_in_E(E)
359 | new_E = (new_E + new_E.transpose(1, 2)) / 2
360 | after_in = PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask)
361 | X, E, y = after_in.X, after_in.E, after_in.y
362 |
363 | for layer in self.tf_layers:
364 | X, E, y = layer(X, E, y, node_mask)
365 |
366 | X = self.mlp_out_X(X)
367 | E = self.mlp_out_E(E)
368 | y = self.mlp_out_y(y)
369 |
370 | X = (X + X_to_out)
371 | E = (E + E_to_out) * diag_mask
372 | y = y + y_to_out
373 |
374 | E = 1/2 * (E + torch.transpose(E, 1, 2))
375 |
376 | return PlaceHolder(X=X, E=E, y=y).mask(node_mask)
377 |
378 |
379 | class DiGressGNN(nn.Module):
380 |
381 | def __init__(self, input_dim, t_limit):
382 | super().__init__()
383 | self.creation_args = {}
384 | self.t_limit = t_limit
385 | self.model = GraphTransformer(
386 | 5,
387 | {'X': input_dim, 'E': 1, 'y': 1},
388 | {'X': 256, 'E': 128, 'y': 128},
389 | {'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128},
390 | {'X': input_dim, 'E': 1, 'y': 1},
391 | nn.ReLU(), nn.ReLU()
392 | )
393 | self.sigmoid = torch.nn.Sigmoid()
394 | self.bce_loss = torch.nn.BCELoss()
395 |
396 | def forward(self, data, t):
397 | # Add extra attribute
398 | data.edge_attr = torch.ones(data.edge_index.shape[1], 1, device=data.x.device)
399 |
400 | # Convert to proper format
401 | dense_data, node_mask = to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
402 | dense_data = dense_data.mask(node_mask)
403 | X, E = dense_data.X.float(), dense_data.E.float()
404 |
405 | # Encode time as y
406 | y = (t[data.ptr[:-1]] / self.t_limit)[:, None]
407 |
408 | # Extract out edge predictions
409 | pred_object = self.model(X, E, y, node_mask)
410 | edge_preds = pred_object.E[:, :, :, 0] # Shape: B x V_max x V_max
411 |
412 | # Convert edge predictions into canonical tensor; do it the same way as
413 | # pyg_data_to_edge_vector
414 | graph_sizes = torch.diff(data.ptr)
415 |
416 | if torch.max(data.batch) >= len(graph_sizes):
417 | edge_preds = edge_preds[:len(graph_sizes)]
418 |
419 | # Create boolean mask of only the top upper triangle of each B x V x V
420 | # matrix, ignoring the diagonal
421 | triu_mask = torch.triu(torch.ones_like(edge_preds), diagonal=1) == 1
422 | for i, size in enumerate(torch.diff(data.ptr)):
423 | # For each individual graph, limit the upper-triangle mask to
424 | # only the size of that graph
425 | triu_mask[i, :, size:] = False
426 |
427 | # Edge vector is all entries in the graph-specific upper triangle of each
428 | # individual adjacency matrix
429 | edge_vec = edge_preds[triu_mask]
430 |
431 | # Finally, pass through a sigmoid
432 | return self.sigmoid(edge_vec)
433 |
434 | def loss(self, pred_probs, true_probs):
435 | return self.bce_loss(pred_probs, true_probs)
436 |
437 | if __name__ == "__main__":
438 | import networkx as nx
439 |
440 | if torch.cuda.is_available():
441 | DEVICE = "cuda"
442 | else:
443 | DEVICE = "cpu"
444 |
445 | batch_size = 32
446 | num_nodes = 10
447 | node_dim = 5
448 | t_limit = 1
449 |
450 | # Prepare a data and time object just as the model would see
451 | edge_index = []
452 | for i in range(batch_size):
453 | edges = torch.tensor(list(nx.erdos_renyi_graph(num_nodes, 0.5).edges)) + (i * num_nodes)
454 | edge_index.append(edges)
455 | edge_index.append(torch.flip(edges, (1,)))
456 | edge_index = torch.concat(edge_index, dim=0)
457 | batch = torch.repeat_interleave(torch.arange(batch_size), num_nodes, 0)
458 | ptr = torch.concat([
459 | torch.zeros(1, device=batch.device),
460 | torch.where(torch.diff(batch) != 0)[0] + 1,
461 | torch.tensor([len(batch)], device=batch.device)
462 | ]).long()
463 |
464 | data = torch_geometric.data.Data(
465 | x=torch.rand((batch_size * num_nodes, node_dim)),
466 | edge_index=edge_index,
467 | batch=batch,
468 | ptr=ptr
469 | ).to(DEVICE)
470 | t = torch.rand(len(data.x)).to(DEVICE)
471 |
472 | # Run the model
473 | model = DiGressGNN(node_dim, t_limit).to(DEVICE)
474 | edge_pred = model(data, t)
475 |
--------------------------------------------------------------------------------
/src/model/discrete_diffusers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # Define device
4 | if torch.cuda.is_available():
5 | DEVICE = "cuda"
6 | else:
7 | DEVICE = "cpu"
8 |
9 |
10 | class DiscreteDiffuser:
11 | # Base class for discrete diffusers
12 | def __init__(self, input_shape, seed=None):
13 | """
14 | Arguments:
15 | `input_shape`: a tuple of ints which is the shape of input tensors
16 | x; does not include batch dimension
17 | `seed`: random seed for sampling and running the diffusion process
18 | """
19 | self.input_shape = input_shape
20 | self.rng = torch.Generator(device=DEVICE)
21 | if seed:
22 | self.rng.manual_seed(seed)
23 |
24 | def _inflate_dims(self, v):
25 | """
26 | Given a tensor vector `v`, appends dimensions of size 1 so that it has
27 | the same number of dimensions as `self.input_shape`. For example, if
28 | `self.input_shape` is (3, 50, 50), then this function turns `v` from a
29 | B-tensor to a B x 1 x 1 x 1 tensor. This is useful for combining the
30 | tensor with things shaped like the input later.
31 | Arguments:
32 | `v`: a B-tensor
33 | Returns a B x `self.input_shape` tensor.
34 | """
35 | return v[(slice(None),) + ((None,) * len(self.input_shape))]
36 |
37 | def forward(self, x0, t, return_posterior=True):
38 | """
39 | Runs diffusion process forward given starting point `x0` and a time `t`.
40 | Optionally also returns a tensor which represents the posterior of
41 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
42 | Arguments:
43 | `x0`: a B x `self.input_shape` tensor containing the data at some
44 | time points
45 | `t`: a B-tensor containing the time in the diffusion process for
46 | each input
47 | Returns a B x `self.input_shape` tensor to represent xt. If
48 | `return_posterior` is True, then also returns a B x `self.input_shape`
49 | tensor which is a parameter of the posterior.
50 | """
51 | if return_posterior:
52 | return torch.zeros_like(x0), torch.zeros_like(x0)
53 | else:
54 | return torch.zeros_like(x0)
55 |
56 | def reverse_step(self, xt, t, post):
57 | """
58 | Performs a reverse sampling step to compute x_{t-1} given xt and the
59 | posterior quantity (or an estimate of it) defined in `posterior`.
60 | Arguments:
61 | `xt`: a B x `self.input_shape` tensor containing the data at time t
62 | `t`: a B-tensor containing the time in the diffusion process for
63 | each input
64 | `post`: a B x `self.input_shape` tensor containing the posterior
65 | quantity (or a model-predicted estimate of it) as defined in
66 | `posterior`
67 | Returns a B x `self.input_shape` tensor for x_{t-1}.
68 | """
69 | return torch.zeros_like(x0)
70 |
71 | def sample_prior(self, num_samples, t):
72 | """
73 | Samples from the prior distribution specified by the diffusion process
74 | at time `t`.
75 | Arguments:
76 | `num_samples`: B, the number of samples to return
77 | `t`: a B-tensor containing the time in the diffusion process for
78 | each input
79 | Returns a B x `self.input_shape` tensor for the `xt` values that are
80 | sampled.
81 | """
82 | return torch.zeros(torch.Size([num_samples] + list(self.input_shape)))
83 |
84 | def __str__(self):
85 | return "Base Discrete Diffuser"
86 |
87 |
88 | class GaussianDiffuser(DiscreteDiffuser):
89 | # Diffuser which adds Gaussian noise to the inputs
90 | def __init__(self, beta_1, delta_beta, input_shape, seed=None):
91 | """
92 | Arguments:
93 | `beta_1`: beta(1), the first value of beta at t = 1
94 | `delta_beta`: beta(t) will be linear with this slope
95 | `input_shape`: a tuple of ints which is the shape of input tensors
96 | x; does not include batch dimension
97 | `seed`: random seed for sampling and running the diffusion process
98 | """
99 | super().__init__(input_shape, seed)
100 |
101 | self.beta_1 = torch.tensor(beta_1).to(DEVICE)
102 | self.delta_beta = torch.tensor(delta_beta).to(DEVICE)
103 | self.string = "Gaussian Diffuser (beta(t) = %.2f + %.2ft)" % (
104 | beta_1, delta_beta
105 | )
106 |
107 | def _beta(self, t):
108 | """
109 | Computes beta(t).
110 | Arguments:
111 | `t`: a B-tensor of times
112 | Returns a B-tensor of beta values.
113 | """
114 | # Subtract 1 from t: when t = 1, beta(t) = beta_1
115 | return self.beta_1 + (self.delta_beta * (t - 1))
116 |
117 | def _alpha(self, t):
118 | """
119 | Computes alpha(t).
120 | Arguments:
121 | `t`: a B-tensor of times
122 | Returns a B-tensor of alpha values.
123 | """
124 | return 1 - self._beta(t)
125 |
126 | def _alpha_bar(self, t):
127 | """
128 | Computes alpha-bar(t).
129 | Arguments:
130 | `t`: a B-tensor of times
131 | Returns a B-tensor of alpha-bar values.
132 | """
133 | max_t = torch.max(t)
134 | t_range = torch.arange(max_t.int() + 1, device=DEVICE)
135 | alphas = self._alpha(t_range)
136 | alphas_prod = torch.cumprod(alphas, dim=0)
137 | return alphas_prod[t.long()]
138 |
139 | def forward(self, x0, t, return_posterior=True):
140 | """
141 | Runs diffusion process forward given starting point `x0` and a time `t`.
142 | Optionally also returns a tensor which represents the posterior of
143 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
144 | Arguments:
145 | `x0`: a B x `self.input_shape` tensor containing the data at some
146 | time points
147 | `t`: a B-tensor containing the time in the diffusion process for
148 | each input
149 | Returns a B x `self.input_shape` tensor to represent xt. If
150 | `return_posterior` is True, then also returns a B x `self.input_shape`
151 | tensor which is a parameter of the posterior.
152 | """
153 | z = torch.normal(
154 | torch.zeros_like(x0), torch.ones_like(x0), generator=self.rng
155 | ) # Shape: B x ...
156 |
157 | alpha_bar = self._inflate_dims(self._alpha_bar(t))
158 | xt = (torch.sqrt(alpha_bar) * x0) + (torch.sqrt(1 - alpha_bar) * z)
159 |
160 | if return_posterior:
161 | return xt, z
162 | return xt
163 |
164 | def reverse_step(self, xt, t, post):
165 | """
166 | Performs a reverse sampling step to compute x_{t-1} given xt and the
167 | posterior quantity (or an estimate of it) defined in `posterior`.
168 | Arguments:
169 | `xt`: a B x `self.input_shape` tensor containing the data at time t
170 | `t`: a B-tensor containing the time in the diffusion process for
171 | each input
172 | `post`: a B x `self.input_shape` tensor containing the posterior
173 | quantity (or a model-predicted estimate of it) as defined in
174 | `posterior`
175 | Returns a B x `self.input_shape` tensor for x_{t-1}.
176 | """
177 | beta = self._beta(t)
178 | alpha = self._alpha(t)
179 | alpha_bar = self._alpha_bar(t)
180 | alpha_bar_1 = self._alpha_bar(t - 1)
181 | beta_tilde = beta * (1 - alpha_bar_1) / (1 - alpha_bar)
182 |
183 | z = torch.normal(
184 | torch.zeros_like(xt), torch.ones_like(xt), generator=self.rng
185 | )
186 |
187 | d = xt - (post * self._inflate_dims(beta / torch.sqrt(1 - alpha_bar)))
188 | d = d / self._inflate_dims(torch.sqrt(alpha))
189 | std = torch.sqrt(beta_tilde)
190 | std[t == 1] = 0 # No noise for the last step
191 | return d + (z * self._inflate_dims(std))
192 |
193 | def sample_prior(self, num_samples, t):
194 | """
195 | Samples from the prior distribution specified by the diffusion process
196 | at time `t`.
197 | Arguments:
198 | `num_samples`: B, the number of samples to return
199 | `t`: a B-tensor containing the time in the diffusion process for
200 | each input
201 | Returns a B x `self.input_shape` tensor for the `xt` values that are
202 | sampled.
203 | """
204 | # Ignore t
205 | size = torch.Size([num_samples] + list(self.input_shape))
206 | return torch.normal(
207 | torch.zeros(size, device=DEVICE), torch.ones(size, device=DEVICE),
208 | generator=self.rng
209 | )
210 |
211 | def __str__(self):
212 | return self.string
213 |
214 |
215 | class BernoulliDiffuser(DiscreteDiffuser):
216 | # Diffuser which flips bits in the inputs
217 | def __init__(self, a, b, input_shape, seed=None):
218 | """
219 | Arguments:
220 | `a`: stretch factor of logistic beta function
221 | `b`: shift factor of logistic beta function
222 | `input_shape`: a tuple of ints which is the shape of input tensors
223 | x; does not include batch dimension
224 | `seed`: random seed for sampling and running the diffusion process
225 | """
226 | super().__init__(input_shape, seed)
227 |
228 | self.a = torch.tensor(a).to(DEVICE)
229 | self.b = torch.tensor(b).to(DEVICE)
230 | self.string = "Bernoulli Diffuser (beta(t) = s((t / %.2f) - %.2f)" % (
231 | a, b
232 | )
233 |
234 | epsilon = 1e-6 # For numerical stability
235 | self.half = torch.tensor(0.5, device=DEVICE)
236 | self.half_eps = self.half - 1e-6
237 | self.log2 = torch.log(torch.tensor(2, device=DEVICE))
238 |
239 | def _beta(self, t):
240 | """
241 | Computes beta(t), which is the probability of a flip at time `t`.
242 | Arguments:
243 | `t`: a B-tensor of times
244 | Returns a B-tensor of beta values.
245 | """
246 | # Subtract 1 from t: when t = 1, beta(t) = beta_1
247 | return torch.minimum(
248 | torch.sigmoid((t / self.a) - self.b), self.half_eps
249 | )
250 |
251 | def _beta_bar(self, t):
252 | """
253 | Computes beta-bar(t), which is the probability of a flip from time 0 to
254 | time `t` (flip is relative to time 0).
255 | Arguments:
256 | `t`: a B-tensor of times
257 | Returns a B-tensor of beta-bar values.
258 | """
259 | max_range = torch.arange(1, torch.max(t) + 1, device=DEVICE)
260 | betas = self._beta(max_range) # Shape: maxT
261 |
262 | betas_tiled = torch.tile(betas, (t.shape[0], 1)) # Shape: B x maxT
263 | biases = self.half - betas_tiled
264 |
265 | # Anything that ran over a time t, set to 1
266 | mask = max_range[None] > t[:, None]
267 | biases[mask] = 1
268 |
269 | # Do the product in log-space and transform back for stability
270 | coef = (t - 1) * self.log2
271 | log_prod = torch.sum(torch.log(biases), dim=1)
272 | prod = torch.exp(coef + log_prod)
273 |
274 | return self.half - prod
275 |
276 | def forward(self, x0, t, return_posterior=True):
277 | """
278 | Runs diffusion process forward given starting point `x0` and a time `t`.
279 | Optionally also returns a tensor which represents the posterior of
280 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
281 | Arguments:
282 | `x0`: a B x `self.input_shape` tensor containing the data at some
283 | time points
284 | `t`: a B-tensor containing the time in the diffusion process for
285 | each input
286 | Returns a B x `self.input_shape` tensor to represent xt. If
287 | `return_posterior` is True, then also returns a B x `self.input_shape`
288 | tensor which is a parameter of the posterior.
289 | """
290 | beta_t = self._inflate_dims(self._beta(t))
291 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
292 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1))
293 |
294 | prob_flip = torch.tile(beta_bar_t, (1, *x0.shape[1:]))
295 | indicators = torch.bernoulli(prob_flip)
296 |
297 | # Perform flips
298 | xt = x0.clone()
299 | mask = indicators == 1
300 | xt[mask] = 1 - x0[mask]
301 |
302 | if return_posterior:
303 | term_1 = ((1 - xt) * beta_t) + (xt * (1 - beta_t))
304 | term_2 = ((1 - x0) * beta_bar_t_1) + (x0 * (1 - beta_bar_t_1))
305 | x0_xor_xt = torch.square(x0 - xt)
306 | term_3 = (x0_xor_xt * beta_bar_t) + \
307 | ((1 - x0_xor_xt) * (1 - beta_bar_t))
308 |
309 | post = term_1 * term_2 / term_3
310 | # Due to small numerical instabilities (particularly at t = 0), clip
311 | # the probabilities
312 | post = torch.clamp(post, 0, 1)
313 | return xt, post
314 | return xt
315 |
316 | def reverse_step(self, xt, t, post):
317 | """
318 | Performs a reverse sampling step to compute x_{t-1} given xt and the
319 | posterior quantity (or an estimate of it) defined in `posterior`.
320 | Arguments:
321 | `xt`: a B x `self.input_shape` tensor containing the data at time t
322 | `t`: a B-tensor containing the time in the diffusion process for
323 | each input
324 | `post`: a B x `self.input_shape` tensor containing the posterior
325 | quantity (or a model-predicted estimate of it) as defined in
326 | `posterior`
327 | Returns a B x `self.input_shape` tensor for x_{t-1}.
328 | """
329 | # Ignore xt and t
330 | return torch.bernoulli(post)
331 |
332 | def sample_prior(self, num_samples, t):
333 | """
334 | Samples from the prior distribution specified by the diffusion process
335 | at time `t`.
336 | Arguments:
337 | `num_samples`: B, the number of samples to return
338 | `t`: a B-tensor containing the time in the diffusion process for
339 | each input
340 | Returns a B x `self.input_shape` tensor for the `xt` values that are
341 | sampled.
342 | """
343 | # Ignore t
344 | size = torch.Size([num_samples] + list(self.input_shape))
345 | probs = torch.tile(self.half, size)
346 | return torch.bernoulli(probs)
347 |
348 | def __str__(self):
349 | return self.string
350 |
351 |
352 | class BernoulliOneDiffuser(DiscreteDiffuser):
353 | # Diffuser which sets bits to 1 in the inputs
354 | def __init__(self, a, b, input_shape, seed=None, reverse_reflect=False):
355 | """
356 | Arguments:
357 | `a`: stretch factor of logistic beta function
358 | `b`: shift factor of logistic beta function
359 | `input_shape`: a tuple of ints which is the shape of input tensors
360 | x; does not include batch dimension
361 | `seed`: random seed for sampling and running the diffusion process
362 | `reverse_reflect`: if True, enforce that the reverse-diffusion
363 | process only sets entries to 0; if False, the reverse-diffusion
364 | process samples from the Bernoulli posterior as usual
365 | """
366 | super().__init__(input_shape, seed)
367 |
368 | self.a = torch.tensor(a).to(DEVICE)
369 | self.b = torch.tensor(b).to(DEVICE)
370 | self.string = "Bernoulli One Diffuser (beta(t) = s((t / %.2f) - %.2f)" \
371 | % (a, b)
372 | self.reverse_reflect = reverse_reflect
373 |
374 | def _beta(self, t):
375 | """
376 | Computes beta(t), which is the probability of a flip to 1 at time `t`.
377 | Arguments:
378 | `t`: a B-tensor of times
379 | Returns a B-tensor of beta values.
380 | """
381 | # Subtract 1 from t: when t = 1, beta(t) = beta_1
382 | return torch.sigmoid((t / self.a) - self.b)
383 |
384 | def _beta_bar(self, t):
385 | """
386 | Computes beta-bar(t), which is the probability of a flip to 1 from time
387 | 0 to time `t` (flip is assuming the bit at time 0 is 0).
388 | Arguments:
389 | `t`: a B-tensor of times
390 | Returns a B-tensor of beta-bar values.
391 | """
392 | max_range = torch.arange(1, torch.max(t) + 1, device=DEVICE)
393 | betas = self._beta(max_range) # Shape: maxT
394 |
395 | betas_tiled = torch.tile(betas, (t.shape[0], 1)) # Shape: B x maxT
396 | comps = 1 - betas_tiled
397 |
398 | # Anything that ran over a time t, set to 1
399 | mask = max_range[None] > t[:, None]
400 | comps[mask] = 1
401 |
402 | # Do the product in log-space and transform back for stability
403 | log_prod = torch.sum(torch.log(comps), dim=1)
404 | prod = torch.exp(log_prod)
405 |
406 | return 1 - prod
407 |
408 | def forward(self, x0, t, return_posterior=True):
409 | """
410 | Runs diffusion process forward given starting point `x0` and a time `t`.
411 | Optionally also returns a tensor which represents the posterior of
412 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
413 | Arguments:
414 | `x0`: a B x `self.input_shape` tensor containing the data at some
415 | time points
416 | `t`: a B-tensor containing the time in the diffusion process for
417 | each input
418 | Returns a B x `self.input_shape` tensor to represent xt. If
419 | `return_posterior` is True, then also returns a B x `self.input_shape`
420 | tensor which is a parameter of the posterior.
421 | """
422 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
423 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1))
424 |
425 | prob_one = x0 + ((1 - x0) * beta_bar_t)
426 | indicators = torch.bernoulli(prob_one)
427 |
428 | # Perform sampling
429 | xt = x0.clone()
430 | mask = indicators == 1
431 | xt[mask] = 1 # Set to 1
432 |
433 | if return_posterior:
434 | term_1 = xt
435 | term_2 = x0 + ((1 - x0) * beta_bar_t_1)
436 | term_3 = (x0 * xt) + ((1 - x0) * xt * beta_bar_t) + \
437 | ((1 - x0) * (1 - xt) * (1 - beta_bar_t))
438 |
439 | post = term_1 * term_2 / term_3
440 | # Due to small numerical instabilities (particularly at t = 0), clip
441 | # the probabilities
442 | post = torch.clamp(post, 0, 1)
443 | return xt, post
444 | return xt
445 |
446 | def reverse_step(self, xt, t, post):
447 | """
448 | Performs a reverse sampling step to compute x_{t-1} given xt and the
449 | posterior quantity (or an estimate of it) defined in `posterior`.
450 | Arguments:
451 | `xt`: a B x `self.input_shape` tensor containing the data at time t
452 | `t`: a B-tensor containing the time in the diffusion process for
453 | each input
454 | `post`: a B x `self.input_shape` tensor containing the posterior
455 | quantity (or a model-predicted estimate of it) as defined in
456 | `posterior`
457 | Returns a B x `self.input_shape` tensor for x_{t-1}.
458 | """
459 | # Ignore xt and t
460 | if self.reverse_reflect:
461 | xt_1 = xt.clone()
462 | indicators = torch.bernoulli(post)
463 | xt_1[indicators == 0] = 0 # Only set to 0 when a 0 is sampled
464 | return xt_1
465 | return torch.bernoulli(post)
466 |
467 | def sample_prior(self, num_samples, t):
468 | """
469 | Samples from the prior distribution specified by the diffusion process
470 | at time `t`.
471 | Arguments:
472 | `num_samples`: B, the number of samples to return
473 | `t`: a B-tensor containing the time in the diffusion process for
474 | each input
475 | Returns a B x `self.input_shape` tensor for the `xt` values that are
476 | sampled.
477 | """
478 | # Ignore t
479 | size = torch.Size([num_samples] + list(self.input_shape))
480 | return torch.ones(size, device=DEVICE)
481 |
482 | def __str__(self):
483 | return self.string
484 |
485 |
486 | class BernoulliZeroDiffuser(DiscreteDiffuser):
487 | # Diffuser which sets bits to 0 in the inputs
488 | def __init__(self, a, b, input_shape, seed=None, reverse_reflect=False):
489 | """
490 | Arguments:
491 | `a`: stretch factor of logistic beta function
492 | `b`: shift factor of logistic beta function
493 | `input_shape`: a tuple of ints which is the shape of input tensors
494 | x; does not include batch dimension
495 | `seed`: random seed for sampling and running the diffusion process
496 | `reverse_reflect`: if True, enforce that the reverse-diffusion
497 | process only sets entries to 1; if False, the reverse-diffusion
498 | process samples from the Bernoulli posterior as usual
499 | """
500 | super().__init__(input_shape, seed)
501 |
502 | self.a = torch.tensor(a).to(DEVICE)
503 | self.b = torch.tensor(b).to(DEVICE)
504 | self.string = "Bernoulli Zero Diffuser (beta(t) = s((t / %.2f) - %.2f)" \
505 | % (a, b)
506 | self.reverse_reflect = reverse_reflect
507 |
508 | def _beta(self, t):
509 | """
510 | Computes beta(t), which is the probability of a flip to 0 at time `t`.
511 | Arguments:
512 | `t`: a B-tensor of times
513 | Returns a B-tensor of beta values.
514 | """
515 | # Subtract 1 from t: when t = 1, beta(t) = beta_1
516 | return torch.sigmoid((t / self.a) - self.b)
517 |
518 | def _beta_bar(self, t):
519 | """
520 | Computes beta-bar(t), which is the probability of a flip to 0 from time
521 | 0 to time `t` (flip is assuming the bit at time 0 is 1).
522 | Arguments:
523 | `t`: a B-tensor of times
524 | Returns a B-tensor of beta-bar values.
525 | """
526 | max_range = torch.arange(1, torch.max(t) + 1, device=DEVICE)
527 | betas = self._beta(max_range) # Shape: maxT
528 |
529 | betas_tiled = torch.tile(betas, (t.shape[0], 1)) # Shape: B x maxT
530 | comps = 1 - betas_tiled
531 |
532 | # Anything that ran over a time t, set to 1
533 | mask = max_range[None] > t[:, None]
534 | comps[mask] = 1
535 |
536 | # Do the product in log-space and transform back for stability
537 | log_prod = torch.sum(torch.log(comps), dim=1)
538 | prod = torch.exp(log_prod)
539 |
540 | return 1 - prod
541 |
542 | def forward(self, x0, t, return_posterior=True):
543 | """
544 | Runs diffusion process forward given starting point `x0` and a time `t`.
545 | Optionally also returns a tensor which represents the posterior of
546 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
547 | Arguments:
548 | `x0`: a B x `self.input_shape` tensor containing the data at some
549 | time points
550 | `t`: a B-tensor containing the time in the diffusion process for
551 | each input
552 | Returns a B x `self.input_shape` tensor to represent xt. If
553 | `return_posterior` is True, then also returns a B x `self.input_shape`
554 | tensor which is a parameter of the posterior.
555 | """
556 | beta_t = self._inflate_dims(self._beta(t))
557 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
558 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1))
559 |
560 | prob_one = x0 * (1 - beta_bar_t)
561 | indicators = torch.bernoulli(prob_one)
562 |
563 | # Perform sampling
564 | xt = x0.clone()
565 | mask = indicators == 0
566 | xt[mask] = 0 # Set to 0
567 |
568 | if return_posterior:
569 | term_1 = xt + beta_t - (2 * xt * beta_t)
570 | term_2 = x0 * (1 - beta_bar_t_1)
571 | term_3 = ((1 - x0) * (1 - xt)) + (x0 * (1 - xt) * beta_bar_t) + \
572 | (x0 * xt * (1 - beta_bar_t))
573 |
574 | post = term_1 * term_2 / term_3
575 | # Due to small numerical instabilities (particularly at t = 0), clip
576 | # the probabilities
577 | post = torch.clamp(post, 0, 1)
578 | return xt, post
579 | return xt
580 |
581 | def reverse_step(self, xt, t, post):
582 | """
583 | Performs a reverse sampling step to compute x_{t-1} given xt and the
584 | posterior quantity (or an estimate of it) defined in `posterior`.
585 | Arguments:
586 | `xt`: a B x `self.input_shape` tensor containing the data at time t
587 | `t`: a B-tensor containing the time in the diffusion process for
588 | each input
589 | `post`: a B x `self.input_shape` tensor containing the posterior
590 | quantity (or a model-predicted estimate of it) as defined in
591 | `posterior`
592 | Returns a B x `self.input_shape` tensor for x_{t-1}.
593 | """
594 | # Ignore xt and t
595 | if self.reverse_reflect:
596 | xt_1 = xt.clone()
597 | indicators = torch.bernoulli(post)
598 | xt_1[indicators == 1] = 1 # Only set to 1 when a 1 is sampled
599 | return xt_1
600 | return torch.bernoulli(post)
601 |
602 | def sample_prior(self, num_samples, t):
603 | """
604 | Samples from the prior distribution specified by the diffusion process
605 | at time `t`.
606 | Arguments:
607 | `num_samples`: B, the number of samples to return
608 | `t`: a B-tensor containing the time in the diffusion process for
609 | each input
610 | Returns a B x `self.input_shape` tensor for the `xt` values that are
611 | sampled.
612 | """
613 | # Ignore t
614 | size = torch.Size([num_samples] + list(self.input_shape))
615 | return torch.zeros(size, device=DEVICE)
616 |
617 | def __str__(self):
618 | return self.string
619 |
620 |
621 | class BernoulliSkipDiffuser(BernoulliDiffuser):
622 | # Diffuser which flips bits in the inputs, but the posterior and reverse
623 | # step are directly just probabilities of the original
624 | def forward(self, x0, t, return_posterior=True):
625 | """
626 | Runs diffusion process forward given starting point `x0` and a time `t`.
627 | Optionally also returns a tensor which represents the posterior of
628 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
629 | Arguments:
630 | `x0`: a B x `self.input_shape` tensor containing the data at some
631 | time points
632 | `t`: a B-tensor containing the time in the diffusion process for
633 | each input
634 | Returns a B x `self.input_shape` tensor to represent xt. If
635 | `return_posterior` is True, then also returns a B x `self.input_shape`
636 | tensor which is a parameter of the posterior.
637 | """
638 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
639 | prob_flip = torch.tile(beta_bar_t, (1, *x0.shape[1:]))
640 | indicators = torch.bernoulli(prob_flip)
641 |
642 | # Perform flips
643 | xt = x0.clone()
644 | mask = indicators == 1
645 | xt[mask] = 1 - x0[mask]
646 |
647 | if return_posterior:
648 | return xt, x0 # Here, our "posterior" is just x0
649 | return xt
650 |
651 | def reverse_step(self, xt, t, post):
652 | """
653 | Performs a reverse sampling step to compute x_{t-1} given xt and the
654 | posterior quantity (or an estimate of it) defined in `posterior`.
655 | Arguments:
656 | `xt`: a B x `self.input_shape` tensor containing the data at time t
657 | `t`: a B-tensor containing the time in the diffusion process for
658 | each input
659 | `post`: a B x `self.input_shape` tensor containing the posterior
660 | quantity (or a model-predicted estimate of it) as defined in
661 | `posterior`
662 | Returns a B x `self.input_shape` tensor for x_{t-1}.
663 | """
664 | last_step_mask = None
665 | if not torch.all(t > 1):
666 | last_step_mask = t == 1
667 |
668 | x0 = torch.bernoulli(post)
669 | # In this case, what we're calling the "posterior" is just x0
670 |
671 | beta_t = self._inflate_dims(self._beta(t))
672 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
673 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1))
674 |
675 | term_1 = ((1 - xt) * beta_t) + (xt * (1 - beta_t))
676 | term_2 = ((1 - x0) * beta_bar_t_1) + (x0 * (1 - beta_bar_t_1))
677 | x0_xor_xt = torch.square(x0 - xt)
678 | term_3 = (x0_xor_xt * beta_bar_t) + ((1 - x0_xor_xt) * (1 - beta_bar_t))
679 |
680 | posterior = term_1 * term_2 / term_3 # p(x_{t-1} = 1 | xt, x0)
681 | # Due to small numerical instabilities (particularly at t = 0), clip
682 | # the probabilities
683 | posterior = torch.clamp(posterior, 0, 1)
684 | xt_1 = torch.bernoulli(posterior)
685 |
686 | if last_step_mask is not None:
687 | # For the last step, don't do this posterior calculation; just use
688 | # the x0 we are given
689 | xt_1[last_step_mask] = x0[last_step_mask]
690 | return xt_1
691 |
692 |
693 | class BernoulliOneSkipDiffuser(BernoulliOneDiffuser):
694 | # Diffuser which sets bits to 1 in the inputs, but the posterior and reverse
695 | # step are directly just probabilities of the original
696 | def forward(self, x0, t, return_posterior=True):
697 | """
698 | Runs diffusion process forward given starting point `x0` and a time `t`.
699 | Optionally also returns a tensor which represents the posterior of
700 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
701 | Arguments:
702 | `x0`: a B x `self.input_shape` tensor containing the data at some
703 | time points
704 | `t`: a B-tensor containing the time in the diffusion process for
705 | each input
706 | Returns a B x `self.input_shape` tensor to represent xt. If
707 | `return_posterior` is True, then also returns a B x `self.input_shape`
708 | tensor which is a parameter of the posterior.
709 | """
710 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
711 |
712 | prob_one = x0 + ((1 - x0) * beta_bar_t)
713 | indicators = torch.bernoulli(prob_one)
714 |
715 | # Perform sampling
716 | xt = x0.clone()
717 | mask = indicators == 1
718 | xt[mask] = 1 # Set to 1
719 |
720 | if return_posterior:
721 | return xt, x0 # Here, our "posterior" is just x0
722 | return xt
723 |
724 | def reverse_step(self, xt, t, post):
725 | """
726 | Performs a reverse sampling step to compute x_{t-1} given xt and the
727 | posterior quantity (or an estimate of it) defined in `posterior`.
728 | Arguments:
729 | `xt`: a B x `self.input_shape` tensor containing the data at time t
730 | `t`: a B-tensor containing the time in the diffusion process for
731 | each input
732 | `post`: a B x `self.input_shape` tensor containing the posterior
733 | quantity (or a model-predicted estimate of it) as defined in
734 | `posterior`
735 | Returns a B x `self.input_shape` tensor for x_{t-1}.
736 | """
737 | last_step_mask = None
738 | if not torch.all(t > 1):
739 | last_step_mask = t == 1
740 |
741 | x0 = torch.bernoulli(post)
742 | # In this case, what we're calling the "posterior" is just x0
743 |
744 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
745 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1))
746 |
747 | term_1 = xt
748 | term_2 = x0 + ((1 - x0) * beta_bar_t_1)
749 | term_3 = (x0 * xt) + ((1 - x0) * xt * beta_bar_t) + \
750 | ((1 - x0) * (1 - xt) * (1 - beta_bar_t))
751 |
752 | posterior = term_1 * term_2 / term_3 # p(x_{t-1} = 1 | xt, x0)
753 | # Due to small numerical instabilities (particularly at t = 0), clip
754 | # the probabilities
755 | posterior = torch.nan_to_num(posterior, 1)
756 | posterior = torch.clamp(posterior, 0, 1)
757 | xt_1 = torch.bernoulli(posterior)
758 |
759 | if last_step_mask is not None:
760 | # For the last step, don't do this posterior calculation; just use
761 | # the x0 we are given
762 | xt_1[last_step_mask] = x0[last_step_mask]
763 |
764 | if self.reverse_reflect:
765 | mask = (xt == 0) & (xt_1 == 1)
766 | xt_1[mask] = 0 # Disallow flipping to 1 from t to t-1
767 |
768 | return xt_1
769 |
770 |
771 | class BernoulliZeroSkipDiffuser(BernoulliZeroDiffuser):
772 | # Diffuser which sets bits to 0 in the inputs, but the posterior and reverse
773 | # step are directly just probabilities of the original
774 | def forward(self, x0, t, return_posterior=True):
775 | """
776 | Runs diffusion process forward given starting point `x0` and a time `t`.
777 | Optionally also returns a tensor which represents the posterior of
778 | x_{t-1} given xt and/or x0 (e.g. probability, mean, noise, etc.)
779 | Arguments:
780 | `x0`: a B x `self.input_shape` tensor containing the data at some
781 | time points
782 | `t`: a B-tensor containing the time in the diffusion process for
783 | each input
784 | Returns a B x `self.input_shape` tensor to represent xt. If
785 | `return_posterior` is True, then also returns a B x `self.input_shape`
786 | tensor which is a parameter of the posterior.
787 | """
788 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
789 |
790 | prob_one = x0 * (1 - beta_bar_t)
791 | indicators = torch.bernoulli(prob_one)
792 |
793 | # Perform sampling
794 | xt = x0.clone()
795 | mask = indicators == 0
796 | xt[mask] = 0 # Set to 0
797 |
798 | if return_posterior:
799 | return xt, x0 # Here, our "posterior" is just x0
800 | return xt
801 |
802 | def reverse_step(self, xt, t, post):
803 | """
804 | Performs a reverse sampling step to compute x_{t-1} given xt and the
805 | posterior quantity (or an estimate of it) defined in `posterior`.
806 | Arguments:
807 | `xt`: a B x `self.input_shape` tensor containing the data at time t
808 | `t`: a B-tensor containing the time in the diffusion process for
809 | each input
810 | `post`: a B x `self.input_shape` tensor containing the posterior
811 | quantity (or a model-predicted estimate of it) as defined in
812 | `posterior`
813 | Returns a B x `self.input_shape` tensor for x_{t-1}.
814 | """
815 | last_step_mask = None
816 | if not torch.all(t > 1):
817 | last_step_mask = t == 1
818 |
819 | x0 = torch.bernoulli(post)
820 | # In this case, what we're calling the "posterior" is just x0
821 |
822 | beta_t = self._inflate_dims(self._beta(t))
823 | beta_bar_t = self._inflate_dims(self._beta_bar(t))
824 | beta_bar_t_1 = self._inflate_dims(self._beta_bar(t - 1))
825 |
826 | term_1 = xt + beta_t - (2 * xt * beta_t)
827 | term_2 = x0 * (1 - beta_bar_t_1)
828 | term_3 = ((1 - x0) * (1 - xt)) + (x0 * (1 - xt) * beta_bar_t) + \
829 | (x0 * xt * (1 - beta_bar_t))
830 |
831 | posterior = term_1 * term_2 / term_3 # p(x_{t-1} = 1 | xt, x0)
832 | # Due to small numerical instabilities (particularly at t = 0), clip
833 | # the probabilities
834 | posterior = torch.nan_to_num(posterior, 1)
835 | posterior = torch.clamp(posterior, 0, 1)
836 | xt_1 = torch.bernoulli(posterior)
837 |
838 | if last_step_mask is not None:
839 | # For the last step, don't do this posterior calculation; just use
840 | # the x0 we are given
841 | xt_1[last_step_mask] = x0[last_step_mask]
842 |
843 | if self.reverse_reflect:
844 | mask = (xt == 1) & (xt_1 == 0)
845 | xt_1[mask] = 1 # Disallow flipping to 0 from t to t-1
846 |
847 | return xt_1
848 |
849 |
850 | if __name__ == "__main__":
851 | input_shape = (1, 28, 28)
852 | diffuser = GaussianDiffuser(1e-4, 1.2e-4, input_shape)
853 | x0 = torch.empty((32,) + input_shape, device=DEVICE)
854 | t = torch.randint(1, 1001, (32,), device=DEVICE)
855 | xt, z = diffuser.forward(x0, t)
856 | xt1 = diffuser.reverse_step(xt, t, z)
857 | xT = diffuser.sample_prior(32, None)
858 |
--------------------------------------------------------------------------------
/src/model/generate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | import feature.graph_conversions as graph_conversions
4 |
5 | # Define device
6 | if torch.cuda.is_available():
7 | DEVICE = "cuda"
8 | else:
9 | DEVICE = "cpu"
10 |
11 |
12 | def generate_samples(
13 | model, diffuser, num_samples=64, t_start=0, t_limit=1000,
14 | initial_samples=None, return_all_times=False, verbose=False
15 | ):
16 | """
17 | Generates samples from a trained posterior model and discrete diffuser. This
18 | first generates a sample from the prior distribution a `t_limit`, then steps
19 | backward through time to generate new data points.
20 | Arguments:
21 | `model`: a trained model which takes in x, t and predicts a posterior
22 | `diffuser`: a DiscreteDiffuser object
23 | `num_samples`: number of objects to return
24 | `t_start`: last time step to stop at (a smaller positive integer) than
25 | `t_limit`
26 | `t_limit`: the time step to start generating at (a larger positive
27 | integer than `t_start`)
28 | `initial_samples`: if given, this is a tensor which contains the samples
29 | to start from initially, to be used instead of sampling from the
30 | diffuser's defined prior
31 | `return_all_times`: if True, instead of returning a tensor at `t_start`,
32 | return a larger tensor where the first dimension is
33 | `t_limit - t_start + 1`, and a tensor of times; each tensor is the
34 | reconstruction of the object for that time; the first entry will be
35 | the object at `t_limit`, and the last entry will be the object at
36 | `t_start`
37 | `verbose`: if True, print out progress bar
38 | Returns a tensor of size `num_samples` x .... If `return_all_times` is True,
39 | returns a tensor of size T x `num_samples` x ... of reconstructions and a
40 | tensor of size T for times.
41 | """
42 | # First, sample from the prior distribution at some late time t
43 | if initial_samples is not None:
44 | xt = initial_samples
45 | assert len(xt) == num_samples
46 | else:
47 | t = (torch.ones(num_samples) * t_limit).to(DEVICE)
48 | xt = diffuser.sample_prior(num_samples, t)
49 |
50 | if return_all_times:
51 | all_xt = torch.empty(
52 | (t_limit - t_start + 1,) + xt.shape,
53 | dtype=xt.dtype, device=xt.device
54 | )
55 | all_xt[0] = xt
56 | all_t = torch.arange(t_limit, t_start - 1, step=-1).to(DEVICE)
57 |
58 | # Disable gradient computation in model
59 | model.eval()
60 | torch.set_grad_enabled(False)
61 |
62 | time_steps = torch.arange(t_limit, t_start, step=-1).to(DEVICE)
63 | # (descending order)
64 |
65 | # Step backward through time starting at xt
66 | x = xt
67 | t_iter = tqdm.tqdm(enumerate(time_steps), total=len(time_steps)) if verbose \
68 | else enumerate(time_steps)
69 | for t_i, time_step in t_iter:
70 | t = torch.ones(num_samples).to(DEVICE) * time_step
71 | post = model(xt, t)
72 | xt = diffuser.reverse_step(xt, t, post)
73 |
74 | if return_all_times:
75 | all_xt[t_i + 1] = xt
76 |
77 | if return_all_times:
78 | return all_xt, all_t
79 | return xt
80 |
81 |
82 | def generate_graph_samples(
83 | model, diffuser, initial_samples, t_start=0, t_limit=1000,
84 | return_all_times=False, verbose=False
85 | ):
86 | """
87 | Generates samples from a trained posterior model and discrete diffuser. This
88 | first generates a sample from the prior distribution a `t_limit`, then steps
89 | backward through time to generate new data points.
90 | Arguments:
91 | `model`: a trained model which takes in x, t and predicts a posterior on
92 | edges in canonical order
93 | `diffuser`: a DiscreteDiffuser object
94 | `initial_samples`: a torch-geometric Data object which contains the
95 | samples to start from initially, at `t_limit`
96 | `t_start`: last time step to stop at (a smaller positive integer) than
97 | `t_limit`
98 | `t_limit`: the time step to start generating at (a larger positive
99 | integer than `t_start`)
100 | `return_all_times`: if True, instead of returning a single
101 | torch-geometric Data object at `t_start`, return a list of Data
102 | objects of length `t_limit - t_start + 1`, and parallel tensor of
103 | times; each Data object reconstruction of the object for that time;
104 | the first entry will be the object at `t_limit`, and the last entry
105 | will be the object at `t_start`
106 | `verbose`: if True, print out progress bar
107 | Returns a torch-geometric Data object. If `return_all_times` is True,
108 | returns a list of T torch-geometric Data objects and a T-tensor of times.
109 | """
110 | xt = initial_samples
111 |
112 | if return_all_times:
113 | all_xt = [xt]
114 | all_t = torch.arange(t_limit, t_start - 1, step=-1).to(DEVICE)
115 |
116 | # Disable gradient computation in model
117 | model.eval()
118 | torch.set_grad_enabled(False)
119 |
120 | time_steps = torch.arange(t_limit, t_start, step=-1).to(DEVICE)
121 | # (descending order)
122 |
123 | # Step backward through time starting at xt
124 | x = xt
125 | t_iter = tqdm.tqdm(enumerate(time_steps), total=len(time_steps)) if verbose \
126 | else enumerate(time_steps)
127 | for t_i, time_step in t_iter:
128 | edges = graph_conversions.pyg_data_to_edge_vector(xt)
129 |
130 | t_v = torch.tile(
131 | torch.tensor([time_step], device=DEVICE), (xt.x.shape[0],)
132 | ) # Shape: V
133 | t_e = torch.tile(
134 | torch.tensor([time_step], device=DEVICE), edges.shape
135 | ) # Shape: E
136 |
137 | post = model(xt, t_v)
138 | post_edges = diffuser.reverse_step(
139 | edges[:, None], t_e, post[:, None]
140 | )[:, 0] # Do everything on E x 1 tensors, and then squeeze back
141 |
142 | # Make copy of Data object
143 | xt = xt.clone()
144 |
145 | xt.edge_index = graph_conversions.edge_vector_to_pyg_data(
146 | xt, post_edges
147 | )
148 |
149 | if return_all_times:
150 | all_xt.append(xt)
151 |
152 | if return_all_times:
153 | return all_xt, all_t
154 | return xt
155 |
--------------------------------------------------------------------------------
/src/model/gnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_geometric
3 | import numpy as np
4 | import networkx as nx
5 | from model.util import sanitize_sacred_arguments
6 | import feature.graph_conversions as graph_conversions
7 |
8 | class GraphLinkGIN(torch.nn.Module):
9 |
10 | def __init__(
11 | self, input_dim, t_limit, num_gnn_layers=4, hidden_dim=10,
12 | time_embed_size=256, virtual_node=False
13 | ):
14 | """
15 | Initialize a time-dependent GNN which predicts bit probabilities for
16 | each edge.
17 | Arguments:
18 | `input_dim`: the dimension of the input node features
19 | `t_limit`: maximum time horizon
20 | `num_gnn_layers`: number of GNN layers to have
21 | `hidden_dim`: the dimension of the hidden node embeddings
22 | `time_embed_size`: size of the time embeddings
23 | `virtual_node`: if True, include a virtual node in the architecture
24 | """
25 | super().__init__()
26 |
27 | self.creation_args = locals()
28 | del self.creation_args["self"]
29 | del self.creation_args["__class__"]
30 | self.creation_args = sanitize_sacred_arguments(self.creation_args)
31 |
32 | self.t_limit = t_limit
33 | self.num_gnn_layers = num_gnn_layers
34 | self.virtual_node = virtual_node
35 |
36 | self.time_embed_dense = torch.nn.Linear(3, time_embed_size)
37 |
38 | self.swish = lambda x: x * torch.sigmoid(x)
39 | self.relu = torch.nn.ReLU()
40 |
41 | # GNN layers
42 | self.gnn_layers = torch.nn.ModuleList()
43 | self.gnn_batch_norms = torch.nn.ModuleList()
44 | for i in range(num_gnn_layers):
45 | gnn_nn = torch.nn.Sequential(
46 | torch.nn.Linear(
47 | input_dim + time_embed_size if i == 0 else hidden_dim,
48 | hidden_dim * 2
49 | ),
50 | self.relu,
51 | torch.nn.Linear(hidden_dim * 2, hidden_dim)
52 | )
53 | gnn_layer = torch_geometric.nn.GINConv(gnn_nn, train_eps=True)
54 | gnn_batch_norm = torch_geometric.nn.LayerNorm(hidden_dim)
55 |
56 | self.gnn_layers.append(gnn_layer)
57 | self.gnn_batch_norms.append(gnn_batch_norm)
58 |
59 | # Link prediction
60 | self.link_dense = torch.nn.Linear(hidden_dim, 1)
61 |
62 | # Loss
63 | self.bce_loss = torch.nn.BCELoss()
64 |
65 | def forward(self, data, t):
66 | """
67 | Forward pass of the network.
68 | Arguments:
69 | `data`: a (batched) torch-geometric Data object
70 | `t`: a V-tensor containing the time to train on for each node; note
71 | that the time should be the same for nodes belonging to the same
72 | individual graph
73 | Returns an E-tensor of probabilities of each edge at time t - 1, where E
74 | is the total possible number of edges, and is in canonical ordering.
75 | """
76 | if self.virtual_node:
77 | # Add a virtual node
78 | data_copy = data.clone() # Don't modify the original
79 | graph_conversions.add_virtual_nodes(data_copy) # Modify `data_copy`
80 | # Also extend `t` by the number of virtual nodes added (use 0,
81 | # since the node is fake)
82 | # TODO: it would be better to use the same time as its host graph
83 | t = torch.cat([
84 | t, torch.zeros(len(data.ptr) - 1, device=t.device)
85 | ])
86 | data = data_copy # Use our modified Data object
87 |
88 | # Get the time embeddings for `t`
89 | time_embed_args = t[:, None] / self.t_limit # Shape: V x 1
90 | time_embed = self.swish(self.time_embed_dense(
91 | torch.cat([
92 | torch.sin(time_embed_args * (np.pi / 2)),
93 | torch.cos(time_embed_args * (np.pi / 2)),
94 | time_embed_args
95 | ], dim=1)
96 | )) # Shape: V x Dt
97 |
98 | # Concatenate initial node features and time embedding
99 | node_embed = torch.cat([data.x.float(), time_embed], dim=1)
100 | # Shape: V x D
101 |
102 | # GNN layers
103 | for i in range(self.num_gnn_layers):
104 | node_embed = self.gnn_batch_norms[i](
105 | self.gnn_layers[i](node_embed, data.edge_index),
106 | data.batch
107 | )
108 |
109 | # For all possible edges (i.e. node pairs), compute probability
110 | edge_inds = graph_conversions.edge_vector_to_pyg_data(
111 | data, 1, reflect=False
112 | ) # Shape: 2 x E
113 | node_embed_1 = node_embed[edge_inds[0]] # Shape: E x D'
114 | node_embed_2 = node_embed[edge_inds[1]] # Shape: E x D'
115 | node_prod = node_embed_1 * node_embed_2
116 |
117 | edge_probs = torch.sigmoid(self.link_dense(node_prod))[:, 0]
118 |
119 | return edge_probs
120 |
121 | def loss(self, pred_probs, true_probs):
122 | """
123 | Computes the loss of a batch.
124 | Arguments:
125 | `pred_probs`: an E-tensor of predicted probabilities
126 | `true_probs`: an E-tensor of true probabilities
127 | Returns a scalar loss value.
128 | """
129 | return self.bce_loss(pred_probs, true_probs)
130 |
131 |
132 | class GraphAttentionLayer(torch.nn.Module):
133 |
134 | def __init__(self, input_dim, num_heads=8, att_hidden_dim=32, dropout=0.1):
135 | """
136 | Initialize a graph attention layer which computes attention between all
137 | nodes.
138 | Arguments:
139 | `input_dim`: the dimension of the input node features, D
140 | `num_heads`: number of attention heads
141 | `att_hidden_dim`: the dimension of the hidden node embeddings in the
142 | GAT
143 | `dropout`: dropout rate for post-GAT layers
144 | """
145 |
146 | super().__init__()
147 |
148 | self.input_dim = input_dim
149 | self.num_heads = num_heads
150 | self.att_hidden_dim = att_hidden_dim
151 |
152 | self.spectral_dense = torch.nn.Linear(input_dim, input_dim)
153 |
154 | self.gat = torch_geometric.nn.GATv2Conv(
155 | input_dim * 2, att_hidden_dim, heads=num_heads, edge_dim=1
156 | # edge_dim is 1 for presence/absence of edge
157 | )
158 |
159 | self.dropout_1 = torch.nn.Dropout(dropout)
160 | self.norm_1 = torch_geometric.nn.LayerNorm(
161 | att_hidden_dim * num_heads
162 | )
163 | self.dense_2 = torch.nn.Linear(
164 | att_hidden_dim * num_heads, input_dim
165 | )
166 | self.dropout_2 = torch.nn.Dropout(dropout)
167 | self.norm_2 = torch_geometric.nn.LayerNorm(input_dim)
168 |
169 | self.dense_3 = torch.nn.Linear(input_dim, input_dim)
170 | self.dropout_3 = torch.nn.Dropout(dropout)
171 | self.norm_3 = torch_geometric.nn.LayerNorm(input_dim)
172 |
173 | self.relu = torch.nn.ReLU()
174 |
175 | def forward(
176 | self, x, full_edge_index, edge_indicators, batch, spectrum_mats
177 | ):
178 | """
179 | Forward pass of the attention layer.
180 | Arguments:
181 | `x`: V x D tensor of node features
182 | `full_edge_index`: 2 x E tensor of edge indices for the attention
183 | mechanism, denoting all possible edges within each subgraph in
184 | the batch
185 | `edge_feats`: E x 1 tensor of edge features, denoting which edges
186 | actually exist in each subgraph (e.g. 1 if the edge exists, 0
187 | otherwise)
188 | `batch`: V-tensor of which nodes belong to which batch
189 | `spectrum_mats`: list of B matrices, where each matrix is n x m,
190 | where n is the number of nodes in each graph and m is the
191 | number of eigenvectors (columns); this matrix transforms _out_
192 | of the spectral domain when left-multiplying node features; B is
193 | the number of graphs in the batch
194 | Returns a V x D tensor of updated node features.
195 | """
196 | # Perform spectral convolution
197 | specconv_out = x.clone()
198 | for i in range(len(spectrum_mats)):
199 | batch_mask = batch == i
200 | x_in = x[batch_mask] # Shape: n x d
201 | mat = spectrum_mats[i] # Shape: n x m
202 | # Transform into spectral domain
203 | x_spec = torch.matmul(torch.transpose(mat, 0, 1), x_in)
204 | # Shape: m x d
205 | # Perform convolution by combining channels across each "frequency"
206 | x_spec_out = self.spectral_dense(x_spec) # Shape: m x d
207 | # Transform back into feature domain
208 | x_out = torch.matmul(mat, x_spec_out) # Shape: n x d
209 | specconv_out[batch_mask] = x_out
210 |
211 | # Attention on x and output of spectral convolution
212 | gat_out = self.gat(
213 | torch.cat([x, specconv_out], dim=1), full_edge_index,
214 | edge_attr=edge_indicators
215 | )
216 | x_out_1 = self.norm_1(x + self.relu(self.dropout_1(gat_out)))
217 |
218 | # Post-attention dense layers
219 | x_out_2 = self.norm_2(x_out_1 + self.dropout_2(self.dense_2(x_out_1)))
220 | x_out_3 = self.norm_3(x_out_2 + self.dropout_3(self.dense_3(x_out_2)))
221 | return x_out_3
222 |
223 |
224 | class GraphLinkGAT(torch.nn.Module):
225 |
226 | def __init__(
227 | self, input_dim, t_limit, num_gnn_layers=4, gat_num_heads=8,
228 | gat_hidden_dim=32, hidden_dim=256, time_embed_size=256, spectrum_dim=5,
229 | epsilon=1e-6
230 | ):
231 | """
232 | Initialize a time-dependent GNN which predicts bit probabilities for
233 | each edge.
234 | Arguments:
235 | `input_dim`: the dimension of the input node features
236 | `t_limit`: maximum time horizon
237 | `num_gnn_layers`: number of GNN layers to have
238 | `gat_num_heads`: number of attention heads
239 | `gat_hidden_dim`: the dimension of the hidden node embeddings in the
240 | GAT
241 | `hidden_dim`: size of hidden dimension before and after attention
242 | layers
243 | `time_embed_size`: size of the time embeddings
244 | `spectrum_dim`: number of spectral features to use (i.e. the number
245 | of eigenvectors to use)
246 | `epsilon`: small number for numerical stability when computing graph
247 | Laplacian
248 | """
249 | super().__init__()
250 |
251 | self.creation_args = locals()
252 | del self.creation_args["self"]
253 | del self.creation_args["__class__"]
254 | self.creation_args = sanitize_sacred_arguments(self.creation_args)
255 |
256 | self.t_limit = t_limit
257 | self.num_gnn_layers = num_gnn_layers
258 | self.spectrum_dim = spectrum_dim
259 | self.epsilon = epsilon
260 |
261 | self.time_embed_dense = torch.nn.Linear(3, time_embed_size)
262 |
263 | self.swish = lambda x: x * torch.sigmoid(x)
264 | self.relu = torch.nn.ReLU()
265 |
266 | # Pre-GNN linear layers
267 | self.pregnn_dense_1 = torch.nn.Linear(
268 | input_dim + time_embed_size, hidden_dim
269 | )
270 | self.pregnn_dense_2 = torch.nn.Linear(hidden_dim, hidden_dim)
271 |
272 | # GNN layers
273 | self.gnn_layers = torch.nn.ModuleList()
274 | for i in range(num_gnn_layers):
275 | self.gnn_layers.append(GraphAttentionLayer(
276 | hidden_dim if i == 0 else gat_num_heads * gat_hidden_dim,
277 | num_heads=gat_num_heads, att_hidden_dim=gat_hidden_dim
278 | ))
279 |
280 | # Pre-GNN linear layers
281 | self.postgnn_dense_1 = torch.nn.Linear(
282 | gat_hidden_dim * gat_num_heads, hidden_dim
283 | )
284 | self.postgnn_dense_2 = torch.nn.Linear(hidden_dim, hidden_dim)
285 |
286 | # Link prediction
287 | self.link_dense = torch.nn.Linear(hidden_dim, 1)
288 |
289 | # Loss
290 | self.bce_loss = torch.nn.BCELoss()
291 |
292 | def forward(self, data, t):
293 | """
294 | Forward pass of the network.
295 | Arguments:
296 | `data`: a (batched) torch-geometric Data object
297 | `t`: a V-tensor containing the time to train on for each node; note
298 | that the time should be the same for nodes belonging to the same
299 | individual graph
300 | Returns an E-tensor of probabilities of each edge at time t - 1, where E
301 | is the total possible number of edges, and is in canonical ordering.
302 | """
303 | # Get the time embeddings for `t`
304 | time_embed_args = t[:, None] / self.t_limit # Shape: V x 1
305 | time_embed = self.swish(self.time_embed_dense(
306 | torch.cat([
307 | torch.sin(time_embed_args * (np.pi / 2)),
308 | torch.cos(time_embed_args * (np.pi / 2)),
309 | time_embed_args
310 | ], dim=1)
311 | )) # Shape: V x Dt
312 |
313 | # Concatenate initial node features and time embedding
314 | node_embed = torch.cat([data.x.float(), time_embed], dim=1)
315 | # Shape: V x D
316 |
317 | # Pre-GNN dense layers on node features
318 | node_embed = self.relu(self.pregnn_dense_1(node_embed))
319 | node_embed = self.relu(self.pregnn_dense_2(node_embed))
320 |
321 | # Create edge_index specifying the full dense subgraphs
322 | full_edge_index = graph_conversions.edge_vector_to_pyg_data(
323 | data, 1, reflect=False
324 | ) # Shape: 2 x E
325 |
326 | # Create edge features, which denotes both which edges are real
327 | edge_indicators = \
328 | graph_conversions.pyg_data_to_edge_vector(data)[:, None]
329 | # Shape: E x 1
330 |
331 | # Compute the Laplacian for each graph in the batch
332 | adj_mat = torch_geometric.utils.to_dense_adj(
333 | data.edge_index, data.batch
334 | )
335 | deg = torch.sum(adj_mat, dim=2)
336 | sqrt_deg = 1 / torch.sqrt(deg + self.epsilon)
337 | sqrt_deg_mat = torch.diag_embed(sqrt_deg)
338 |
339 | identity = torch.eye(adj_mat.size(1), device=adj_mat.device)[None]
340 | laplacian = identity - \
341 | torch.matmul(torch.matmul(sqrt_deg_mat, adj_mat), sqrt_deg_mat)
342 |
343 | # Compute spectrum transformation matrix (i.e. eigenvalues/eigenvectors)
344 | # This is done separately for each graph, as each graph may have a
345 | # different size
346 | spectrum_mats = []
347 | for i, graph_size in enumerate(torch.diff(data.ptr)):
348 | # We only compute the eigendecomposition on the graph-size-limited
349 | # Laplacian, since this function always sorts eigenvalues
350 | evals, evecs = torch.linalg.eigh(
351 | laplacian[i, :graph_size, :graph_size]
352 | )
353 | # Limit the eigenvectors to the smallest eigenvalues if needed
354 | if self.spectrum_dim < graph_size:
355 | evecs = evecs[:, :self.spectrum_dim]
356 | spectrum_mats.append(evecs)
357 |
358 | # GNN layers
359 | for i in range(self.num_gnn_layers):
360 | node_embed = self.gnn_layers[i](
361 | node_embed, full_edge_index, edge_indicators, data.batch,
362 | spectrum_mats
363 | )
364 |
365 | # Post-GNN dense layers on node features
366 | node_embed = self.relu(self.postgnn_dense_1(node_embed))
367 | node_embed = self.relu(self.postgnn_dense_2(node_embed))
368 |
369 | # For all possible edges (i.e. node pairs), compute probability
370 | node_embed_1 = node_embed[full_edge_index[0]] # Shape: E x D'
371 | node_embed_2 = node_embed[full_edge_index[1]] # Shape: E x D'
372 | node_prod = node_embed_1 * node_embed_2
373 |
374 | edge_probs = torch.sigmoid(self.link_dense(node_prod))[:, 0]
375 |
376 | return edge_probs
377 |
378 | def loss(self, pred_probs, true_probs):
379 | """
380 | Computes the loss of a batch.
381 | Arguments:
382 | `pred_probs`: an E-tensor of predicted probabilities
383 | `true_probs`: an E-tensor of true probabilities
384 | Returns a scalar loss value.
385 | """
386 | return self.bce_loss(pred_probs, true_probs)
387 |
--------------------------------------------------------------------------------
/src/model/image_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from model.util import sanitize_sacred_arguments
4 |
5 | class MNISTProbUNetTimeConcat(torch.nn.Module):
6 |
7 | def __init__(
8 | self, t_limit, enc_dims=[32, 64, 128, 256], dec_dims=[32, 64, 128],
9 | time_embed_size=32, data_channels=1
10 | ):
11 | """
12 | Initialize a time-dependent U-net for MNIST, where time embeddings are
13 | concatenated to image representations.
14 | Arguments:
15 | `t_limit`: maximum time horizon
16 | `enc_dims`: the number of channels in each encoding layer
17 | `dec_dims`: the number of channels in each decoding layer (given in
18 | reverse order of usage)
19 | `time_embed_size`: size of the time embeddings
20 | `data_channels`: number of channels in input image
21 | """
22 | super().__init__()
23 |
24 | assert len(enc_dims) == 4
25 | assert len(dec_dims) == 3
26 |
27 | self.creation_args = locals()
28 | del self.creation_args["self"]
29 | del self.creation_args["__class__"]
30 | self.creation_args = sanitize_sacred_arguments(self.creation_args)
31 |
32 | self.t_limit = t_limit
33 |
34 | # Encoders: receptive field increases and depth increases
35 | self.conv_e1 = torch.nn.Conv2d(
36 | data_channels + time_embed_size, enc_dims[0], kernel_size=3,
37 | stride=1, bias=False
38 | )
39 | self.time_dense_e1 = torch.nn.Linear(2, time_embed_size)
40 | self.norm_e1 = torch.nn.GroupNorm(4, num_channels=enc_dims[0])
41 | self.conv_e2 = torch.nn.Conv2d(
42 | enc_dims[0] + time_embed_size, enc_dims[1], kernel_size=3, stride=2,
43 | bias=False
44 | )
45 | self.time_dense_e2 = torch.nn.Linear(2, time_embed_size)
46 | self.norm_e2 = torch.nn.GroupNorm(32, num_channels=enc_dims[1])
47 | self.conv_e3 = torch.nn.Conv2d(
48 | enc_dims[1] + time_embed_size, enc_dims[2], kernel_size=3, stride=2,
49 | bias=False
50 | )
51 | self.time_dense_e3 = torch.nn.Linear(2, time_embed_size)
52 | self.norm_e3 = torch.nn.GroupNorm(32, num_channels=enc_dims[2])
53 | self.conv_e4 = torch.nn.Conv2d(
54 | enc_dims[2] + time_embed_size, enc_dims[3], kernel_size=3, stride=2,
55 | bias=False
56 | )
57 | self.time_dense_e4 = torch.nn.Linear(2, time_embed_size)
58 | self.norm_e4 = torch.nn.GroupNorm(32, num_channels=enc_dims[3])
59 |
60 | # Decoders: depth decreases
61 | self.conv_d4 = torch.nn.ConvTranspose2d(
62 | enc_dims[3] + time_embed_size, dec_dims[2], 3, stride=2, bias=False
63 | )
64 | self.time_dense_d4 = torch.nn.Linear(2, time_embed_size)
65 | self.norm_d4 = torch.nn.GroupNorm(32, num_channels=dec_dims[2])
66 | self.conv_d3 = torch.nn.ConvTranspose2d(
67 | dec_dims[2] + enc_dims[2] + time_embed_size, dec_dims[1], 3,
68 | stride=2, output_padding=1, bias=False
69 | )
70 | self.time_dense_d3 = torch.nn.Linear(2, time_embed_size)
71 | self.norm_d3 = torch.nn.GroupNorm(32, num_channels=dec_dims[1])
72 | self.conv_d2 = torch.nn.ConvTranspose2d(
73 | dec_dims[1] + enc_dims[1] + time_embed_size, dec_dims[0], 3,
74 | stride=2, output_padding=1, bias=False
75 | )
76 | self.time_dense_d2 = torch.nn.Linear(2, time_embed_size)
77 | self.norm_d2 = torch.nn.GroupNorm(32, num_channels=dec_dims[0])
78 | self.conv_d1 = torch.nn.ConvTranspose2d(
79 | dec_dims[0] + enc_dims[0], data_channels, 3, stride=1, bias=True
80 | )
81 |
82 | # Activation functions
83 | self.swish = lambda x: x * torch.sigmoid(x)
84 |
85 | # Loss
86 | self.bce_loss = torch.nn.BCELoss()
87 |
88 | def forward(self, xt, t):
89 | """
90 | Forward pass of the network.
91 | Arguments:
92 | `xt`: B x 1 x H x W tensor containing the images to train on
93 | `t`: B-tensor containing the times to train the network for each
94 | image
95 | Returns a B x 1 x H x W tensor which consists of the prediction.
96 | """
97 | # Get the time embeddings for `t`
98 | # We embed the time as cos((t/T) * (2pi)) and sin((t/T) * (2pi))
99 | time_embed_args = (t[:, None] / self.t_limit) * (2 * np.pi)
100 | # Shape: B x 1
101 | time_embed = self.swish(
102 | torch.cat([
103 | torch.sin(time_embed_args), torch.cos(time_embed_args)
104 | ], dim=1)
105 | )
106 | # Shape: B x 2
107 |
108 | # Encoding
109 | enc_1_out = self.swish(self.norm_e1(self.conv_e1(
110 | torch.cat([
111 | xt,
112 | torch.tile(
113 | self.time_dense_e1(time_embed)[:, :, None, None],
114 | (1, 1) + xt.shape[2:]
115 | )
116 | ], dim=1)
117 | )))
118 | enc_2_out = self.swish(self.norm_e2(self.conv_e2(
119 | torch.cat([
120 | enc_1_out,
121 | torch.tile(
122 | self.time_dense_e2(time_embed)[:, :, None, None],
123 | (1, 1) + enc_1_out.shape[2:]
124 | )
125 | ], dim=1)
126 | )))
127 | enc_3_out = self.swish(self.norm_e3(self.conv_e3(
128 | torch.cat([
129 | enc_2_out,
130 | torch.tile(
131 | self.time_dense_e3(time_embed)[:, :, None, None],
132 | (1, 1) + enc_2_out.shape[2:]
133 | )
134 | ], dim=1)
135 | )))
136 | enc_4_out = self.swish(self.norm_e4(self.conv_e4(
137 | torch.cat([
138 | enc_3_out,
139 | torch.tile(
140 | self.time_dense_e4(time_embed)[:, :, None, None],
141 | (1, 1) + enc_3_out.shape[2:]
142 | )
143 | ], dim=1)
144 | )))
145 |
146 | # Decoding
147 | dec_4_out = self.swish(self.norm_d4(self.conv_d4(
148 | torch.cat([
149 | enc_4_out,
150 | torch.tile(
151 | self.time_dense_d4(time_embed)[:, :, None, None],
152 | (1, 1) + enc_4_out.shape[2:]
153 | )
154 | ], dim=1)
155 | )))
156 | dec_3_out = self.swish(self.norm_d3(self.conv_d3(
157 | torch.cat([
158 | dec_4_out, enc_3_out,
159 | torch.tile(
160 | self.time_dense_d3(time_embed)[:, :, None, None],
161 | (1, 1) + dec_4_out.shape[2:]
162 | )
163 | ], dim=1)
164 | )))
165 | dec_2_out = self.swish(self.norm_d2(self.conv_d2(
166 | torch.cat([
167 | dec_3_out, enc_2_out,
168 | torch.tile(
169 | self.time_dense_d2(time_embed)[:, :, None, None],
170 | (1, 1) + dec_3_out.shape[2:]
171 | )
172 | ], dim=1)
173 | )))
174 | dec_1_out = self.conv_d1(torch.cat([dec_2_out, enc_1_out], dim=1))
175 | return torch.sigmoid(dec_1_out)
176 |
177 | def loss(self, pred_probs, true_probs):
178 | """
179 | Computes the loss of the neural network.
180 | Arguments:
181 | `pred_probs`: a B x 1 x H x W tensor of predicted probabilities
182 | `true_probs`: a B x 1 x H x W tensor of true probabilities
183 | Returns a scalar loss of binary cross-entropy values, averaged across
184 | all dimensions.
185 | """
186 | return self.bce_loss(pred_probs, true_probs)
187 |
188 |
189 | class MNISTProbUNetTimeAdd(torch.nn.Module):
190 |
191 | def __init__(
192 | self, t_limit, enc_dims=[32, 64, 128, 256], dec_dims=[32, 64, 128],
193 | time_embed_size=32, time_embed_std=30, use_time_embed_dense=False,
194 | data_channels=1
195 | ):
196 | """
197 | Initialize a time-dependent U-net for MNIST, where time embeddings are
198 | added to image representations.
199 | Arguments:
200 | `t_limit`: maximum time horizon
201 | `enc_dims`: the number of channels in each encoding layer
202 | `dec_dims`: the number of channels in each decoding layer (given in
203 | reverse order of usage)
204 | `time_embed_size`: size of the time embeddings
205 | `time_embed_std`: standard deviation of random weights to sample for
206 | time embeddings
207 | `use_time_embed_dense`: if True, have a dense layer on top of time
208 | embeddings initially
209 | `data_channels`: number of channels in input image
210 | """
211 | super().__init__()
212 |
213 | assert len(enc_dims) == 4
214 | assert len(dec_dims) == 3
215 | assert time_embed_size % 2 == 0
216 |
217 | self.creation_args = locals()
218 | del self.creation_args["self"]
219 | del self.creation_args["__class__"]
220 | self.creation_args = sanitize_sacred_arguments(self.creation_args)
221 |
222 | self.t_limit = t_limit
223 | self.use_time_embed_dense = use_time_embed_dense
224 |
225 | # Random embedding layer for time; the random weights are set at the
226 | # start and are not trainable
227 | self.time_embed_rand_weights = torch.nn.Parameter(
228 | torch.randn(time_embed_size // 2) * time_embed_std,
229 | requires_grad=False
230 | )
231 | if use_time_embed_dense:
232 | self.time_embed_dense = torch.nn.Linear(
233 | time_embed_size, time_embed_size
234 | )
235 |
236 | # Encoders: receptive field increases and depth increases
237 | self.conv_e1 = torch.nn.Conv2d(
238 | data_channels, enc_dims[0], kernel_size=3, stride=1, bias=False
239 | )
240 | self.time_dense_e1 = torch.nn.Linear(time_embed_size, enc_dims[0])
241 | self.norm_e1 = torch.nn.GroupNorm(4, num_channels=enc_dims[0])
242 | self.conv_e2 = torch.nn.Conv2d(
243 | enc_dims[0], enc_dims[1], kernel_size=3, stride=2, bias=False
244 | )
245 | self.time_dense_e2 = torch.nn.Linear(time_embed_size, enc_dims[1])
246 | self.norm_e2 = torch.nn.GroupNorm(32, num_channels=enc_dims[1])
247 | self.conv_e3 = torch.nn.Conv2d(
248 | enc_dims[1], enc_dims[2], kernel_size=3, stride=2, bias=False
249 | )
250 | self.time_dense_e3 = torch.nn.Linear(time_embed_size, enc_dims[2])
251 | self.norm_e3 = torch.nn.GroupNorm(32, num_channels=enc_dims[2])
252 | self.conv_e4 = torch.nn.Conv2d(
253 | enc_dims[2], enc_dims[3], kernel_size=3, stride=2, bias=False
254 | )
255 | self.time_dense_e4 = torch.nn.Linear(time_embed_size, enc_dims[3])
256 | self.norm_e4 = torch.nn.GroupNorm(32, num_channels=enc_dims[3])
257 |
258 | # Decoders: depth decreases
259 | self.conv_d4 = torch.nn.ConvTranspose2d(
260 | enc_dims[3], dec_dims[2], 3, stride=2, bias=False
261 | )
262 | self.time_dense_d4 = torch.nn.Linear(time_embed_size, dec_dims[2])
263 | self.norm_d4 = torch.nn.GroupNorm(32, num_channels=dec_dims[2])
264 | self.conv_d3 = torch.nn.ConvTranspose2d(
265 | dec_dims[2] + enc_dims[2], dec_dims[1], 3, stride=2,
266 | output_padding=1, bias=False
267 | )
268 | self.time_dense_d3 = torch.nn.Linear(time_embed_size, dec_dims[1])
269 | self.norm_d3 = torch.nn.GroupNorm(32, num_channels=dec_dims[1])
270 | self.conv_d2 = torch.nn.ConvTranspose2d(
271 | dec_dims[1] + enc_dims[1], dec_dims[0], 3, stride=2,
272 | output_padding=1, bias=False
273 | )
274 | self.time_dense_d2 = torch.nn.Linear(time_embed_size, dec_dims[0])
275 | self.norm_d2 = torch.nn.GroupNorm(32, num_channels=dec_dims[0])
276 | self.conv_d1 = torch.nn.ConvTranspose2d(
277 | dec_dims[0] + enc_dims[0], data_channels, 3, stride=1, bias=True
278 | )
279 |
280 | # Activation functions
281 | self.swish = lambda x: x * torch.sigmoid(x)
282 |
283 | # Loss
284 | self.bce_loss = torch.nn.BCELoss()
285 |
286 | def forward(self, xt, t):
287 | """
288 | Forward pass of the network.
289 | Arguments:
290 | `xt`: B x 1 x H x W tensor containing the images to train on
291 | `t`: B-tensor containing the times to train the network for each
292 | image
293 | Returns a B x 1 x H x W tensor which consists of the prediction.
294 | """
295 | # Get the time embeddings for `t`
296 | # We had sampled a vector z from a zero-mean Gaussian of fixed variance
297 | # We embed the time as cos((t/T) * (2pi) * z) and sin((t/T) * (2pi) * z)
298 | time_embed_args = (t[:, None] / self.t_limit) * \
299 | self.time_embed_rand_weights[None, :] * (2 * np.pi)
300 | # Shape: B x (E / 2)
301 |
302 | time_embed = torch.cat([
303 | torch.sin(time_embed_args), torch.cos(time_embed_args)
304 | ], dim=1)
305 | if self.use_time_embed_dense:
306 | time_embed = self.swish(self.time_embed_dense(time_embed))
307 | else:
308 | time_embed = self.swish(time_embed)
309 | # Shape: B x E
310 |
311 | # Encoding
312 | enc_1_out = self.swish(self.norm_e1(
313 | self.conv_e1(xt) +
314 | self.time_dense_e1(time_embed)[:, :, None, None]
315 | ))
316 | enc_2_out = self.swish(self.norm_e2(
317 | self.conv_e2(enc_1_out) +
318 | self.time_dense_e2(time_embed)[:, :, None, None]
319 | ))
320 | enc_3_out = self.swish(self.norm_e3(
321 | self.conv_e3(enc_2_out) +
322 | self.time_dense_e3(time_embed)[:, :, None, None]
323 | ))
324 | enc_4_out = self.swish(self.norm_e4(
325 | self.conv_e4(enc_3_out) +
326 | self.time_dense_e4(time_embed)[:, :, None, None]
327 | ))
328 |
329 | # Decoding
330 | dec_4_out = self.swish(self.norm_d4(
331 | self.conv_d4(enc_4_out) +
332 | self.time_dense_d4(time_embed)[:, :, None, None]
333 | ))
334 | dec_3_out = self.swish(self.norm_d3(
335 | self.conv_d3(torch.cat([dec_4_out, enc_3_out], dim=1)) +
336 | self.time_dense_d3(time_embed)[:, :, None, None]
337 | ))
338 | dec_2_out = self.swish(self.norm_d2(
339 | self.conv_d2(torch.cat([dec_3_out, enc_2_out], dim=1)) +
340 | self.time_dense_d2(time_embed)[:, :, None, None]
341 | ))
342 | dec_1_out = self.conv_d1(torch.cat([dec_2_out, enc_1_out], dim=1))
343 | return torch.sigmoid(dec_1_out)
344 |
345 | def loss(self, pred_probs, true_probs):
346 | """
347 | Computes the loss of the neural network.
348 | Arguments:
349 | `pred_probs`: a B x 1 x H x W tensor of predicted probabilities
350 | `true_probs`: a B x 1 x H x W tensor of true probabilities
351 | Returns a scalar loss of binary cross-entropy values, averaged across
352 | all dimensions.
353 | """
354 | return self.bce_loss(pred_probs, true_probs)
355 |
--------------------------------------------------------------------------------
/src/model/train_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import tqdm
4 | import os
5 | import sacred
6 | import model.util as util
7 | import feature.graph_conversions as graph_conversions
8 | import model.generate as generate
9 | import analysis.graph_metrics as graph_metrics
10 | import analysis.mmd as mmd
11 |
12 |
13 | MODEL_DIR = os.environ.get(
14 | "MODEL_DIR",
15 | "/gstore/data/resbioai/tsenga5/branched_diffusion/models/trained_models/misc"
16 | )
17 |
18 | train_ex = sacred.Experiment("train")
19 |
20 | train_ex.observers.append(
21 | sacred.observers.FileStorageObserver.create(MODEL_DIR)
22 | )
23 |
24 | # Define device
25 | if torch.cuda.is_available():
26 | DEVICE = "cuda"
27 | else:
28 | DEVICE = "cpu"
29 |
30 |
31 | @train_ex.config
32 | def config():
33 | # Number of training epochs
34 | num_epochs = 30
35 |
36 | # Learning rate
37 | learning_rate = 0.001
38 |
39 |
40 | @train_ex.command
41 | def train_model(
42 | model, diffuser, data_loader, num_epochs, learning_rate, _run, t_limit=1000
43 | ):
44 | """
45 | Trains a diffusion model using the given instantiated model and discrete
46 | diffuser object.
47 | Arguments:
48 | `model`: an instantiated model which takes in x, t and predicts a
49 | posterior
50 | `diffuser`: a DiscreteDiffuser object
51 | `data_loader`: a DataLoader object that yields batches of data as
52 | tensors in pairs: x, y
53 | `class_time_to_branch_index`: function that takes in B-tensors of class
54 | and time and maps to a B-tensor of branch indices
55 | `num_epochs`: number of epochs to train for
56 | `learning_rate`: learning rate to use for training
57 | `t_limit`: training will occur between time 1 and `t_limit`
58 | """
59 | run_num = _run._id
60 | output_dir = os.path.join(MODEL_DIR, str(run_num))
61 |
62 | model.train()
63 | torch.set_grad_enabled(True)
64 | optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
65 |
66 | for epoch_num in range(num_epochs):
67 | batch_losses = []
68 | t_iter = tqdm.tqdm(data_loader)
69 | for x0, y in t_iter:
70 | x0 = x0.to(DEVICE).float()
71 |
72 | # Sample random times between 1 and t_limit (inclusive)
73 | t = torch.randint(
74 | t_limit, size=(x0.shape[0],), device=DEVICE
75 | ) + 1
76 |
77 | # Run diffusion forward to get xt and the posterior parameter to
78 | # predict
79 | xt, true_post = diffuser.forward(x0, t)
80 |
81 | # Get model-predicted posterior parameter
82 | pred_post = model(xt, t)
83 |
84 | # Compute loss
85 | loss = model.loss(pred_post, true_post)
86 | loss_val = loss.item()
87 | t_iter.set_description("Loss: %.4f" % loss_val)
88 |
89 | if np.isnan(loss_val):
90 | continue
91 |
92 | optim.zero_grad()
93 | loss.backward()
94 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
95 | optim.step()
96 |
97 | batch_losses.append(loss_val)
98 |
99 | epoch_loss = np.mean(batch_losses)
100 | print("Epoch %d average Loss: %.4f" % (epoch_num + 1, epoch_loss))
101 |
102 | _run.log_scalar("train_epoch_loss", epoch_loss)
103 | _run.log_scalar("train_batch_losses", batch_losses)
104 |
105 | model_path = os.path.join(
106 | output_dir, "epoch_%d_ckpt.pth" % (epoch_num + 1)
107 | )
108 | link_path = os.path.join(output_dir, "last_ckpt.pth")
109 |
110 | # Save model
111 | util.save_model(model, model_path)
112 |
113 | # Create symlink to last epoch
114 | if os.path.islink(link_path):
115 | os.remove(link_path)
116 | os.symlink(os.path.basename(model_path), link_path)
117 |
118 |
119 | @train_ex.command
120 | def train_graph_model(
121 | model, diffuser, data_loader, num_epochs, learning_rate, _run, t_limit=1000,
122 | compute_mmd=False, val_data_loader=None, mmd_sample_size=200
123 | ):
124 | """
125 | Trains a diffusion model on graphs using the given instantiated model and
126 | discrete diffuser object.
127 | Arguments:
128 | `model`: an instantiated model which takes in x, t and predicts a
129 | posterior on edges in canonical order
130 | `diffuser`: a DiscreteDiffuser object
131 | `data_loader`: a DataLoader object that yields torch-geometric Data
132 | objects
133 | `class_time_to_branch_index`: function that takes in B-tensors of class
134 | and time and maps to a B-tensor of branch indices
135 | `num_epochs`: number of epochs to train for
136 | `learning_rate`: learning rate to use for training
137 | `t_limit`: training will occur between time 1 and `t_limit`
138 | `compute_mmd`: if True, compute some performance metrics at the end of
139 | training
140 | `val_data_loader`: if `compute_mmd` is True, this must be another data
141 | loader (like `data_loader`) which yields validation-set objects
142 | `mmd_sample_size`: number of graphs to compute MMD over
143 | """
144 | run_num = _run._id
145 | output_dir = os.path.join(MODEL_DIR, str(run_num))
146 |
147 | model.train()
148 | torch.set_grad_enabled(True)
149 | optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
150 |
151 | for epoch_num in range(num_epochs):
152 | batch_losses = []
153 | t_iter = tqdm.tqdm(data_loader)
154 | for data in t_iter:
155 |
156 | e0, edge_batch_inds = graph_conversions.pyg_data_to_edge_vector(
157 | data, return_batch_inds=True
158 | ) # Shape: E
159 |
160 | # Pick some random times t between 1 and t_limit (inclusive), one
161 | # value for each individual graph
162 | graph_sizes = torch.diff(data.ptr)
163 | graph_times = torch.randint(
164 | t_limit, size=(graph_sizes.shape[0],), device=DEVICE
165 | ) + 1
166 |
167 | # Tile the graph times to the size of all nodes
168 | t_v = graph_times[data.batch].float() # Shape: V
169 | t_e = graph_times[edge_batch_inds].float() # Shape: E
170 |
171 | # Add noise to edges from time 0 to time t
172 | et, true_post = diffuser.forward(e0[:, None], t_e)
173 | # Do the noising on E x 1 tensors
174 | et, true_post = et[:, 0], true_post[:, 0]
175 | data.edge_index = graph_conversions.edge_vector_to_pyg_data(
176 | data, et
177 | )
178 | # Note: this modifies `data`
179 |
180 | # Get model-predicted posterior parameter
181 | pred_post = model(data, t_v)
182 |
183 | # Compute loss
184 | loss = model.loss(pred_post, true_post)
185 | loss_val = loss.item()
186 | t_iter.set_description("Loss: %.4f" % loss_val)
187 |
188 | if np.isnan(loss_val):
189 | continue
190 |
191 | optim.zero_grad()
192 | loss.backward()
193 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
194 | optim.step()
195 |
196 | batch_losses.append(loss_val)
197 |
198 | epoch_loss = np.mean(batch_losses)
199 | print("Epoch %d average Loss: %.4f" % (epoch_num + 1, epoch_loss))
200 |
201 | _run.log_scalar("train_epoch_loss", epoch_loss)
202 | _run.log_scalar("train_batch_losses", batch_losses)
203 |
204 | model_path = os.path.join(
205 | output_dir, "epoch_%d_ckpt.pth" % (epoch_num + 1)
206 | )
207 | link_path = os.path.join(output_dir, "last_ckpt.pth")
208 |
209 | # Save model
210 | util.save_model(model, model_path)
211 |
212 | # Create symlink to last epoch
213 | if os.path.islink(link_path):
214 | os.remove(link_path)
215 | os.symlink(os.path.basename(model_path), link_path)
216 |
217 | # If required, compute MMD metrics
218 | if compute_mmd:
219 | num_batches = int(np.ceil(mmd_sample_size / data_loader.batch_size))
220 | train_graphs_1, train_graphs_2 = [], []
221 | gen_graphs = []
222 | data_iter_1, data_iter_2 = iter(data_loader), iter(val_data_loader)
223 | print("Generating %d graphs over %d batches" % (
224 | mmd_sample_size, num_batches)
225 | )
226 | for i in range(num_batches):
227 | print("Batch %d/%d" % (i + 1, num_batches))
228 | data = next(data_iter_1)
229 | train_graphs_1.extend(
230 | graph_conversions.split_pyg_data_to_nx_graphs(data)
231 | )
232 | data = next(data_iter_2)
233 | train_graphs_2.extend(
234 | graph_conversions.split_pyg_data_to_nx_graphs(data)
235 | )
236 | edges = graph_conversions.pyg_data_to_edge_vector(data)
237 | sampled_edges = diffuser.sample_prior(
238 | edges.shape[0], # Samples will be E x 1
239 | torch.tile(torch.tensor([t_limit], device=DEVICE), edges.shape)
240 | )[:, 0] # Shape: E
241 | data.edge_index = graph_conversions.edge_vector_to_pyg_data(
242 | data, sampled_edges
243 | )
244 |
245 | samples = generate.generate_graph_samples(
246 | model, diffuser, data, t_limit=t_limit, verbose=True
247 | )
248 | gen_graphs.extend(
249 | graph_conversions.split_pyg_data_to_nx_graphs(samples)
250 | )
251 |
252 | train_graphs_1 = train_graphs_1[:mmd_sample_size]
253 | train_graphs_2 = train_graphs_2[:mmd_sample_size]
254 | gen_graphs = gen_graphs[:mmd_sample_size]
255 | assert len(train_graphs_1) == mmd_sample_size
256 | assert len(train_graphs_2) == mmd_sample_size
257 | assert len(gen_graphs) == mmd_sample_size
258 | all_graphs = train_graphs_1 + train_graphs_2 + gen_graphs
259 |
260 | # Compute MMD values
261 | print("MMD (squared) values:")
262 | square_func = np.square
263 | kernel_type = "gaussian_total_variation"
264 |
265 | degree_hists = mmd.make_histograms(
266 | graph_metrics.get_degrees(all_graphs), bin_width=1
267 | )
268 | degree_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy(
269 | degree_hists[:mmd_sample_size], degree_hists[-mmd_sample_size:],
270 | kernel_type, sigma=1
271 | ))
272 | degree_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy(
273 | degree_hists[:mmd_sample_size],
274 | degree_hists[mmd_sample_size:-mmd_sample_size],
275 | kernel_type, sigma=1
276 | ))
277 | _run.log_scalar("degree_mmd", degree_mmd_1)
278 | _run.log_scalar("degree_mmd_baseline", degree_mmd_2)
279 | print("Degree MMD ratio: %.8f/%.8f = %.8f" % (
280 | degree_mmd_1, degree_mmd_2, degree_mmd_1 / degree_mmd_2
281 | ))
282 |
283 | cluster_coef_hists = mmd.make_histograms(
284 | graph_metrics.get_clustering_coefficients(all_graphs), num_bins=100
285 | )
286 | cluster_coef_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy(
287 | cluster_coef_hists[:mmd_sample_size],
288 | cluster_coef_hists[-mmd_sample_size:],
289 | kernel_type, sigma=0.1
290 | ))
291 | cluster_coef_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy(
292 | cluster_coef_hists[:mmd_sample_size],
293 | cluster_coef_hists[mmd_sample_size:-mmd_sample_size],
294 | kernel_type, sigma=0.1
295 | ))
296 | _run.log_scalar("cluster_coef_mmd", cluster_coef_mmd_1)
297 | _run.log_scalar("cluster_coef_mmd_baseline", cluster_coef_mmd_2)
298 | print("Clustering coefficient MMD ratio: %.8f/%.8f = %.8f" % (
299 | cluster_coef_mmd_1, cluster_coef_mmd_2,
300 | cluster_coef_mmd_1 / cluster_coef_mmd_2
301 | ))
302 |
303 | spectra_hists = mmd.make_histograms(
304 | graph_metrics.get_spectra(all_graphs),
305 | bin_array=np.linspace(-1e-5, 2, 200 + 1)
306 | )
307 | spectra_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy(
308 | spectra_hists[:mmd_sample_size], spectra_hists[-mmd_sample_size:],
309 | kernel_type, sigma=1
310 | ))
311 | spectra_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy(
312 | spectra_hists[:mmd_sample_size],
313 | spectra_hists[mmd_sample_size:-mmd_sample_size],
314 | kernel_type, sigma=1
315 | ))
316 | _run.log_scalar("spectra_mmd", spectra_mmd_1)
317 | _run.log_scalar("spectra_mmd_baseline", spectra_mmd_2)
318 | print("Spectrum MMD ratio: %.8f/%.8f = %.8f" % (
319 | spectra_mmd_1, spectra_mmd_2, spectra_mmd_1 / spectra_mmd_2
320 | ))
321 |
322 | orbit_counts = graph_metrics.get_orbit_counts(all_graphs)
323 | orbit_counts = np.stack([
324 | np.mean(counts, axis=0) for counts in orbit_counts
325 | ])
326 | orbit_mmd_1 = square_func(mmd.compute_maximum_mean_discrepancy(
327 | orbit_counts[:mmd_sample_size], orbit_counts[-mmd_sample_size:],
328 | kernel_type, normalize=False, sigma=30
329 | ))
330 | orbit_mmd_2 = square_func(mmd.compute_maximum_mean_discrepancy(
331 | orbit_counts[:mmd_sample_size],
332 | orbit_counts[mmd_sample_size:-mmd_sample_size],
333 | kernel_type, normalize=False, sigma=30
334 | ))
335 | _run.log_scalar("orbit_mmd", orbit_mmd_1)
336 | _run.log_scalar("orbit_mmd_baseline", orbit_mmd_2)
337 | print("Orbit MMD ratio: %.8f/%.8f = %.8f" % (
338 | orbit_mmd_1, orbit_mmd_2, orbit_mmd_1 / orbit_mmd_2
339 | ))
340 |
--------------------------------------------------------------------------------
/src/model/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def sanitize_sacred_arguments(args):
4 | """
5 | This function goes through and sanitizes the arguments to native types.
6 | Lists and dictionaries passed through Sacred automatically become
7 | ReadOnlyLists and ReadOnlyDicts. This function will go through and
8 | recursively change them to native lists and dicts.
9 | `args` can be a single token, a list of items, or a dictionary of items.
10 | The return type will be a native token, list, or dictionary.
11 | """
12 | if isinstance(args, list): # Captures ReadOnlyLists
13 | return [
14 | sanitize_sacred_arguments(item) for item in args
15 | ]
16 | elif isinstance(args, dict): # Captures ReadOnlyDicts
17 | return {
18 | str(key) : sanitize_sacred_arguments(val) \
19 | for key, val in args.items()
20 | }
21 | else: # Return the single token as-is
22 | return args
23 |
24 |
25 | def save_model(model, save_path):
26 | """
27 | Saves the given model at the given path. This saves the state of the model
28 | (i.e. trained layers and parameters), and the arguments used to create the
29 | model (i.e. a dictionary of the original arguments).
30 | """
31 | save_dict = {
32 | "model_state": model.state_dict(),
33 | "model_creation_args": model.creation_args
34 | }
35 | torch.save(save_dict, save_path)
36 |
37 |
38 | def load_model(model_class, load_path):
39 | """
40 | Restores a model from the given path. `model_class` must be the class for
41 | which the saved model was created from. This will create a model of this
42 | class, using the loaded creation arguments. It will then restore the learned
43 | parameters to the model.
44 | """
45 | load_dict = torch.load(load_path)
46 | model_state = load_dict["model_state"]
47 | model_creation_args = load_dict["model_creation_args"]
48 | model = model_class(**model_creation_args)
49 | model.load_state_dict(model_state)
50 | return model
51 |
--------------------------------------------------------------------------------
/src/plot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Genentech/GraphGUIDE/dad0dd371268684a5203839441febe0484d8a3e4/src/plot/__init__.py
--------------------------------------------------------------------------------
/src/plot/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | def plot_mnist_digits(
5 | digits, grid_size=(10, 1), scale=1, clip=True, title=None
6 | ):
7 | """
8 | Plots MNIST digits. No normalization will be done.
9 | Arguments:
10 | `digits`: a B x 1 x 28 x 28 NumPy array of numbers between
11 | 0 and 1
12 | `grid_size`: a pair of integers denoting the number of digits
13 | to plot horizontally and vertically (in that order); if
14 | more digits are provided than spaces in the grid, leftover
15 | digits will not be plotted; if fewer digits are provided
16 | than spaces in the grid, there will be at most one
17 | unfinished row
18 | `scale`: amount to scale figure size by
19 | `clip`: if True, clip values to between 0 and 1
20 | `title`: if given, title for the plot
21 | """
22 | digits = np.transpose(digits, (0, 2, 3, 1))
23 | if clip:
24 | digits = np.clip(digits, 0, 1)
25 |
26 | width, height = grid_size
27 | num_digits = len(digits)
28 | height = min(height, num_digits // width)
29 |
30 | figsize = (width * scale, (height * scale) + 0.5)
31 |
32 | fig, ax = plt.subplots(
33 | ncols=width, nrows=height,
34 | figsize=figsize, gridspec_kw={"wspace": 0, "hspace": 0}
35 | )
36 | if height == 1:
37 | ax = [ax]
38 | if width == 1:
39 | ax = [[a] for a in ax]
40 |
41 | for j in range(height):
42 | for i in range(width):
43 | index = i + (width * j)
44 | ax[j][i].imshow(digits[index], cmap="gray", aspect="auto")
45 | ax[j][i].axis("off")
46 | if title is not None:
47 | ax[0][0].set_title(title)
48 | plt.subplots_adjust(bottom=0.25)
49 | plt.show()
50 |
--------------------------------------------------------------------------------